Skip to content

Commit 8333900

Browse files
authored
Conditionally whitelist datetime.datetime and add tests (#767)
1 parent b3f3662 commit 8333900

File tree

2 files changed

+72
-17
lines changed

2 files changed

+72
-17
lines changed

temporalio/worker/workflow_sandbox/_restrictions.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Optional,
2828
Sequence,
2929
Set,
30+
Tuple,
3031
Type,
3132
TypeVar,
3233
cast,
@@ -952,20 +953,20 @@ def r_op(obj: Any, other: Any) -> Any:
952953
return cast(_OpF, r_op)
953954

954955

956+
_do_not_restrict: Tuple[Type, ...] = (bool, int, float, complex, str, bytes, bytearray)
957+
if HAVE_PYDANTIC:
958+
# The datetime validator in pydantic_core
959+
# https://github.com/pydantic/pydantic-core/blob/741961c05847d9e9ee517cd783e24c2b58e5596b/src/input/input_python.rs#L548-L582
960+
# does some runtime type inspection that a RestrictedProxy instance
961+
# fails. For this reason we do not restrict date/datetime instances when
962+
# Pydantic is being used. Other restricted types such as pathlib.Path
963+
# and uuid.UUID which are likely to be used in Pydantic model fields
964+
# currently pass Pydantic's validation when wrapped by RestrictedProxy.
965+
_do_not_restrict += (datetime.date,) # e.g. datetime.datetime
966+
967+
955968
def _is_restrictable(v: Any) -> bool:
956-
return v is not None and not isinstance(
957-
v,
958-
(
959-
bool,
960-
int,
961-
float,
962-
complex,
963-
str,
964-
bytes,
965-
bytearray,
966-
datetime.date, # e.g. datetime.datetime
967-
),
968-
)
969+
return v is not None and not isinstance(v, _do_not_restrict)
969970

970971

971972
class _RestrictedProxy:

tests/contrib/pydantic/test_pydantic.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dataclasses
2+
import datetime
3+
import os
4+
import pathlib
25
import uuid
3-
from datetime import datetime
46

57
import pydantic
68
import pytest
@@ -9,6 +11,11 @@
911
from temporalio.client import Client
1012
from temporalio.contrib.pydantic import pydantic_data_converter
1113
from temporalio.worker import Worker
14+
from temporalio.worker.workflow_sandbox._restrictions import (
15+
RestrictionContext,
16+
SandboxMatcher,
17+
_RestrictedProxy,
18+
)
1219
from tests.contrib.pydantic.models import (
1320
PydanticModels,
1421
PydanticModelWithStrictField,
@@ -103,7 +110,7 @@ async def test_round_trip_misc_objects(client: Client):
103110
{"7": 7.0},
104111
[{"7": 7.0}],
105112
({"7": 7.0},),
106-
datetime(2025, 1, 2, 3, 4, 5),
113+
datetime.datetime(2025, 1, 2, 3, 4, 5),
107114
uuid.uuid4(),
108115
)
109116

@@ -262,7 +269,9 @@ async def test_datetime_usage_in_workflow(client: Client):
262269

263270
def test_pydantic_model_with_strict_field_outside_sandbox():
264271
_test_pydantic_model_with_strict_field(
265-
PydanticModelWithStrictField(strict_field=datetime(2025, 1, 2, 3, 4, 5))
272+
PydanticModelWithStrictField(
273+
strict_field=datetime.datetime(2025, 1, 2, 3, 4, 5)
274+
)
266275
)
267276

268277

@@ -276,7 +285,9 @@ async def test_pydantic_model_with_strict_field_inside_sandbox(client: Client):
276285
workflows=[PydanticModelWithStrictFieldWorkflow],
277286
task_queue=tq,
278287
):
279-
orig = PydanticModelWithStrictField(strict_field=datetime(2025, 1, 2, 3, 4, 5))
288+
orig = PydanticModelWithStrictField(
289+
strict_field=datetime.datetime(2025, 1, 2, 3, 4, 5)
290+
)
280291
result = await client.execute_workflow(
281292
PydanticModelWithStrictFieldWorkflow.run,
282293
orig,
@@ -324,3 +335,46 @@ async def test_validation_error(client: Client):
324335
task_queue=task_queue_name,
325336
result_type=tuple[int],
326337
)
338+
339+
340+
class RestrictedProxyFieldsModel(BaseModel):
341+
path_field: pathlib.Path
342+
uuid_field: uuid.UUID
343+
datetime_field: datetime.datetime
344+
345+
346+
def test_model_instantiation_from_restricted_proxy_values():
347+
restricted_path_cls = _RestrictedProxy(
348+
"Path",
349+
pathlib.Path,
350+
RestrictionContext(),
351+
SandboxMatcher(),
352+
)
353+
restricted_uuid_cls = _RestrictedProxy(
354+
"uuid",
355+
uuid.UUID,
356+
RestrictionContext(),
357+
SandboxMatcher(),
358+
)
359+
restricted_datetime_cls = _RestrictedProxy(
360+
"datetime",
361+
datetime.datetime,
362+
RestrictionContext(),
363+
SandboxMatcher(),
364+
)
365+
366+
restricted_path = restricted_path_cls("test/path")
367+
restricted_uuid = restricted_uuid_cls(bytes=os.urandom(16), version=4)
368+
restricted_datetime = restricted_datetime_cls(2025, 1, 2, 3, 4, 5)
369+
370+
assert type(restricted_path) is _RestrictedProxy
371+
assert type(restricted_uuid) is _RestrictedProxy
372+
assert type(restricted_datetime) is not _RestrictedProxy
373+
p = RestrictedProxyFieldsModel(
374+
path_field=restricted_path, # type: ignore
375+
uuid_field=restricted_uuid, # type: ignore
376+
datetime_field=restricted_datetime, # type: ignore
377+
)
378+
assert p.path_field == restricted_path
379+
assert p.uuid_field == restricted_uuid
380+
assert p.datetime_field == restricted_datetime

0 commit comments

Comments
 (0)