14
14
import operator
15
15
import random
16
16
import types
17
+ import warnings
17
18
from copy import copy , deepcopy
18
19
from dataclasses import dataclass
19
20
from typing import (
@@ -49,15 +50,25 @@ class RestrictedWorkflowAccessError(temporalio.workflow.NondeterminismError):
49
50
qualified_name: Fully qualified name of what was accessed.
50
51
"""
51
52
52
- def __init__ (self , qualified_name : str ) -> None :
53
+ def __init__ (
54
+ self , qualified_name : str , * , override_message : Optional [str ] = None
55
+ ) -> None :
53
56
"""Create restricted workflow access error."""
54
57
super ().__init__ (
58
+ override_message
59
+ or RestrictedWorkflowAccessError .default_message (qualified_name )
60
+ )
61
+ self .qualified_name = qualified_name
62
+
63
+ @staticmethod
64
+ def default_message (qualified_name : str ) -> str :
65
+ """Get default message for restricted access."""
66
+ return (
55
67
f"Cannot access { qualified_name } from inside a workflow. "
56
68
"If this is code from a module not used in a workflow or known to "
57
69
"only be used deterministically from a workflow, mark the import "
58
70
"as pass through."
59
71
)
60
- self .qualified_name = qualified_name
61
72
62
73
63
74
@dataclass (frozen = True )
@@ -182,6 +193,20 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
182
193
time.
183
194
"""
184
195
196
+ leaf_message : Optional [str ] = None
197
+ """
198
+ Override message to use in error/warning. Defaults to a common message.
199
+ This is only applicable to leafs, so this must only be set when
200
+ ``match_self`` is ``True`` and this matcher is on ``children`` of a parent.
201
+ """
202
+
203
+ leaf_warning : Optional [Type [Warning ]] = None
204
+ """
205
+ If set, issues a warning instead of raising an error. This is only
206
+ applicable to leafs, so this must only be set when ``match_self`` is
207
+ ``True`` and this matcher is on ``children`` of a parent.
208
+ """
209
+
185
210
all : ClassVar [SandboxMatcher ]
186
211
"""Shortcut for an always-matched matcher."""
187
212
@@ -197,40 +222,67 @@ def nested_child(path: Sequence[str], child: SandboxMatcher) -> SandboxMatcher:
197
222
all_uses_runtime : ClassVar [SandboxMatcher ]
198
223
"""Shortcut for a matcher that matches any :py:attr:`use` at runtime."""
199
224
200
- def match_access (
225
+ def __post_init__ (self ):
226
+ """Post initialization validations."""
227
+ if self .leaf_message and not self .match_self :
228
+ raise ValueError ("Cannot set leaf_message without match_self" )
229
+ if self .leaf_warning and not self .match_self :
230
+ raise ValueError ("Cannot set leaf_warning without match_self" )
231
+
232
+ def access_matcher (
201
233
self , context : RestrictionContext , * child_path : str , include_use : bool = False
202
- ) -> bool :
203
- """Perform a match check.
234
+ ) -> Optional [ SandboxMatcher ] :
235
+ """Perform a match check and return matcher .
204
236
205
237
Args:
206
238
context: Current restriction context.
207
239
child_path: Full path to the child being accessed.
208
240
include_use: Whether to include the :py:attr:`use` set in the check.
209
241
210
242
Returns:
211
- ``True`` if matched.
243
+ The matcher if matched.
212
244
"""
213
245
# We prefer to avoid recursion
214
246
matcher = self
215
247
for v in child_path :
216
248
# Does not match if this is runtime only and we're not runtime
217
249
if not context .is_runtime and matcher .only_runtime :
218
- return False
250
+ return None
219
251
220
252
# Considered matched if self matches or access matches. Note, "use"
221
253
# does not match by default because we allow it to be accessed but
222
254
# not used.
223
255
if matcher .match_self or v in matcher .access or "*" in matcher .access :
224
- return True
256
+ return matcher
225
257
if include_use and (v in matcher .use or "*" in matcher .use ):
226
- return True
258
+ return matcher
227
259
child_matcher = matcher .children .get (v ) or matcher .children .get ("*" )
228
260
if not child_matcher :
229
- return False
261
+ return None
230
262
matcher = child_matcher
231
263
if not context .is_runtime and matcher .only_runtime :
232
- return False
233
- return matcher .match_self
264
+ return None
265
+ if not matcher .match_self :
266
+ return None
267
+ return matcher
268
+
269
+ def match_access (
270
+ self , context : RestrictionContext , * child_path : str , include_use : bool = False
271
+ ) -> bool :
272
+ """Perform a match check.
273
+
274
+ Args:
275
+ context: Current restriction context.
276
+ child_path: Full path to the child being accessed.
277
+ include_use: Whether to include the :py:attr:`use` set in the check.
278
+
279
+ Returns:
280
+ ``True`` if matched.
281
+ """
282
+ return (
283
+ self .access_matcher (context , * child_path , include_use = include_use )
284
+ is not None
285
+ )
234
286
235
287
def child_matcher (self , * child_path : str ) -> Optional [SandboxMatcher ]:
236
288
"""Return a child matcher for the given path.
@@ -273,6 +325,10 @@ def __or__(self, other: SandboxMatcher) -> SandboxMatcher:
273
325
"""Combine this matcher with another."""
274
326
if self .only_runtime != other .only_runtime :
275
327
raise ValueError ("Cannot combine only-runtime and non-only-runtime" )
328
+ if self .leaf_message != other .leaf_message :
329
+ raise ValueError ("Cannot combine different messages" )
330
+ if self .leaf_warning != other .leaf_warning :
331
+ raise ValueError ("Cannot combine different warning values" )
276
332
if self .match_self or other .match_self :
277
333
return SandboxMatcher .all
278
334
new_children = dict (self .children ) if self .children else {}
@@ -287,6 +343,8 @@ def __or__(self, other: SandboxMatcher) -> SandboxMatcher:
287
343
use = self .use | other .use ,
288
344
children = new_children ,
289
345
only_runtime = self .only_runtime ,
346
+ leaf_message = self .leaf_message ,
347
+ leaf_warning = self .leaf_warning ,
290
348
)
291
349
292
350
def with_child_unrestricted (self , * child_path : str ) -> SandboxMatcher :
@@ -457,6 +515,28 @@ def _public_callables(parent: Any, *, exclude: Set[str] = set()) -> Set[str]:
457
515
# rewriter
458
516
only_runtime = True ,
459
517
),
518
+ "asyncio" : SandboxMatcher (
519
+ children = {
520
+ "as_completed" : SandboxMatcher (
521
+ children = {
522
+ "__call__" : SandboxMatcher (
523
+ match_self = True ,
524
+ leaf_warning = UserWarning ,
525
+ leaf_message = "asyncio.as_completed() is non-deterministic, use workflow.as_completed() instead" ,
526
+ )
527
+ },
528
+ ),
529
+ "wait" : SandboxMatcher (
530
+ children = {
531
+ "__call__" : SandboxMatcher (
532
+ match_self = True ,
533
+ leaf_warning = UserWarning ,
534
+ leaf_message = "asyncio.wait() is non-deterministic, use workflow.wait() instead" ,
535
+ )
536
+ },
537
+ ),
538
+ }
539
+ ),
460
540
# TODO(cretz): Fix issues with class extensions on restricted proxy
461
541
# "argparse": SandboxMatcher.all_uses_runtime,
462
542
"bz2" : SandboxMatcher (use = {"open" }),
@@ -689,12 +769,23 @@ def from_proxy(v: _RestrictedProxy) -> _RestrictionState:
689
769
matcher : SandboxMatcher
690
770
691
771
def assert_child_not_restricted (self , name : str ) -> None :
692
- if (
693
- self .matcher .match_access (self .context , name )
694
- and not temporalio .workflow .unsafe .is_sandbox_unrestricted ()
695
- ):
696
- logger .warning ("%s on %s restricted" , name , self .name )
697
- raise RestrictedWorkflowAccessError (f"{ self .name } .{ name } " )
772
+ if temporalio .workflow .unsafe .is_sandbox_unrestricted ():
773
+ return
774
+ matcher = self .matcher .access_matcher (self .context , name )
775
+ if not matcher :
776
+ return
777
+ logger .warning ("%s on %s restricted" , name , self .name )
778
+ # Issue warning instead of error if configured to do so
779
+ if matcher .leaf_warning :
780
+ warnings .warn (
781
+ matcher .leaf_message
782
+ or RestrictedWorkflowAccessError .default_message (f"{ self .name } .{ name } " ),
783
+ matcher .leaf_warning ,
784
+ )
785
+ else :
786
+ raise RestrictedWorkflowAccessError (
787
+ f"{ self .name } .{ name } " , override_message = matcher .leaf_message
788
+ )
698
789
699
790
def set_on_proxy (self , v : _RestrictedProxy ) -> None :
700
791
# To prevent recursion, must use __setattr__ on object to set the
0 commit comments