diff --git a/temporalio/worker/workflow_sandbox/_restrictions.py b/temporalio/worker/workflow_sandbox/_restrictions.py index 3796fd7aa..fdc126809 100644 --- a/temporalio/worker/workflow_sandbox/_restrictions.py +++ b/temporalio/worker/workflow_sandbox/_restrictions.py @@ -27,6 +27,7 @@ Optional, Sequence, Set, + Tuple, Type, TypeVar, cast, @@ -952,20 +953,20 @@ def r_op(obj: Any, other: Any) -> Any: return cast(_OpF, r_op) +_do_not_restrict: Tuple[Type, ...] = (bool, int, float, complex, str, bytes, bytearray) +if HAVE_PYDANTIC: + # The datetime validator in pydantic_core + # https://github.com/pydantic/pydantic-core/blob/741961c05847d9e9ee517cd783e24c2b58e5596b/src/input/input_python.rs#L548-L582 + # does some runtime type inspection that a RestrictedProxy instance + # fails. For this reason we do not restrict date/datetime instances when + # Pydantic is being used. Other restricted types such as pathlib.Path + # and uuid.UUID which are likely to be used in Pydantic model fields + # currently pass Pydantic's validation when wrapped by RestrictedProxy. + _do_not_restrict += (datetime.date,) # e.g. datetime.datetime + + def _is_restrictable(v: Any) -> bool: - return v is not None and not isinstance( - v, - ( - bool, - int, - float, - complex, - str, - bytes, - bytearray, - datetime.date, # e.g. datetime.datetime - ), - ) + return v is not None and not isinstance(v, _do_not_restrict) class _RestrictedProxy: diff --git a/tests/contrib/pydantic/test_pydantic.py b/tests/contrib/pydantic/test_pydantic.py index 26764b40f..c70eee56f 100644 --- a/tests/contrib/pydantic/test_pydantic.py +++ b/tests/contrib/pydantic/test_pydantic.py @@ -1,6 +1,8 @@ import dataclasses +import datetime +import os +import pathlib import uuid -from datetime import datetime import pydantic import pytest @@ -9,6 +11,11 @@ from temporalio.client import Client from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.worker import Worker +from temporalio.worker.workflow_sandbox._restrictions import ( + RestrictionContext, + SandboxMatcher, + _RestrictedProxy, +) from tests.contrib.pydantic.models import ( PydanticModels, PydanticModelWithStrictField, @@ -103,7 +110,7 @@ async def test_round_trip_misc_objects(client: Client): {"7": 7.0}, [{"7": 7.0}], ({"7": 7.0},), - datetime(2025, 1, 2, 3, 4, 5), + datetime.datetime(2025, 1, 2, 3, 4, 5), uuid.uuid4(), ) @@ -262,7 +269,9 @@ async def test_datetime_usage_in_workflow(client: Client): def test_pydantic_model_with_strict_field_outside_sandbox(): _test_pydantic_model_with_strict_field( - PydanticModelWithStrictField(strict_field=datetime(2025, 1, 2, 3, 4, 5)) + PydanticModelWithStrictField( + strict_field=datetime.datetime(2025, 1, 2, 3, 4, 5) + ) ) @@ -276,7 +285,9 @@ async def test_pydantic_model_with_strict_field_inside_sandbox(client: Client): workflows=[PydanticModelWithStrictFieldWorkflow], task_queue=tq, ): - orig = PydanticModelWithStrictField(strict_field=datetime(2025, 1, 2, 3, 4, 5)) + orig = PydanticModelWithStrictField( + strict_field=datetime.datetime(2025, 1, 2, 3, 4, 5) + ) result = await client.execute_workflow( PydanticModelWithStrictFieldWorkflow.run, orig, @@ -324,3 +335,46 @@ async def test_validation_error(client: Client): task_queue=task_queue_name, result_type=tuple[int], ) + + +class RestrictedProxyFieldsModel(BaseModel): + path_field: pathlib.Path + uuid_field: uuid.UUID + datetime_field: datetime.datetime + + +def test_model_instantiation_from_restricted_proxy_values(): + restricted_path_cls = _RestrictedProxy( + "Path", + pathlib.Path, + RestrictionContext(), + SandboxMatcher(), + ) + restricted_uuid_cls = _RestrictedProxy( + "uuid", + uuid.UUID, + RestrictionContext(), + SandboxMatcher(), + ) + restricted_datetime_cls = _RestrictedProxy( + "datetime", + datetime.datetime, + RestrictionContext(), + SandboxMatcher(), + ) + + restricted_path = restricted_path_cls("test/path") + restricted_uuid = restricted_uuid_cls(bytes=os.urandom(16), version=4) + restricted_datetime = restricted_datetime_cls(2025, 1, 2, 3, 4, 5) + + assert type(restricted_path) is _RestrictedProxy + assert type(restricted_uuid) is _RestrictedProxy + assert type(restricted_datetime) is not _RestrictedProxy + p = RestrictedProxyFieldsModel( + path_field=restricted_path, # type: ignore + uuid_field=restricted_uuid, # type: ignore + datetime_field=restricted_datetime, # type: ignore + ) + assert p.path_field == restricted_path + assert p.uuid_field == restricted_uuid + assert p.datetime_field == restricted_datetime