From bef66b56e7ed0695664d2e98af2395c7ee89290d Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Thu, 17 Apr 2025 16:06:07 -0700 Subject: [PATCH 01/21] Init commit - waiting for core changes --- .vscode/settings.json | 7 +++ temporalio/activity.py | 1 + temporalio/bridge/src/client.rs | 4 +- temporalio/bridge/src/worker.rs | 1 + temporalio/bridge/worker.py | 5 +- temporalio/client.py | 11 +++++ temporalio/exceptions.py | 12 +++++ temporalio/worker/_activity.py | 1 + tests/conftest.py | 2 + tests/helpers/__init__.py | 18 +++++++ tests/worker/test_workflow.py | 88 ++++++++++++++++++++++++++++++++- 11 files changed, 145 insertions(+), 5 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..98523efd5 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "rust-analyzer.linkedProjects": [ + // Add the path to the Cargo.toml file of your Rust project + // relative to the workspace root + "temporalio/bridge/sdk-core/Cargo.toml" + ] +} \ No newline at end of file diff --git a/temporalio/activity.py b/temporalio/activity.py index c67fa0f38..c0ba5b92c 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -266,6 +266,7 @@ def heartbeat(*details: Any) -> None: Raises: RuntimeError: When not in an activity. """ + print("PUBLIC HEARTBEAT FN") heartbeat_fn = _Context.current().heartbeat if not heartbeat_fn: raise RuntimeError("Can only execute heartbeat after interceptor init") diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index f5c0aa750..e4aeba3a2 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -231,10 +231,12 @@ impl ClientRef { } "list_workflow_rules" => { rpc_call!(retry_client, call, list_workflow_rules) - } "patch_schedule" => { rpc_call!(retry_client, call, patch_schedule) } + "pause_activity" => { + rpc_call!(retry_client, call, pause_activity) + } "poll_activity_task_queue" => { rpc_call!(retry_client, call, poll_activity_task_queue) } diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 7034c60c2..57873a25c 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -601,6 +601,7 @@ impl WorkerRef { } fn record_activity_heartbeat(&self, proto: &PyBytes) -> PyResult<()> { + println!("IN BRIDGE - RUST"); enter_sync!(self.runtime); let heartbeat = ActivityHeartbeat::decode(proto.as_bytes()) .map_err(|err| PyValueError::new_err(format!("Invalid proto: {}", err)))?; diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 24f5c1227..9439989dd 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -231,9 +231,10 @@ async def complete_activity_task( def record_activity_heartbeat( self, comp: temporalio.bridge.proto.ActivityHeartbeat - ) -> None: + ) -> temporalio.bridge.proto.RecordActivityHeartbeatResponse: """Record an activity heartbeat.""" - self._ref.record_activity_heartbeat(comp.SerializeToString()) + print("IN BRIDGE - PYTHON") + return self._ref.record_activity_heartbeat(comp.SerializeToString()) def request_workflow_eviction(self, run_id: str) -> None: """Request a workflow be evicted.""" diff --git a/temporalio/client.py b/temporalio/client.py index 4cd9d1f19..266d831d5 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -5149,6 +5149,12 @@ def __init__(self) -> None: """Create async activity cancelled error.""" super().__init__("Activity cancelled") +class AsyncActivityPausedError(temporalio.exceptions.TemporalError): + """Error that occurs when async activity attempted heartbeat but was paused.""" + + def __init__(self) -> None: + """Create async activity paused error.""" + super().__init__("Activity paused") class ScheduleAlreadyRunningError(temporalio.exceptions.TemporalError): """Error when a schedule is already running.""" @@ -6289,6 +6295,9 @@ async def heartbeat_async_activity( ) if resp_by_id.cancel_requested: raise AsyncActivityCancelledError() + if resp_by_id.activity_paused: + raise AsyncActivityPausedError() + else: resp = await self._client.workflow_service.record_activity_task_heartbeat( temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest( @@ -6303,6 +6312,8 @@ async def heartbeat_async_activity( ) if resp.cancel_requested: raise AsyncActivityCancelledError() + if resp.activity_paused: + raise AsyncActivityPausedError() async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: result = ( diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f045b36a0..643b5e43a 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -155,6 +155,18 @@ def details(self) -> Sequence[Any]: """User-defined details on the error.""" return self._details +class ActivityPausedError(FailureError): + """Error raised on activity pause.""" + + def __init__(self, message: str = "Activity paused.", *details: Any) -> None: + """Initialize an activity paused error.""" + super().__init__(message) + self._details = details + + @property + def details(self) -> Sequence[Any]: + """User-defined details on the error.""" + return self._details class TerminatedError(FailureError): """Error raised on workflow cancellation.""" diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index f38a27e12..826994e7d 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -241,6 +241,7 @@ async def _heartbeat_async( activity: _RunningActivity, task_token: bytes, ) -> None: + print("HEARTBEAT ASYNC FN") # Drain the queue, only taking the last value to actually heartbeat details: Optional[Sequence[Any]] = None while not activity.pending_heartbeats.empty(): diff --git a/tests/conftest.py b/tests/conftest.py index be99e117f..37b1fe89c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,6 +115,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "frontend.workerVersioningDataAPIs=true", "--dynamic-config-value", "system.enableDeploymentVersions=true", + "--dynamic-config-value", + "frontend.activityAPIsEnabled=true", ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index da5259748..6fdcd478b 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -14,6 +14,7 @@ ) from temporalio.api.update.v1 import UpdateRef from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest +from temporalio.api.workflow.v1 import PendingActivityInfo from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle from temporalio.common import SearchAttributeKey from temporalio.service import RPCError, RPCStatusCode @@ -210,3 +211,20 @@ async def check_workflow_exists() -> bool: await assert_eq_eventually(True, check_workflow_exists) assert handle is not None return handle + + +async def assert_pending_activity_exists_eventually( + handle: WorkflowHandle, + activity_id: str, + timeout: timedelta = timedelta(seconds=5), +) -> PendingActivityInfo: + """Wait until a pending activity with the given ID exists and return it.""" + async def check() -> Optional[PendingActivityInfo]: + desc = await handle.describe() + for act in desc.raw_description.pending_activities: + if act.activity_id == activity_id: + return act + raise AssertionError(f"Activity with ID {activity_id} not found in pending activities") + + activity_info = await assert_eventually(check, timeout=timeout) + return activity_info diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 45bcf4a41..ad4bdd9d0 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -48,6 +48,7 @@ from temporalio.api.workflowservice.v1 import ( GetWorkflowExecutionHistoryRequest, ResetStickyTaskQueueRequest, + PauseActivityRequest ) from temporalio.bridge.proto.workflow_activation import WorkflowActivation from temporalio.bridge.proto.workflow_completion import WorkflowActivationCompletion @@ -92,6 +93,7 @@ TemporalError, TimeoutError, WorkflowAlreadyStartedError, + ActivityPausedError ) from temporalio.runtime import ( BUFFERED_METRIC_KIND_COUNTER, @@ -119,6 +121,7 @@ assert_task_fail_eventually, assert_workflow_exists_eventually, ensure_search_attributes_present, + assert_pending_activity_exists_eventually, find_free_port, new_worker, workflow_update_exists, @@ -133,7 +136,6 @@ with workflow.unsafe.imports_passed_through(): import pytest - @workflow.defn class HelloWorkflow: @workflow.run @@ -7481,7 +7483,6 @@ async def test_expose_root_execution(client: Client, env: WorkflowEnvironment): assert child_wf_info_root.workflow_id == parent_desc.id assert child_wf_info_root.run_id == parent_desc.run_id - @workflow.defn(dynamic=True) class WorkflowDynamicConfigFnFailure: @workflow.dynamic_config @@ -7622,3 +7623,86 @@ async def test_workflow_missing_local_activity_no_activities(client: Client): handle, message_contains="Activity function say_hello is not registered on this worker, no available activities", ) +@activity.defn +async def heartbeat_activity() -> str: + while True: + try: + activity.heartbeat() + await asyncio.sleep(1) + except ActivityPausedError as e: + return "Paused" + +@workflow.defn +class ActivityHeartbeatWorkflow: + @workflow.run + async def run(self, activity_id: str) -> str: + await workflow.execute_activity( + heartbeat_activity, + activity_id=activity_id, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + + result = await workflow.execute_activity( + heartbeat_activity, + activity_id=f"{activity_id}-2", + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + return result + +async def test_activity_pause(client: Client, env: WorkflowEnvironment): + async def pause_and_assert( + client: Client, handle: WorkflowHandle, activity_id: str + ): + """Pause the given activity and assert it becomes paused.""" + desc = await handle.describe() + req = PauseActivityRequest( + namespace=client.namespace, + execution=WorkflowExecution( + workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id, + run_id=desc.raw_description.workflow_execution_info.execution.run_id, + ), + id=activity_id, + ) + await client.workflow_service.pause_activity(req) + + # Assert eventually paused + async def check_paused() -> bool: + info = await assert_pending_activity_exists_eventually(handle, activity_id) + return info.paused + + await assert_eventually(check_paused) + + async with new_worker( + client, ActivityHeartbeatWorkflow, activities=[heartbeat_activity] + ) as worker: + test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" + + handle: WorkflowHandle[str] = await client.start_workflow( + ActivityHeartbeatWorkflow.run, + test_activity_id, + id=f"test-activity-pause-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Wait for first activity + activity_info_1 = await assert_pending_activity_exists_eventually(handle, test_activity_id) + # Assert not paused + assert not activity_info_1.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_1.activity_id) + + # Wait for second activity + activity_info_2 = await assert_pending_activity_exists_eventually( + handle, f"{test_activity_id}-2" + ) + # # Assert not paused + # assert not activity_info_2.paused + # # Pause activity then assert it is paused + # await pause_and_assert(client, handle, activity_info_2.activity_id) + + # # Assert workflow returned "Paused" + # assert await handle.result() == "Paused" From 7fad0b4da90bf550f5ca8e2f70282be59f71671b Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Wed, 30 Apr 2025 23:40:40 -0700 Subject: [PATCH 02/21] add activity paused error usage --- temporalio/activity.py | 1 - temporalio/bridge/src/worker.rs | 1 - temporalio/bridge/worker.py | 5 ++--- temporalio/worker/_activity.py | 37 ++++++++++++++++++++++++++++----- tests/helpers/__init__.py | 2 +- tests/worker/test_workflow.py | 16 +++++++------- 6 files changed, 43 insertions(+), 19 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index c0ba5b92c..c67fa0f38 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -266,7 +266,6 @@ def heartbeat(*details: Any) -> None: Raises: RuntimeError: When not in an activity. """ - print("PUBLIC HEARTBEAT FN") heartbeat_fn = _Context.current().heartbeat if not heartbeat_fn: raise RuntimeError("Can only execute heartbeat after interceptor init") diff --git a/temporalio/bridge/src/worker.rs b/temporalio/bridge/src/worker.rs index 57873a25c..7034c60c2 100644 --- a/temporalio/bridge/src/worker.rs +++ b/temporalio/bridge/src/worker.rs @@ -601,7 +601,6 @@ impl WorkerRef { } fn record_activity_heartbeat(&self, proto: &PyBytes) -> PyResult<()> { - println!("IN BRIDGE - RUST"); enter_sync!(self.runtime); let heartbeat = ActivityHeartbeat::decode(proto.as_bytes()) .map_err(|err| PyValueError::new_err(format!("Invalid proto: {}", err)))?; diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 9439989dd..24f5c1227 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -231,10 +231,9 @@ async def complete_activity_task( def record_activity_heartbeat( self, comp: temporalio.bridge.proto.ActivityHeartbeat - ) -> temporalio.bridge.proto.RecordActivityHeartbeatResponse: + ) -> None: """Record an activity heartbeat.""" - print("IN BRIDGE - PYTHON") - return self._ref.record_activity_heartbeat(comp.SerializeToString()) + self._ref.record_activity_heartbeat(comp.SerializeToString()) def request_workflow_eviction(self, run_id: str) -> None: """Request a workflow be evicted.""" diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 826994e7d..61174ad47 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -216,7 +216,12 @@ def _cancel( warnings.warn(f"Cannot find activity to cancel for token {task_token!r}") return logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason) - activity.cancel(cancelled_by_request=True) + activity.cancel( + cancelled_by_request=cancel.reason + == temporalio.bridge.proto.activity_task.ActivityCancelReason.CANCELLED, + cancelled_by_pause=cancel.reason + == temporalio.bridge.proto.activity_task.ActivityCancelReason.PAUSED, + ) def _heartbeat(self, task_token: bytes, *details: Any) -> None: # We intentionally make heartbeating non-async, but since the data @@ -241,7 +246,6 @@ async def _heartbeat_async( activity: _RunningActivity, task_token: bytes, ) -> None: - print("HEARTBEAT ASYNC FN") # Drain the queue, only taking the last value to actually heartbeat details: Optional[Sequence[Any]] = None while not activity.pending_heartbeats.empty(): @@ -317,6 +321,21 @@ async def _run_activity( temporalio.exceptions.CancelledError("Cancelled"), completion.result.cancelled.failure, ) + elif ( + isinstance( + err, + ( + asyncio.CancelledError, + temporalio.exceptions.ActivityPausedError, + ), + ) + and running_activity.cancelled_by_pause + ): + temporalio.activity.logger.debug("Completing as paused") + await self._data_converter.encode_failure( + temporalio.exceptions.ActivityPausedError("Activity paused"), + completion.result.cancelled.failure, + ) else: if ( isinstance( @@ -571,23 +590,31 @@ class _RunningActivity: done: bool = False cancelled_by_request: bool = False cancelled_due_to_heartbeat_error: Optional[Exception] = None + cancelled_by_pause: bool = False def cancel( self, *, cancelled_by_request: bool = False, cancelled_due_to_heartbeat_error: Optional[Exception] = None, + cancelled_by_pause: bool = False, ) -> None: self.cancelled_by_request = cancelled_by_request self.cancelled_due_to_heartbeat_error = cancelled_due_to_heartbeat_error + self.cancelled_by_pause = cancelled_by_pause if self.cancelled_event: self.cancelled_event.set() if not self.done: # If there's a thread raiser, use it if self.cancel_thread_raiser: - self.cancel_thread_raiser.raise_in_thread( - temporalio.exceptions.CancelledError - ) + if self.cancelled_by_pause: + self.cancel_thread_raiser.raise_in_thread( + temporalio.exceptions.ActivityPausedError + ) + else: + self.cancel_thread_raiser.raise_in_thread( + temporalio.exceptions.CancelledError + ) # If not sync and there's a task, cancel it if not self.sync and self.task: # TODO(cretz): Check that Python >= 3.9 and set msg? diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 6fdcd478b..e99811125 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -227,4 +227,4 @@ async def check() -> Optional[PendingActivityInfo]: raise AssertionError(f"Activity with ID {activity_id} not found in pending activities") activity_info = await assert_eventually(check, timeout=timeout) - return activity_info + return cast(PendingActivityInfo, activity_info) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index ad4bdd9d0..de39b51a3 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -7629,7 +7629,7 @@ async def heartbeat_activity() -> str: try: activity.heartbeat() await asyncio.sleep(1) - except ActivityPausedError as e: + except (ActivityPausedError, asyncio.CancelledError): return "Paused" @workflow.defn @@ -7681,7 +7681,7 @@ async def check_paused() -> bool: ) as worker: test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" - handle: WorkflowHandle[str] = await client.start_workflow( + handle = await client.start_workflow( ActivityHeartbeatWorkflow.run, test_activity_id, id=f"test-activity-pause-{uuid.uuid4()}", @@ -7699,10 +7699,10 @@ async def check_paused() -> bool: activity_info_2 = await assert_pending_activity_exists_eventually( handle, f"{test_activity_id}-2" ) - # # Assert not paused - # assert not activity_info_2.paused - # # Pause activity then assert it is paused - # await pause_and_assert(client, handle, activity_info_2.activity_id) + # Assert not paused + assert not activity_info_2.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_2.activity_id) - # # Assert workflow returned "Paused" - # assert await handle.result() == "Paused" + # Assert workflow returned "Paused" + assert await handle.result() == "Paused" From 0539be0a7853bc63760e341620ef1e14f2b7afd3 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 3 May 2025 01:38:34 -0400 Subject: [PATCH 03/21] working for async activities --- temporalio/activity.py | 30 +++++++++++++++++ temporalio/client.py | 16 +++------ temporalio/exceptions.py | 12 ------- temporalio/worker/_activity.py | 47 +++++++++----------------- tests/worker/test_workflow.py | 60 +++++++++++++++++++++++----------- 5 files changed, 90 insertions(+), 75 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index c67fa0f38..d8e028da0 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -34,6 +34,9 @@ overload, ) +import temporalio.bridge +import temporalio.bridge.proto +import temporalio.bridge.proto.activity_task import temporalio.common import temporalio.converter @@ -135,6 +138,27 @@ def _logger_details(self) -> Mapping[str, Any]: _current_context: contextvars.ContextVar[_Context] = contextvars.ContextVar("activity") +@dataclass +class ActivityCancellationDetails: + not_found: bool = False + cancelled: bool = False + paused: bool = False + timed_out: bool = False + worker_shutdown: bool = False + + @staticmethod + def fromProto( + proto: temporalio.bridge.proto.activity_task.ActivityCancellationDetails, + ) -> ActivityCancellationDetails: + return ActivityCancellationDetails( + not_found=proto.is_not_found, + cancelled=proto.is_cancelled, + paused=proto.is_paused, + timed_out=proto.is_timed_out, + worker_shutdown=proto.is_worker_shutdown, + ) + + @dataclass class _Context: info: Callable[[], Info] @@ -148,6 +172,7 @@ class _Context: temporalio.converter.PayloadConverter, ] runtime_metric_meter: Optional[temporalio.common.MetricMeter] + cancellation_details: Callable[[], Optional[ActivityCancellationDetails]] _logger_details: Optional[Mapping[str, Any]] = None _payload_converter: Optional[temporalio.converter.PayloadConverter] = None _metric_meter: Optional[temporalio.common.MetricMeter] = None @@ -260,6 +285,11 @@ def info() -> Info: return _Context.current().info() +def cancellation_details() -> Optional[ActivityCancellationDetails]: + """Cancellation details of the currenct activity, if any""" + return _Context.current().cancellation_details() + + def heartbeat(*details: Any) -> None: """Send a heartbeat for the current activity. diff --git a/temporalio/client.py b/temporalio/client.py index 266d831d5..375e1b4cc 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -5149,12 +5149,6 @@ def __init__(self) -> None: """Create async activity cancelled error.""" super().__init__("Activity cancelled") -class AsyncActivityPausedError(temporalio.exceptions.TemporalError): - """Error that occurs when async activity attempted heartbeat but was paused.""" - - def __init__(self) -> None: - """Create async activity paused error.""" - super().__init__("Activity paused") class ScheduleAlreadyRunningError(temporalio.exceptions.TemporalError): """Error when a schedule is already running.""" @@ -6293,10 +6287,9 @@ async def heartbeat_async_activity( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ) - if resp_by_id.cancel_requested: + # TODO(thomas): modify activity context (if applicable to async activities) + if resp_by_id.cancel_requested or resp_by_id.activity_paused: raise AsyncActivityCancelledError() - if resp_by_id.activity_paused: - raise AsyncActivityPausedError() else: resp = await self._client.workflow_service.record_activity_task_heartbeat( @@ -6310,10 +6303,9 @@ async def heartbeat_async_activity( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ) - if resp.cancel_requested: + # TODO(thomas): modify activity context (if applicable to async activities) + if resp.cancel_requested or resp.activity_paused: raise AsyncActivityCancelledError() - if resp.activity_paused: - raise AsyncActivityPausedError() async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: result = ( diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index 643b5e43a..f045b36a0 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -155,18 +155,6 @@ def details(self) -> Sequence[Any]: """User-defined details on the error.""" return self._details -class ActivityPausedError(FailureError): - """Error raised on activity pause.""" - - def __init__(self, message: str = "Activity paused.", *details: Any) -> None: - """Initialize an activity paused error.""" - super().__init__(message) - self._details = details - - @property - def details(self) -> Sequence[Any]: - """User-defined details on the error.""" - return self._details class TerminatedError(FailureError): """Error raised on workflow cancellation.""" diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 61174ad47..cfc47a5d7 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -216,12 +216,10 @@ def _cancel( warnings.warn(f"Cannot find activity to cancel for token {task_token!r}") return logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason) - activity.cancel( - cancelled_by_request=cancel.reason - == temporalio.bridge.proto.activity_task.ActivityCancelReason.CANCELLED, - cancelled_by_pause=cancel.reason - == temporalio.bridge.proto.activity_task.ActivityCancelReason.PAUSED, + activity.cancellation_details = ( + temporalio.activity.ActivityCancellationDetails.fromProto(cancel.details) ) + activity.cancel(cancelled_by_request=True) def _heartbeat(self, task_token: bytes, *details: Any) -> None: # We intentionally make heartbeating non-async, but since the data @@ -321,21 +319,6 @@ async def _run_activity( temporalio.exceptions.CancelledError("Cancelled"), completion.result.cancelled.failure, ) - elif ( - isinstance( - err, - ( - asyncio.CancelledError, - temporalio.exceptions.ActivityPausedError, - ), - ) - and running_activity.cancelled_by_pause - ): - temporalio.activity.logger.debug("Completing as paused") - await self._data_converter.encode_failure( - temporalio.exceptions.ActivityPausedError("Activity paused"), - completion.result.cancelled.failure, - ) else: if ( isinstance( @@ -544,6 +527,8 @@ async def _execute_activity( else running_activity.cancel_thread_raiser.shielded, payload_converter_class_or_instance=self._data_converter.payload_converter, runtime_metric_meter=None if sync_non_threaded else self._metric_meter, + # Function reference to the running activity's cancellation details + cancellation_details=lambda: running_activity.cancellation_details, ) ) temporalio.activity.logger.debug("Starting activity") @@ -590,31 +575,26 @@ class _RunningActivity: done: bool = False cancelled_by_request: bool = False cancelled_due_to_heartbeat_error: Optional[Exception] = None - cancelled_by_pause: bool = False + cancellation_details: Optional[temporalio.activity.ActivityCancellationDetails] = ( + None + ) def cancel( self, *, cancelled_by_request: bool = False, cancelled_due_to_heartbeat_error: Optional[Exception] = None, - cancelled_by_pause: bool = False, ) -> None: self.cancelled_by_request = cancelled_by_request self.cancelled_due_to_heartbeat_error = cancelled_due_to_heartbeat_error - self.cancelled_by_pause = cancelled_by_pause if self.cancelled_event: self.cancelled_event.set() if not self.done: # If there's a thread raiser, use it if self.cancel_thread_raiser: - if self.cancelled_by_pause: - self.cancel_thread_raiser.raise_in_thread( - temporalio.exceptions.ActivityPausedError - ) - else: - self.cancel_thread_raiser.raise_in_thread( - temporalio.exceptions.CancelledError - ) + self.cancel_thread_raiser.raise_in_thread( + temporalio.exceptions.CancelledError + ) # If not sync and there's a task, cancel it if not self.sync and self.task: # TODO(cretz): Check that Python >= 3.9 and set msg? @@ -687,6 +667,7 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any: # can set the initializer on the executor). ctx = temporalio.activity._Context.current() info = ctx.info() + cancellation_details = ctx.cancellation_details # Heartbeat calls internally use a data converter which is async so # they need to be called on the event loop @@ -745,6 +726,7 @@ async def heartbeat_with_context(*details: Any) -> None: worker_shutdown_event.thread_event, payload_converter_class_or_instance, ctx.runtime_metric_meter, + cancellation_details, input.fn, *input.args, ] @@ -760,7 +742,6 @@ async def heartbeat_with_context(*details: Any) -> None: finally: if shared_manager: await shared_manager.unregister_heartbeater(info.task_token) - # Otherwise for async activity, just run return await input.fn(*input.args) @@ -792,6 +773,7 @@ def _execute_sync_activity( temporalio.converter.PayloadConverter, ], runtime_metric_meter: Optional[temporalio.common.MetricMeter], + cancellation_details: Callable[[], Optional[temporalio.activity.ActivityCancellationDetails]], fn: Callable[..., Any], *args: Any, ) -> Any: @@ -823,6 +805,7 @@ def _execute_sync_activity( else cancel_thread_raiser.shielded, payload_converter_class_or_instance=payload_converter_class_or_instance, runtime_metric_meter=runtime_metric_meter, + # cancellation_details=None ) ) return fn(*args) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index de39b51a3..1682799b3 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -38,6 +38,7 @@ from google.protobuf.timestamp_pb2 import Timestamp from typing_extensions import Literal, Protocol, runtime_checkable +import temporalio.activity import temporalio.worker import temporalio.workflow from temporalio import activity, workflow @@ -7624,32 +7625,51 @@ async def test_workflow_missing_local_activity_no_activities(client: Client): message_contains="Activity function say_hello is not registered on this worker, no available activities", ) @activity.defn -async def heartbeat_activity() -> str: +async def heartbeat_activity() -> ( + Optional[temporalio.activity.ActivityCancellationDetails] +): + while True: + try: + activity.heartbeat() + await asyncio.sleep(1) + except (CancelledError, asyncio.CancelledError): + return activity.cancellation_details() + +@activity.defn +async def sync_heartbeat_activity() -> ( + Optional[temporalio.activity.ActivityCancellationDetails] +): while True: try: activity.heartbeat() await asyncio.sleep(1) - except (ActivityPausedError, asyncio.CancelledError): - return "Paused" + except (CancelledError, asyncio.CancelledError): + return activity.cancellation_details() @workflow.defn class ActivityHeartbeatWorkflow: @workflow.run - async def run(self, activity_id: str) -> str: - await workflow.execute_activity( - heartbeat_activity, - activity_id=activity_id, - start_to_close_timeout=timedelta(seconds=10), - heartbeat_timeout=timedelta(seconds=2), - retry_policy=RetryPolicy(maximum_attempts=1), + async def run( + self, activity_id: str + ) -> list[Optional[temporalio.activity.ActivityCancellationDetails]]: + result = [] + result.append( + await workflow.execute_activity( + sync_heartbeat_activity, + activity_id=activity_id, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=1), + ) ) - - result = await workflow.execute_activity( - heartbeat_activity, - activity_id=f"{activity_id}-2", - start_to_close_timeout=timedelta(seconds=10), - heartbeat_timeout=timedelta(seconds=2), - retry_policy=RetryPolicy(maximum_attempts=1), + result.append( + await workflow.execute_activity( + heartbeat_activity, + activity_id=f"{activity_id}-2", + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=1), + ) ) return result @@ -7677,7 +7697,7 @@ async def check_paused() -> bool: await assert_eventually(check_paused) async with new_worker( - client, ActivityHeartbeatWorkflow, activities=[heartbeat_activity] + client, ActivityHeartbeatWorkflow, activities=[heartbeat_activity, sync_heartbeat_activity] ) as worker: test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" @@ -7705,4 +7725,6 @@ async def check_paused() -> bool: await pause_and_assert(client, handle, activity_info_2.activity_id) # Assert workflow returned "Paused" - assert await handle.result() == "Paused" + result = await handle.result() + assert result[0] == temporalio.activity.ActivityCancellationDetails(paused=True) + assert result[1] == temporalio.activity.ActivityCancellationDetails(paused=True) From c87a9482126ad2af21cc24d6bc14aff9367f69b2 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 3 May 2025 01:44:39 -0400 Subject: [PATCH 04/21] working for sync activities --- temporalio/worker/_activity.py | 2 +- tests/worker/test_workflow.py | 70 +++++++++++++++++++--------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index cfc47a5d7..b486a42ed 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -805,7 +805,7 @@ def _execute_sync_activity( else cancel_thread_raiser.shielded, payload_converter_class_or_instance=payload_converter_class_or_instance, runtime_metric_meter=runtime_metric_meter, - # cancellation_details=None + cancellation_details=cancellation_details ) ) return fn(*args) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 1682799b3..87c27f587 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import concurrent.futures import dataclasses import json import logging @@ -7636,13 +7637,13 @@ async def heartbeat_activity() -> ( return activity.cancellation_details() @activity.defn -async def sync_heartbeat_activity() -> ( +def sync_heartbeat_activity() -> ( Optional[temporalio.activity.ActivityCancellationDetails] ): while True: try: activity.heartbeat() - await asyncio.sleep(1) + time.sleep(1) except (CancelledError, asyncio.CancelledError): return activity.cancellation_details() @@ -7696,35 +7697,42 @@ async def check_paused() -> bool: await assert_eventually(check_paused) - async with new_worker( - client, ActivityHeartbeatWorkflow, activities=[heartbeat_activity, sync_heartbeat_activity] - ) as worker: - test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" - - handle = await client.start_workflow( - ActivityHeartbeatWorkflow.run, - test_activity_id, - id=f"test-activity-pause-{uuid.uuid4()}", - task_queue=worker.task_queue, - ) + with concurrent.futures.ThreadPoolExecutor() as executor: + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[ActivityHeartbeatWorkflow], + activities=[heartbeat_activity, sync_heartbeat_activity], + activity_executor=executor + ) as worker: + test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" - # Wait for first activity - activity_info_1 = await assert_pending_activity_exists_eventually(handle, test_activity_id) - # Assert not paused - assert not activity_info_1.paused - # Pause activity then assert it is paused - await pause_and_assert(client, handle, activity_info_1.activity_id) + handle = await client.start_workflow( + ActivityHeartbeatWorkflow.run, + test_activity_id, + id=f"test-activity-pause-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) - # Wait for second activity - activity_info_2 = await assert_pending_activity_exists_eventually( - handle, f"{test_activity_id}-2" - ) - # Assert not paused - assert not activity_info_2.paused - # Pause activity then assert it is paused - await pause_and_assert(client, handle, activity_info_2.activity_id) + # Wait for first activity + activity_info_1 = await assert_pending_activity_exists_eventually( + handle, test_activity_id + ) + # Assert not paused + assert not activity_info_1.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_1.activity_id) + + # Wait for second activity + activity_info_2 = await assert_pending_activity_exists_eventually( + handle, f"{test_activity_id}-2" + ) + # Assert not paused + assert not activity_info_2.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_2.activity_id) - # Assert workflow returned "Paused" - result = await handle.result() - assert result[0] == temporalio.activity.ActivityCancellationDetails(paused=True) - assert result[1] == temporalio.activity.ActivityCancellationDetails(paused=True) + # Assert workflow returned "Paused" + result = await handle.result() + assert result[0] == temporalio.activity.ActivityCancellationDetails(paused=True) + assert result[1] == temporalio.activity.ActivityCancellationDetails(paused=True) From 6db508af1d243edb921aa906772afffb7f4d959a Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 3 May 2025 02:28:33 -0400 Subject: [PATCH 05/21] linting and cleanup --- temporalio/activity.py | 4 +++- temporalio/client.py | 2 -- temporalio/testing/_activity.py | 2 ++ temporalio/worker/_activity.py | 8 +++++--- tests/worker/test_workflow.py | 12 +++++++++--- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index d8e028da0..638c65172 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -140,6 +140,8 @@ def _logger_details(self) -> Mapping[str, Any]: @dataclass class ActivityCancellationDetails: + """Provides the reasons for the activity's cancellation""" + not_found: bool = False cancelled: bool = False paused: bool = False @@ -147,7 +149,7 @@ class ActivityCancellationDetails: worker_shutdown: bool = False @staticmethod - def fromProto( + def _fromProto( proto: temporalio.bridge.proto.activity_task.ActivityCancellationDetails, ) -> ActivityCancellationDetails: return ActivityCancellationDetails( diff --git a/temporalio/client.py b/temporalio/client.py index 375e1b4cc..be1de4b5e 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6287,7 +6287,6 @@ async def heartbeat_async_activity( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ) - # TODO(thomas): modify activity context (if applicable to async activities) if resp_by_id.cancel_requested or resp_by_id.activity_paused: raise AsyncActivityCancelledError() @@ -6303,7 +6302,6 @@ async def heartbeat_async_activity( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ) - # TODO(thomas): modify activity context (if applicable to async activities) if resp.cancel_requested or resp.activity_paused: raise AsyncActivityCancelledError() diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 19dd3819b..4bdb025b4 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -74,6 +74,7 @@ def __init__(self) -> None: self._cancelled = False self._worker_shutdown = False self._activities: Set[_Activity] = set() + self.cancellation_details = None def cancel(self) -> None: """Cancel the activity. @@ -154,6 +155,7 @@ def __init__( else self.cancel_thread_raiser.shielded, payload_converter_class_or_instance=env.payload_converter, runtime_metric_meter=env.metric_meter, + cancellation_details=lambda: env.cancellation_details, ) self.task: Optional[asyncio.Task] = None diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index b486a42ed..57bbdde5f 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -217,7 +217,7 @@ def _cancel( return logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason) activity.cancellation_details = ( - temporalio.activity.ActivityCancellationDetails.fromProto(cancel.details) + temporalio.activity.ActivityCancellationDetails._fromProto(cancel.details) ) activity.cancel(cancelled_by_request=True) @@ -773,7 +773,9 @@ def _execute_sync_activity( temporalio.converter.PayloadConverter, ], runtime_metric_meter: Optional[temporalio.common.MetricMeter], - cancellation_details: Callable[[], Optional[temporalio.activity.ActivityCancellationDetails]], + cancellation_details: Callable[ + [], Optional[temporalio.activity.ActivityCancellationDetails] + ], fn: Callable[..., Any], *args: Any, ) -> Any: @@ -805,7 +807,7 @@ def _execute_sync_activity( else cancel_thread_raiser.shielded, payload_converter_class_or_instance=payload_converter_class_or_instance, runtime_metric_meter=runtime_metric_meter, - cancellation_details=cancellation_details + cancellation_details=cancellation_details, ) ) return fn(*args) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 87c27f587..e7d72f0bf 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -7636,6 +7636,7 @@ async def heartbeat_activity() -> ( except (CancelledError, asyncio.CancelledError): return activity.cancellation_details() + @activity.defn def sync_heartbeat_activity() -> ( Optional[temporalio.activity.ActivityCancellationDetails] @@ -7647,6 +7648,7 @@ def sync_heartbeat_activity() -> ( except (CancelledError, asyncio.CancelledError): return activity.cancellation_details() + @workflow.defn class ActivityHeartbeatWorkflow: @workflow.run @@ -7703,7 +7705,7 @@ async def check_paused() -> bool: task_queue=str(uuid.uuid4()), workflows=[ActivityHeartbeatWorkflow], activities=[heartbeat_activity, sync_heartbeat_activity], - activity_executor=executor + activity_executor=executor, ) as worker: test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" @@ -7734,5 +7736,9 @@ async def check_paused() -> bool: # Assert workflow returned "Paused" result = await handle.result() - assert result[0] == temporalio.activity.ActivityCancellationDetails(paused=True) - assert result[1] == temporalio.activity.ActivityCancellationDetails(paused=True) + assert result[0] == temporalio.activity.ActivityCancellationDetails( + paused=True + ) + assert result[1] == temporalio.activity.ActivityCancellationDetails( + paused=True + ) From 04c598c8884992c5c7a1e4051da461e18e5b6624 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 3 May 2025 02:33:25 -0400 Subject: [PATCH 06/21] add cancellation details arg to testing ActivityEnvironment cancel --- temporalio/testing/_activity.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 4bdb025b4..4d35bf716 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -74,16 +74,29 @@ def __init__(self) -> None: self._cancelled = False self._worker_shutdown = False self._activities: Set[_Activity] = set() - self.cancellation_details = None + self.cancellation_details: Optional[ + temporalio.activity.ActivityCancellationDetails + ] = None - def cancel(self) -> None: + def cancel( + self, + cancellation_details: Optional[ + temporalio.activity.ActivityCancellationDetails + ] = None, + ) -> None: """Cancel the activity. + Args: + cancellation_details: Optional details about the cancellation. When provided, these + will be accessible through temporalio.activity.cancellation_details() + in the activity after cancellation. + This only has an effect on the first call. """ if self._cancelled: return self._cancelled = True + self.cancellation_details = cancellation_details for act in self._activities: act.cancel() From 6f5104b5d5712a865ba0d4af3c5018c3306e237f Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 3 May 2025 10:43:29 -0400 Subject: [PATCH 07/21] remove .vscode --- .vscode/settings.json | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 98523efd5..000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "rust-analyzer.linkedProjects": [ - // Add the path to the Cargo.toml file of your Rust project - // relative to the workspace root - "temporalio/bridge/sdk-core/Cargo.toml" - ] -} \ No newline at end of file From ea5b0554829a0404563146f9ef9cfb0483947758 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 3 May 2025 11:01:05 -0400 Subject: [PATCH 08/21] formatting --- tests/worker/test_workflow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index e7d72f0bf..39fa2be8a 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -138,6 +138,7 @@ with workflow.unsafe.imports_passed_through(): import pytest + @workflow.defn class HelloWorkflow: @workflow.run @@ -7485,6 +7486,7 @@ async def test_expose_root_execution(client: Client, env: WorkflowEnvironment): assert child_wf_info_root.workflow_id == parent_desc.id assert child_wf_info_root.run_id == parent_desc.run_id + @workflow.defn(dynamic=True) class WorkflowDynamicConfigFnFailure: @workflow.dynamic_config From fcb7b33f0f77bac515f576a531603af32a45e7d0 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 3 May 2025 13:35:15 -0400 Subject: [PATCH 09/21] use object reference instead of function, picklable --- temporalio/activity.py | 4 ++-- temporalio/testing/_activity.py | 2 +- temporalio/worker/_activity.py | 18 +++++++----------- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index 638c65172..9cf7f757b 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -174,7 +174,7 @@ class _Context: temporalio.converter.PayloadConverter, ] runtime_metric_meter: Optional[temporalio.common.MetricMeter] - cancellation_details: Callable[[], Optional[ActivityCancellationDetails]] + cancellation_details: Optional[ActivityCancellationDetails] = None _logger_details: Optional[Mapping[str, Any]] = None _payload_converter: Optional[temporalio.converter.PayloadConverter] = None _metric_meter: Optional[temporalio.common.MetricMeter] = None @@ -289,7 +289,7 @@ def info() -> Info: def cancellation_details() -> Optional[ActivityCancellationDetails]: """Cancellation details of the currenct activity, if any""" - return _Context.current().cancellation_details() + return _Context.current().cancellation_details def heartbeat(*details: Any) -> None: diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 4d35bf716..cb5b0ff9c 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -168,7 +168,7 @@ def __init__( else self.cancel_thread_raiser.shielded, payload_converter_class_or_instance=env.payload_converter, runtime_metric_meter=env.metric_meter, - cancellation_details=lambda: env.cancellation_details, + cancellation_details=env.cancellation_details, ) self.task: Optional[asyncio.Task] = None diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 57bbdde5f..336352e09 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -15,7 +15,7 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import ( Any, @@ -216,9 +216,8 @@ def _cancel( warnings.warn(f"Cannot find activity to cancel for token {task_token!r}") return logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason) - activity.cancellation_details = ( - temporalio.activity.ActivityCancellationDetails._fromProto(cancel.details) - ) + activity.cancellation_details.cancelled = cancel.details.is_cancelled + activity.cancellation_details.paused = cancel.details.is_paused activity.cancel(cancelled_by_request=True) def _heartbeat(self, task_token: bytes, *details: Any) -> None: @@ -527,8 +526,7 @@ async def _execute_activity( else running_activity.cancel_thread_raiser.shielded, payload_converter_class_or_instance=self._data_converter.payload_converter, runtime_metric_meter=None if sync_non_threaded else self._metric_meter, - # Function reference to the running activity's cancellation details - cancellation_details=lambda: running_activity.cancellation_details, + cancellation_details=running_activity.cancellation_details, ) ) temporalio.activity.logger.debug("Starting activity") @@ -575,8 +573,8 @@ class _RunningActivity: done: bool = False cancelled_by_request: bool = False cancelled_due_to_heartbeat_error: Optional[Exception] = None - cancellation_details: Optional[temporalio.activity.ActivityCancellationDetails] = ( - None + cancellation_details: temporalio.activity.ActivityCancellationDetails = field( + default_factory=temporalio.activity.ActivityCancellationDetails ) def cancel( @@ -773,9 +771,7 @@ def _execute_sync_activity( temporalio.converter.PayloadConverter, ], runtime_metric_meter: Optional[temporalio.common.MetricMeter], - cancellation_details: Callable[ - [], Optional[temporalio.activity.ActivityCancellationDetails] - ], + cancellation_details: Optional[temporalio.activity.ActivityCancellationDetails], fn: Callable[..., Any], *args: Any, ) -> Any: From 7c0556cfa0fd35920ee3a8febb47c4daa6bd64a5 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Mon, 5 May 2025 11:23:45 -0400 Subject: [PATCH 10/21] nits --- temporalio/activity.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index 9cf7f757b..6091e24ec 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -143,7 +143,7 @@ class ActivityCancellationDetails: """Provides the reasons for the activity's cancellation""" not_found: bool = False - cancelled: bool = False + cancelled_requested: bool = False paused: bool = False timed_out: bool = False worker_shutdown: bool = False @@ -154,7 +154,7 @@ def _fromProto( ) -> ActivityCancellationDetails: return ActivityCancellationDetails( not_found=proto.is_not_found, - cancelled=proto.is_cancelled, + cancelled_requested=proto.is_cancelled, paused=proto.is_paused, timed_out=proto.is_timed_out, worker_shutdown=proto.is_worker_shutdown, @@ -288,7 +288,7 @@ def info() -> Info: def cancellation_details() -> Optional[ActivityCancellationDetails]: - """Cancellation details of the currenct activity, if any""" + """Cancellation details of the current activity, if any""" return _Context.current().cancellation_details From 5d3339b72ce4ee44d34d347f17b14ca0fe5acca6 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Mon, 5 May 2025 12:43:37 -0400 Subject: [PATCH 11/21] use holder --- temporalio/activity.py | 18 +++++++++++++++--- temporalio/testing/_activity.py | 9 +++++---- temporalio/worker/_activity.py | 11 ++++++----- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index 6091e24ec..9b469f062 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -139,6 +139,18 @@ def _logger_details(self) -> Mapping[str, Any]: @dataclass +class _ActivityCancellationDetailsHolder: + _details: Optional[ActivityCancellationDetails] = None + + def set_details(self, details: ActivityCancellationDetails) -> None: + self._details = details + + @property + def details(self) -> Optional[ActivityCancellationDetails]: + return self._details + + +@dataclass(frozen=True) class ActivityCancellationDetails: """Provides the reasons for the activity's cancellation""" @@ -149,7 +161,7 @@ class ActivityCancellationDetails: worker_shutdown: bool = False @staticmethod - def _fromProto( + def _from_proto( proto: temporalio.bridge.proto.activity_task.ActivityCancellationDetails, ) -> ActivityCancellationDetails: return ActivityCancellationDetails( @@ -174,7 +186,7 @@ class _Context: temporalio.converter.PayloadConverter, ] runtime_metric_meter: Optional[temporalio.common.MetricMeter] - cancellation_details: Optional[ActivityCancellationDetails] = None + cancellation_details: _ActivityCancellationDetailsHolder _logger_details: Optional[Mapping[str, Any]] = None _payload_converter: Optional[temporalio.converter.PayloadConverter] = None _metric_meter: Optional[temporalio.common.MetricMeter] = None @@ -289,7 +301,7 @@ def info() -> Info: def cancellation_details() -> Optional[ActivityCancellationDetails]: """Cancellation details of the current activity, if any""" - return _Context.current().cancellation_details + return _Context.current().cancellation_details.details def heartbeat(*details: Any) -> None: diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index cb5b0ff9c..08a467af6 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -74,9 +74,9 @@ def __init__(self) -> None: self._cancelled = False self._worker_shutdown = False self._activities: Set[_Activity] = set() - self.cancellation_details: Optional[ - temporalio.activity.ActivityCancellationDetails - ] = None + self.cancellation_details: ( + temporalio.activity._ActivityCancellationDetailsHolder + ) def cancel( self, @@ -96,7 +96,8 @@ def cancel( if self._cancelled: return self._cancelled = True - self.cancellation_details = cancellation_details + if cancellation_details: + self.cancellation_details.set_details(cancellation_details) for act in self._activities: act.cancel() diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 336352e09..8d659c121 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -216,8 +216,9 @@ def _cancel( warnings.warn(f"Cannot find activity to cancel for token {task_token!r}") return logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason) - activity.cancellation_details.cancelled = cancel.details.is_cancelled - activity.cancellation_details.paused = cancel.details.is_paused + activity.cancellation_details.set_details( + temporalio.activity.ActivityCancellationDetails._from_proto(cancel.details) + ) activity.cancel(cancelled_by_request=True) def _heartbeat(self, task_token: bytes, *details: Any) -> None: @@ -573,8 +574,8 @@ class _RunningActivity: done: bool = False cancelled_by_request: bool = False cancelled_due_to_heartbeat_error: Optional[Exception] = None - cancellation_details: temporalio.activity.ActivityCancellationDetails = field( - default_factory=temporalio.activity.ActivityCancellationDetails + cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder = ( + field(default_factory=temporalio.activity._ActivityCancellationDetailsHolder) ) def cancel( @@ -771,7 +772,7 @@ def _execute_sync_activity( temporalio.converter.PayloadConverter, ], runtime_metric_meter: Optional[temporalio.common.MetricMeter], - cancellation_details: Optional[temporalio.activity.ActivityCancellationDetails], + cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder, fn: Callable[..., Any], *args: Any, ) -> Any: From 330dd414f0e9a63e0fb8dfe0388983a225fdfc9d Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Mon, 5 May 2025 12:46:55 -0400 Subject: [PATCH 12/21] docstrings --- temporalio/activity.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index 9b469f062..6c3ab7567 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -152,7 +152,7 @@ def details(self) -> Optional[ActivityCancellationDetails]: @dataclass(frozen=True) class ActivityCancellationDetails: - """Provides the reasons for the activity's cancellation""" + """Provides the reasons for the activity's cancellation. Cancellation details are set once and do not change once set.""" not_found: bool = False cancelled_requested: bool = False @@ -300,7 +300,7 @@ def info() -> Info: def cancellation_details() -> Optional[ActivityCancellationDetails]: - """Cancellation details of the current activity, if any""" + """Cancellation details of the current activity, if any. Once set, cancellation details do not change.""" return _Context.current().cancellation_details.details From 21c4e6e3308184a45e256d30f029c5adafee6ec2 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Tue, 6 May 2025 13:04:51 -0400 Subject: [PATCH 13/21] add test for pause/unpause --- temporalio/activity.py | 13 +-- temporalio/bridge/src/client.rs | 5 ++ temporalio/testing/_activity.py | 8 +- temporalio/worker/_activity.py | 22 ++++- tests/helpers/__init__.py | 48 ++++++++++- tests/worker/test_workflow.py | 139 +++++++++++++++++++++++++------- 6 files changed, 187 insertions(+), 48 deletions(-) diff --git a/temporalio/activity.py b/temporalio/activity.py index 6c3ab7567..4a0914bc2 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -140,14 +140,7 @@ def _logger_details(self) -> Mapping[str, Any]: @dataclass class _ActivityCancellationDetailsHolder: - _details: Optional[ActivityCancellationDetails] = None - - def set_details(self, details: ActivityCancellationDetails) -> None: - self._details = details - - @property - def details(self) -> Optional[ActivityCancellationDetails]: - return self._details + details: Optional[ActivityCancellationDetails] = None @dataclass(frozen=True) @@ -155,7 +148,7 @@ class ActivityCancellationDetails: """Provides the reasons for the activity's cancellation. Cancellation details are set once and do not change once set.""" not_found: bool = False - cancelled_requested: bool = False + cancel_requested: bool = False paused: bool = False timed_out: bool = False worker_shutdown: bool = False @@ -166,7 +159,7 @@ def _from_proto( ) -> ActivityCancellationDetails: return ActivityCancellationDetails( not_found=proto.is_not_found, - cancelled_requested=proto.is_cancelled, + cancel_requested=proto.is_cancelled, paused=proto.is_paused, timed_out=proto.is_timed_out, worker_shutdown=proto.is_worker_shutdown, diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index e4aeba3a2..594eaf22f 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -324,8 +324,13 @@ impl ClientRef { "terminate_workflow_execution" => { rpc_call!(retry_client, call, terminate_workflow_execution) } +<<<<<<< HEAD "trigger_workflow_rule" => { rpc_call!(retry_client, call, trigger_workflow_rule) +======= + "unpause_activity" => { + rpc_call!(retry_client, call, unpause_activity) +>>>>>>> 7e4b2d9 (add test for pause/unpause) } "update_namespace" => { rpc_call_on_trait!(retry_client, call, WorkflowService, update_namespace) diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 08a467af6..700bdd102 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -74,8 +74,8 @@ def __init__(self) -> None: self._cancelled = False self._worker_shutdown = False self._activities: Set[_Activity] = set() - self.cancellation_details: ( - temporalio.activity._ActivityCancellationDetailsHolder + self._cancellation_details = ( + temporalio.activity._ActivityCancellationDetailsHolder() ) def cancel( @@ -97,7 +97,7 @@ def cancel( return self._cancelled = True if cancellation_details: - self.cancellation_details.set_details(cancellation_details) + self._cancellation_details.details = cancellation_details for act in self._activities: act.cancel() @@ -169,7 +169,7 @@ def __init__( else self.cancel_thread_raiser.shielded, payload_converter_class_or_instance=env.payload_converter, runtime_metric_meter=env.metric_meter, - cancellation_details=env.cancellation_details, + cancellation_details=env._cancellation_details, ) self.task: Optional[asyncio.Task] = None diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 8d659c121..6a5297661 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -216,10 +216,10 @@ def _cancel( warnings.warn(f"Cannot find activity to cancel for token {task_token!r}") return logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason) - activity.cancellation_details.set_details( + activity.cancellation_details.details = ( temporalio.activity.ActivityCancellationDetails._from_proto(cancel.details) ) - activity.cancel(cancelled_by_request=True) + activity.cancel(cancelled_by_request=cancel.details.is_cancelled) def _heartbeat(self, task_token: bytes, *details: Any) -> None: # We intentionally make heartbeating non-async, but since the data @@ -306,6 +306,23 @@ async def _run_activity( await self._data_converter.encode_failure( err, completion.result.failed.failure ) + elif ( + isinstance( + err, + (asyncio.CancelledError, temporalio.exceptions.CancelledError), + ) + and running_activity.cancellation_details.details + and running_activity.cancellation_details.details.paused + ): + temporalio.activity.logger.warning( + f"Completing as failure due to unhandled cancel error produced by activity pause", + ) + await self._data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + "Unhandled activity cancel error produced by activity pause" + ), + completion.result.failed.failure, + ) elif ( isinstance( err, @@ -339,7 +356,6 @@ async def _run_activity( await self._data_converter.encode_failure( err, completion.result.failed.failure ) - # For broken executors, we have to fail the entire worker if isinstance(err, concurrent.futures.BrokenExecutor): self._fail_worker_exception_queue.put_nowait(err) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index e99811125..dd3f49f26 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -13,8 +13,12 @@ ListSearchAttributesRequest, ) from temporalio.api.update.v1 import UpdateRef -from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest from temporalio.api.workflow.v1 import PendingActivityInfo +from temporalio.api.workflowservice.v1 import ( + PauseActivityRequest, + PollWorkflowExecutionUpdateRequest, + UnpauseActivityRequest, +) from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle from temporalio.common import SearchAttributeKey from temporalio.service import RPCError, RPCStatusCode @@ -228,3 +232,45 @@ async def check() -> Optional[PendingActivityInfo]: activity_info = await assert_eventually(check, timeout=timeout) return cast(PendingActivityInfo, activity_info) + + +async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str): + """Pause the given activity and assert it becomes paused.""" + desc = await handle.describe() + req = PauseActivityRequest( + namespace=client.namespace, + execution=WorkflowExecution( + workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id, + run_id=desc.raw_description.workflow_execution_info.execution.run_id, + ), + id=activity_id, + ) + await client.workflow_service.pause_activity(req) + + # Assert eventually paused + async def check_paused() -> bool: + info = await assert_pending_activity_exists_eventually(handle, activity_id) + return info.paused + + await assert_eventually(check_paused) + + +async def unpause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str): + """Unpause the given activity and assert it is not paused.""" + desc = await handle.describe() + req = UnpauseActivityRequest( + namespace=client.namespace, + execution=WorkflowExecution( + workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id, + run_id=desc.raw_description.workflow_execution_info.execution.run_id, + ), + id=activity_id, + ) + await client.workflow_service.unpause_activity(req) + + # Assert eventually not paused + async def check_unpaused() -> bool: + info = await assert_pending_activity_exists_eventually(handle, activity_id) + return not info.paused + + await assert_eventually(check_unpaused) \ No newline at end of file diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 39fa2be8a..e40bc613a 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -126,6 +126,8 @@ assert_pending_activity_exists_eventually, find_free_port, new_worker, + pause_and_assert, + unpause_and_assert, workflow_update_exists, ) from tests.helpers.external_stack_trace import ( @@ -7627,27 +7629,42 @@ async def test_workflow_missing_local_activity_no_activities(client: Client): handle, message_contains="Activity function say_hello is not registered on this worker, no available activities", ) + + @activity.defn async def heartbeat_activity() -> ( Optional[temporalio.activity.ActivityCancellationDetails] ): +async def heartbeat_activity( + catch_err: bool = True, +) -> Optional[temporalio.activity.ActivityCancellationDetails]: while True: try: activity.heartbeat() + # If we are on the second attempt, we have retried due to pause/unpause. + if activity.info().attempt > 1: + return activity.cancellation_details() await asyncio.sleep(1) - except (CancelledError, asyncio.CancelledError): + except (CancelledError, asyncio.CancelledError) as err: + if not catch_err: + raise err return activity.cancellation_details() @activity.defn -def sync_heartbeat_activity() -> ( - Optional[temporalio.activity.ActivityCancellationDetails] -): +def sync_heartbeat_activity( + catch_err: bool = True, +) -> Optional[temporalio.activity.ActivityCancellationDetails]: while True: try: activity.heartbeat() + # If we are on the second attempt, we have retried due to pause/unpause. + if activity.info().attempt > 1: + return activity.cancellation_details() time.sleep(1) - except (CancelledError, asyncio.CancelledError): + except (CancelledError, asyncio.CancelledError) as err: + if not catch_err: + raise err return activity.cancellation_details() @@ -7678,29 +7695,8 @@ async def run( ) return result -async def test_activity_pause(client: Client, env: WorkflowEnvironment): - async def pause_and_assert( - client: Client, handle: WorkflowHandle, activity_id: str - ): - """Pause the given activity and assert it becomes paused.""" - desc = await handle.describe() - req = PauseActivityRequest( - namespace=client.namespace, - execution=WorkflowExecution( - workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id, - run_id=desc.raw_description.workflow_execution_info.execution.run_id, - ), - id=activity_id, - ) - await client.workflow_service.pause_activity(req) - - # Assert eventually paused - async def check_paused() -> bool: - info = await assert_pending_activity_exists_eventually(handle, activity_id) - return info.paused - - await assert_eventually(check_paused) +async def test_activity_pause_cancellation_details(client: Client): with concurrent.futures.ThreadPoolExecutor() as executor: async with Worker( client, @@ -7718,7 +7714,7 @@ async def check_paused() -> bool: task_queue=worker.task_queue, ) - # Wait for first activity + # Wait for sync activity activity_info_1 = await assert_pending_activity_exists_eventually( handle, test_activity_id ) @@ -7727,7 +7723,7 @@ async def check_paused() -> bool: # Pause activity then assert it is paused await pause_and_assert(client, handle, activity_info_1.activity_id) - # Wait for second activity + # Wait for async activity activity_info_2 = await assert_pending_activity_exists_eventually( handle, f"{test_activity_id}-2" ) @@ -7736,7 +7732,8 @@ async def check_paused() -> bool: # Pause activity then assert it is paused await pause_and_assert(client, handle, activity_info_2.activity_id) - # Assert workflow returned "Paused" + # Assert workflow return value for paused activities that caught the + # cancel error result = await handle.result() assert result[0] == temporalio.activity.ActivityCancellationDetails( paused=True @@ -7744,3 +7741,85 @@ async def check_paused() -> bool: assert result[1] == temporalio.activity.ActivityCancellationDetails( paused=True ) + + +@workflow.defn +class ActivityHeartbeatPauseUnpauseWorkflow: + @workflow.run + async def run( + self, activity_id: str + ) -> list[Optional[temporalio.activity.ActivityCancellationDetails]]: + results = [] + results.append( + await workflow.execute_activity( + sync_heartbeat_activity, + False, + activity_id=activity_id, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=2), + ) + ) + results.append( + await workflow.execute_activity( + heartbeat_activity, + False, + activity_id=f"{activity_id}-2", + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=2), + ) + ) + return results + + +async def test_activity_pause_unpause(client: Client): + with concurrent.futures.ThreadPoolExecutor() as executor: + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[ActivityHeartbeatPauseUnpauseWorkflow], + activities=[heartbeat_activity, sync_heartbeat_activity], + activity_executor=executor, + ) as worker: + test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" + + handle = await client.start_workflow( + ActivityHeartbeatPauseUnpauseWorkflow.run, + test_activity_id, + id=f"test-activity-pause-unpause-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Wait for sync activity + activity_info_1 = await assert_pending_activity_exists_eventually( + handle, test_activity_id + ) + # Assert not paused + assert not activity_info_1.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_1.activity_id) + + # Wait for next heartbeat to propagate the cancellation. Unpausing before the heartbeat + # will show activity as unpaused to core. Consequently, it will *not* issue an activity cancel. + time.sleep(2) + + # Unpause activity + await unpause_and_assert(client, handle, activity_info_1.activity_id) + # Expect second activity to have started now + activity_info_2 = await assert_pending_activity_exists_eventually( + handle, f"{test_activity_id}-2" + ) + # Assert not paused + assert not activity_info_2.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_2.activity_id) + # Wait for next heartbeat to propagate the cancellation. + time.sleep(2) + # Unpause activity + await unpause_and_assert(client, handle, activity_info_2.activity_id) + + # Check workflow complete + result = await handle.result() + assert result[0] == None + assert result[1] == None From 8ffff73d37e3e62790c2ca8c315401416a742ff0 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Tue, 6 May 2025 13:35:53 -0400 Subject: [PATCH 14/21] linting, reduce heartbeat timeouts for faster test --- tests/helpers/__init__.py | 2 +- tests/worker/test_workflow.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index dd3f49f26..cfa3c2df9 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -273,4 +273,4 @@ async def check_unpaused() -> bool: info = await assert_pending_activity_exists_eventually(handle, activity_id) return not info.paused - await assert_eventually(check_unpaused) \ No newline at end of file + await assert_eventually(check_unpaused) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index e40bc613a..079868294 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -7644,7 +7644,7 @@ async def heartbeat_activity( # If we are on the second attempt, we have retried due to pause/unpause. if activity.info().attempt > 1: return activity.cancellation_details() - await asyncio.sleep(1) + await asyncio.sleep(0.1) except (CancelledError, asyncio.CancelledError) as err: if not catch_err: raise err @@ -7661,7 +7661,7 @@ def sync_heartbeat_activity( # If we are on the second attempt, we have retried due to pause/unpause. if activity.info().attempt > 1: return activity.cancellation_details() - time.sleep(1) + time.sleep(0.1) except (CancelledError, asyncio.CancelledError) as err: if not catch_err: raise err @@ -7756,7 +7756,7 @@ async def run( False, activity_id=activity_id, start_to_close_timeout=timedelta(seconds=10), - heartbeat_timeout=timedelta(seconds=2), + heartbeat_timeout=timedelta(seconds=1), retry_policy=RetryPolicy(maximum_attempts=2), ) ) @@ -7766,7 +7766,7 @@ async def run( False, activity_id=f"{activity_id}-2", start_to_close_timeout=timedelta(seconds=10), - heartbeat_timeout=timedelta(seconds=2), + heartbeat_timeout=timedelta(seconds=1), retry_policy=RetryPolicy(maximum_attempts=2), ) ) @@ -7781,6 +7781,8 @@ async def test_activity_pause_unpause(client: Client): workflows=[ActivityHeartbeatPauseUnpauseWorkflow], activities=[heartbeat_activity, sync_heartbeat_activity], activity_executor=executor, + max_heartbeat_throttle_interval=timedelta(milliseconds=300), + default_heartbeat_throttle_interval=timedelta(milliseconds=300), ) as worker: test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" @@ -7802,8 +7804,7 @@ async def test_activity_pause_unpause(client: Client): # Wait for next heartbeat to propagate the cancellation. Unpausing before the heartbeat # will show activity as unpaused to core. Consequently, it will *not* issue an activity cancel. - time.sleep(2) - + time.sleep(0.3) # Unpause activity await unpause_and_assert(client, handle, activity_info_1.activity_id) # Expect second activity to have started now @@ -7815,7 +7816,7 @@ async def test_activity_pause_unpause(client: Client): # Pause activity then assert it is paused await pause_and_assert(client, handle, activity_info_2.activity_id) # Wait for next heartbeat to propagate the cancellation. - time.sleep(2) + time.sleep(0.3) # Unpause activity await unpause_and_assert(client, handle, activity_info_2.activity_id) From 8cd18d503e7c24a0200f9ad3a922f1d1e9a08f35 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Tue, 6 May 2025 16:53:01 -0400 Subject: [PATCH 15/21] make cancellation details non-optional for testing activity env --- temporalio/testing/_activity.py | 11 ++++------- tests/testing/test_activity.py | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 700bdd102..ae57c16fb 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -80,15 +80,13 @@ def __init__(self) -> None: def cancel( self, - cancellation_details: Optional[ - temporalio.activity.ActivityCancellationDetails - ] = None, + cancellation_details: temporalio.activity.ActivityCancellationDetails, ) -> None: """Cancel the activity. Args: - cancellation_details: Optional details about the cancellation. When provided, these - will be accessible through temporalio.activity.cancellation_details() + cancellation_details: details about the cancellation. These will + be accessible through temporalio.activity.cancellation_details() in the activity after cancellation. This only has an effect on the first call. @@ -96,8 +94,7 @@ def cancel( if self._cancelled: return self._cancelled = True - if cancellation_details: - self._cancellation_details.details = cancellation_details + self._cancellation_details.details = cancellation_details for act in self._activities: act.cancel() diff --git a/tests/testing/test_activity.py b/tests/testing/test_activity.py index 29b66c772..ff281d722 100644 --- a/tests/testing/test_activity.py +++ b/tests/testing/test_activity.py @@ -26,7 +26,11 @@ async def via_create_task(): await asyncio.Future() raise RuntimeError("Unreachable") except asyncio.CancelledError: - activity.heartbeat("cancelled") + cancellation_details = activity.cancellation_details() + if cancellation_details: + activity.heartbeat( + f"cancelled={cancellation_details.cancel_requested}", + ) return "done" env = ActivityEnvironment() @@ -37,9 +41,11 @@ async def via_create_task(): task = asyncio.create_task(env.run(do_stuff, "param1")) await waiting.wait() # Cancel and confirm done - env.cancel() + env.cancel( + cancellation_details=activity.ActivityCancellationDetails(cancel_requested=True) + ) assert "done" == await task - assert heartbeats == ["param: param1", "task, type: unknown", "cancelled"] + assert heartbeats == ["param: param1", "task, type: unknown", "cancelled=True"] def test_activity_env_sync(): @@ -72,7 +78,11 @@ def via_thread(): raise RuntimeError("Unexpected") except CancelledError: nonlocal properly_cancelled - properly_cancelled = True + cancellation_details = activity.cancellation_details() + if cancellation_details: + properly_cancelled = cancellation_details.cancel_requested + else: + properly_cancelled = False env = ActivityEnvironment() # Set heartbeat handler to add to list @@ -84,7 +94,9 @@ def via_thread(): waiting.wait() # Cancel and confirm done time.sleep(1) - env.cancel() + env.cancel( + cancellation_details=activity.ActivityCancellationDetails(cancel_requested=True) + ) thread.join() assert heartbeats == ["param: param1", "task, type: unknown"] assert properly_cancelled From 824517c08ec46d1dd6962318af7280295f7c4f9c Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Fri, 16 May 2025 10:56:30 -0400 Subject: [PATCH 16/21] address pr suggestion --- temporalio/bridge/src/client.rs | 5 ++--- temporalio/testing/_activity.py | 2 +- tests/helpers/__init__.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index 594eaf22f..49210c063 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -231,6 +231,7 @@ impl ClientRef { } "list_workflow_rules" => { rpc_call!(retry_client, call, list_workflow_rules) + } "patch_schedule" => { rpc_call!(retry_client, call, patch_schedule) } @@ -324,13 +325,11 @@ impl ClientRef { "terminate_workflow_execution" => { rpc_call!(retry_client, call, terminate_workflow_execution) } -<<<<<<< HEAD "trigger_workflow_rule" => { rpc_call!(retry_client, call, trigger_workflow_rule) -======= + } "unpause_activity" => { rpc_call!(retry_client, call, unpause_activity) ->>>>>>> 7e4b2d9 (add test for pause/unpause) } "update_namespace" => { rpc_call_on_trait!(retry_client, call, WorkflowService, update_namespace) diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index ae57c16fb..3dbdc13fc 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -80,7 +80,7 @@ def __init__(self) -> None: def cancel( self, - cancellation_details: temporalio.activity.ActivityCancellationDetails, + cancellation_details: temporalio.activity.ActivityCancellationDetails = temporalio.activity.ActivityCancellationDetails(cancel_requested=True), ) -> None: """Cancel the activity. diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index cfa3c2df9..dd3f49f26 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -273,4 +273,4 @@ async def check_unpaused() -> bool: info = await assert_pending_activity_exists_eventually(handle, activity_id) return not info.paused - await assert_eventually(check_unpaused) + await assert_eventually(check_unpaused) \ No newline at end of file From 4fa5bfa08febed63c12701000fadc620fe30edfe Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Fri, 16 May 2025 11:31:01 -0400 Subject: [PATCH 17/21] rebase conflict fixes --- temporalio/testing/_activity.py | 4 +++- tests/helpers/__init__.py | 7 +++++-- tests/worker/test_workflow.py | 8 ++------ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 3dbdc13fc..3694dfdc7 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -80,7 +80,9 @@ def __init__(self) -> None: def cancel( self, - cancellation_details: temporalio.activity.ActivityCancellationDetails = temporalio.activity.ActivityCancellationDetails(cancel_requested=True), + cancellation_details: temporalio.activity.ActivityCancellationDetails = temporalio.activity.ActivityCancellationDetails( + cancel_requested=True + ), ) -> None: """Cancel the activity. diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index dd3f49f26..90aa9b8e2 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -223,12 +223,15 @@ async def assert_pending_activity_exists_eventually( timeout: timedelta = timedelta(seconds=5), ) -> PendingActivityInfo: """Wait until a pending activity with the given ID exists and return it.""" + async def check() -> Optional[PendingActivityInfo]: desc = await handle.describe() for act in desc.raw_description.pending_activities: if act.activity_id == activity_id: return act - raise AssertionError(f"Activity with ID {activity_id} not found in pending activities") + raise AssertionError( + f"Activity with ID {activity_id} not found in pending activities" + ) activity_info = await assert_eventually(check, timeout=timeout) return cast(PendingActivityInfo, activity_info) @@ -273,4 +276,4 @@ async def check_unpaused() -> bool: info = await assert_pending_activity_exists_eventually(handle, activity_id) return not info.paused - await assert_eventually(check_unpaused) \ No newline at end of file + await assert_eventually(check_unpaused) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 079868294..7422a42a3 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -49,8 +49,8 @@ from temporalio.api.sdk.v1 import EnhancedStackTrace from temporalio.api.workflowservice.v1 import ( GetWorkflowExecutionHistoryRequest, + PauseActivityRequest, ResetStickyTaskQueueRequest, - PauseActivityRequest ) from temporalio.bridge.proto.workflow_activation import WorkflowActivation from temporalio.bridge.proto.workflow_completion import WorkflowActivationCompletion @@ -95,7 +95,6 @@ TemporalError, TimeoutError, WorkflowAlreadyStartedError, - ActivityPausedError ) from temporalio.runtime import ( BUFFERED_METRIC_KIND_COUNTER, @@ -120,10 +119,10 @@ admitted_update_task, assert_eq_eventually, assert_eventually, + assert_pending_activity_exists_eventually, assert_task_fail_eventually, assert_workflow_exists_eventually, ensure_search_attributes_present, - assert_pending_activity_exists_eventually, find_free_port, new_worker, pause_and_assert, @@ -7632,9 +7631,6 @@ async def test_workflow_missing_local_activity_no_activities(client: Client): @activity.defn -async def heartbeat_activity() -> ( - Optional[temporalio.activity.ActivityCancellationDetails] -): async def heartbeat_activity( catch_err: bool = True, ) -> Optional[temporalio.activity.ActivityCancellationDetails]: From a48f77a63996423d8f28fa3ff25c3e4939c1013c Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Fri, 16 May 2025 12:29:21 -0400 Subject: [PATCH 18/21] include is_worker_shutdown as reason for requested cancellation, test fix --- temporalio/worker/_activity.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 6a5297661..b05f3f6e9 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -219,7 +219,10 @@ def _cancel( activity.cancellation_details.details = ( temporalio.activity.ActivityCancellationDetails._from_proto(cancel.details) ) - activity.cancel(cancelled_by_request=cancel.details.is_cancelled) + activity.cancel( + cancelled_by_request=cancel.details.is_cancelled + or cancel.details.is_worker_shutdown + ) def _heartbeat(self, task_token: bytes, *details: Any) -> None: # We intentionally make heartbeating non-async, but since the data @@ -319,7 +322,8 @@ async def _run_activity( ) await self._data_converter.encode_failure( temporalio.exceptions.ApplicationError( - "Unhandled activity cancel error produced by activity pause" + type="ActivityPause", + message="Unhandled activity cancel error produced by activity pause", ), completion.result.failed.failure, ) From 998540a319b6f306c667852d4346a0760d309ca7 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Fri, 16 May 2025 18:04:33 -0400 Subject: [PATCH 19/21] skip if time-skipping server (does not support pause/unpause yet) --- tests/worker/test_workflow.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 7422a42a3..fab6870bb 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -7692,7 +7692,11 @@ async def run( return result -async def test_activity_pause_cancellation_details(client: Client): +async def test_activity_pause_cancellation_details( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Time-skipping server does not support pause API yet") with concurrent.futures.ThreadPoolExecutor() as executor: async with Worker( client, @@ -7769,7 +7773,9 @@ async def run( return results -async def test_activity_pause_unpause(client: Client): +async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Time-skipping server does not support pause API yet") with concurrent.futures.ThreadPoolExecutor() as executor: async with Worker( client, From 56913c414d0af48f9f9f08f40bc536a8100ed3f2 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Tue, 27 May 2025 15:37:32 -0700 Subject: [PATCH 20/21] remove sleep calls from tests, add cancellation details to async cancellation errors from external activities --- temporalio/client.py | 20 +++++++-- tests/helpers/__init__.py | 46 ++++++++++++++++--- tests/worker/test_workflow.py | 85 ++++++++++++++++++++++++++++++++--- 3 files changed, 134 insertions(+), 17 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index be1de4b5e..1c7514c5e 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -56,6 +56,7 @@ import temporalio.runtime import temporalio.service import temporalio.workflow +from temporalio.activity import ActivityCancellationDetails from temporalio.service import ( HttpConnectProxyConfig, KeepAliveConfig, @@ -5145,9 +5146,12 @@ def __init__(self) -> None: class AsyncActivityCancelledError(temporalio.exceptions.TemporalError): """Error that occurs when async activity attempted heartbeat but was cancelled.""" - def __init__(self) -> None: + details: Optional[ActivityCancellationDetails] = None + + def __init__(self, details: Optional[ActivityCancellationDetails] = None) -> None: """Create async activity cancelled error.""" super().__init__("Activity cancelled") + self.details = details class ScheduleAlreadyRunningError(temporalio.exceptions.TemporalError): @@ -6288,7 +6292,12 @@ async def heartbeat_async_activity( timeout=input.rpc_timeout, ) if resp_by_id.cancel_requested or resp_by_id.activity_paused: - raise AsyncActivityCancelledError() + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp_by_id.cancel_requested, + paused=resp_by_id.activity_paused, + ) + ) else: resp = await self._client.workflow_service.record_activity_task_heartbeat( @@ -6303,7 +6312,12 @@ async def heartbeat_async_activity( timeout=input.rpc_timeout, ) if resp.cancel_requested or resp.activity_paused: - raise AsyncActivityCancelledError() + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp.cancel_requested, + paused=resp.activity_paused, + ) + ) async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: result = ( diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 90aa9b8e2..99c0f1cbb 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -224,17 +224,49 @@ async def assert_pending_activity_exists_eventually( ) -> PendingActivityInfo: """Wait until a pending activity with the given ID exists and return it.""" - async def check() -> Optional[PendingActivityInfo]: - desc = await handle.describe() - for act in desc.raw_description.pending_activities: - if act.activity_id == activity_id: - return act + async def check() -> PendingActivityInfo: + act_info = await _get_pending_activity_info(handle, activity_id) + if act_info is not None: + return act_info raise AssertionError( f"Activity with ID {activity_id} not found in pending activities" ) - activity_info = await assert_eventually(check, timeout=timeout) - return cast(PendingActivityInfo, activity_info) + return await assert_eventually(check, timeout=timeout) + + +async def wait_for_next_heartbeat_cycle( + handle: WorkflowHandle, + activity_id: str, + initial_heartbeat_time: Any, + timeout: timedelta = timedelta(seconds=5), +) -> None: + """Wait for the next heartbeat cycle by monitoring last_heartbeat_time changes.""" + + async def check_heartbeat_changed() -> None: + current_info = await _get_pending_activity_info(handle, activity_id) + if current_info is None: + raise AssertionError( + f"Activity with ID {activity_id} not found in pending activities" + ) + if current_info.last_heartbeat_time == initial_heartbeat_time: + raise AssertionError( + f"Activity with ID {activity_id} has not heartbeated yet" + ) + + await assert_eventually(check_heartbeat_changed, timeout=timeout) + + +async def _get_pending_activity_info( + handle: WorkflowHandle, + activity_id: str, +) -> Optional[PendingActivityInfo]: + """Get pending activity info by ID, or None if not found.""" + desc = await handle.describe() + for act in desc.raw_description.pending_activities: + if act.activity_id == activity_id: + return act + return None async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str): diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index fab6870bb..589e3bc3e 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -49,12 +49,12 @@ from temporalio.api.sdk.v1 import EnhancedStackTrace from temporalio.api.workflowservice.v1 import ( GetWorkflowExecutionHistoryRequest, - PauseActivityRequest, ResetStickyTaskQueueRequest, ) from temporalio.bridge.proto.workflow_activation import WorkflowActivation from temporalio.bridge.proto.workflow_completion import WorkflowActivationCompletion from temporalio.client import ( + AsyncActivityCancelledError, Client, RPCError, RPCStatusCode, @@ -127,6 +127,7 @@ new_worker, pause_and_assert, unpause_and_assert, + wait_for_next_heartbeat_cycle, workflow_update_exists, ) from tests.helpers.external_stack_trace import ( @@ -7637,14 +7638,16 @@ async def heartbeat_activity( while True: try: activity.heartbeat() - # If we are on the second attempt, we have retried due to pause/unpause. - if activity.info().attempt > 1: + # If we have heartbeat details, we are on the second attempt, we have retried due to pause/unpause. + if activity.info().heartbeat_details: return activity.cancellation_details() await asyncio.sleep(0.1) except (CancelledError, asyncio.CancelledError) as err: if not catch_err: raise err return activity.cancellation_details() + finally: + activity.heartbeat("finally-complete") @activity.defn @@ -7654,14 +7657,16 @@ def sync_heartbeat_activity( while True: try: activity.heartbeat() - # If we are on the second attempt, we have retried due to pause/unpause. - if activity.info().attempt > 1: + # If we have heartbeat details, we are on the second attempt, we have retried due to pause/unpause. + if activity.info().heartbeat_details: return activity.cancellation_details() time.sleep(0.1) except (CancelledError, asyncio.CancelledError) as err: if not catch_err: raise err return activity.cancellation_details() + finally: + activity.heartbeat("finally-complete") @workflow.defn @@ -7806,7 +7811,10 @@ async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): # Wait for next heartbeat to propagate the cancellation. Unpausing before the heartbeat # will show activity as unpaused to core. Consequently, it will *not* issue an activity cancel. - time.sleep(0.3) + await wait_for_next_heartbeat_cycle( + handle, activity_info_1.activity_id, activity_info_1.last_heartbeat_time + ) + # Unpause activity await unpause_and_assert(client, handle, activity_info_1.activity_id) # Expect second activity to have started now @@ -7818,7 +7826,9 @@ async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): # Pause activity then assert it is paused await pause_and_assert(client, handle, activity_info_2.activity_id) # Wait for next heartbeat to propagate the cancellation. - time.sleep(0.3) + await wait_for_next_heartbeat_cycle( + handle, activity_info_2.activity_id, activity_info_2.last_heartbeat_time + ) # Unpause activity await unpause_and_assert(client, handle, activity_info_2.activity_id) @@ -7826,3 +7836,64 @@ async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): result = await handle.result() assert result[0] == None assert result[1] == None + + +@activity.defn +async def external_activity_heartbeat() -> None: + activity.raise_complete_async() + + +@workflow.defn +class ExternalActivityWorkflow: + @workflow.run + async def run(self, activity_id: str) -> None: + await workflow.execute_activity( + external_activity_heartbeat, + activity_id=activity_id, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=1), + retry_policy=RetryPolicy(maximum_attempts=2), + ) + + +async def test_external_activity_cancellation_details( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Time-skipping server does not support pause API yet") + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[ExternalActivityWorkflow], + activities=[external_activity_heartbeat], + ) as worker: + test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" + + wf_handle = await client.start_workflow( + ExternalActivityWorkflow.run, + test_activity_id, + id=f"test-external-activity-pause-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + wf_desc = await wf_handle.describe() + + # Wait for external activity + activity_info = await assert_pending_activity_exists_eventually( + wf_handle, test_activity_id + ) + # Assert not paused + assert not activity_info.paused + + external_activity_handle = client.get_async_activity_handle( + workflow_id=wf_desc.id, run_id=wf_desc.run_id, activity_id=test_activity_id + ) + + # Pause activity then assert it is paused + await pause_and_assert(client, wf_handle, activity_info.activity_id) + + try: + await external_activity_handle.heartbeat() + except AsyncActivityCancelledError as err: + assert err.details == temporalio.activity.ActivityCancellationDetails( + paused=True + ) From bb1b0570d6fd089ba65784f5f5452dc0722aa5af Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Wed, 28 May 2025 19:56:18 -0400 Subject: [PATCH 21/21] replace racy heartbeat check with heartbeat details check --- temporalio/client.py | 2 -- tests/helpers/__init__.py | 26 ++------------------------ tests/worker/test_workflow.py | 34 ++++++++++++++++++++++++++-------- 3 files changed, 28 insertions(+), 34 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 1c7514c5e..f46297eb9 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -5146,8 +5146,6 @@ def __init__(self) -> None: class AsyncActivityCancelledError(temporalio.exceptions.TemporalError): """Error that occurs when async activity attempted heartbeat but was cancelled.""" - details: Optional[ActivityCancellationDetails] = None - def __init__(self, details: Optional[ActivityCancellationDetails] = None) -> None: """Create async activity cancelled error.""" super().__init__("Activity cancelled") diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index 99c0f1cbb..a352877d5 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -225,7 +225,7 @@ async def assert_pending_activity_exists_eventually( """Wait until a pending activity with the given ID exists and return it.""" async def check() -> PendingActivityInfo: - act_info = await _get_pending_activity_info(handle, activity_id) + act_info = await get_pending_activity_info(handle, activity_id) if act_info is not None: return act_info raise AssertionError( @@ -235,29 +235,7 @@ async def check() -> PendingActivityInfo: return await assert_eventually(check, timeout=timeout) -async def wait_for_next_heartbeat_cycle( - handle: WorkflowHandle, - activity_id: str, - initial_heartbeat_time: Any, - timeout: timedelta = timedelta(seconds=5), -) -> None: - """Wait for the next heartbeat cycle by monitoring last_heartbeat_time changes.""" - - async def check_heartbeat_changed() -> None: - current_info = await _get_pending_activity_info(handle, activity_id) - if current_info is None: - raise AssertionError( - f"Activity with ID {activity_id} not found in pending activities" - ) - if current_info.last_heartbeat_time == initial_heartbeat_time: - raise AssertionError( - f"Activity with ID {activity_id} has not heartbeated yet" - ) - - await assert_eventually(check_heartbeat_changed, timeout=timeout) - - -async def _get_pending_activity_info( +async def get_pending_activity_info( handle: WorkflowHandle, activity_id: str, ) -> Optional[PendingActivityInfo]: diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 589e3bc3e..d75eee51a 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -124,10 +124,10 @@ assert_workflow_exists_eventually, ensure_search_attributes_present, find_free_port, + get_pending_activity_info, new_worker, pause_and_assert, unpause_and_assert, - wait_for_next_heartbeat_cycle, workflow_update_exists, ) from tests.helpers.external_stack_trace import ( @@ -7781,6 +7781,19 @@ async def run( async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): if env.supports_time_skipping: pytest.skip("Time-skipping server does not support pause API yet") + + async def check_heartbeat_details_exist( + handle: WorkflowHandle, + activity_id: str, + ) -> None: + act_info = await get_pending_activity_info(handle, activity_id) + if act_info is None: + raise AssertionError(f"Activity with ID {activity_id} not found.") + if len(act_info.heartbeat_details.payloads) == 0: + raise AssertionError( + f"Activity with ID {activity_id} has no heartbeat details" + ) + with concurrent.futures.ThreadPoolExecutor() as executor: async with Worker( client, @@ -7809,10 +7822,12 @@ async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): # Pause activity then assert it is paused await pause_and_assert(client, handle, activity_info_1.activity_id) - # Wait for next heartbeat to propagate the cancellation. Unpausing before the heartbeat - # will show activity as unpaused to core. Consequently, it will *not* issue an activity cancel. - await wait_for_next_heartbeat_cycle( - handle, activity_info_1.activity_id, activity_info_1.last_heartbeat_time + # Wait for heartbeat details to exist. At this point, the activity has finished executing + # due to cancellation from the pause. + await assert_eventually( + lambda: check_heartbeat_details_exist( + handle, activity_info_1.activity_id + ) ) # Unpause activity @@ -7825,9 +7840,12 @@ async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): assert not activity_info_2.paused # Pause activity then assert it is paused await pause_and_assert(client, handle, activity_info_2.activity_id) - # Wait for next heartbeat to propagate the cancellation. - await wait_for_next_heartbeat_cycle( - handle, activity_info_2.activity_id, activity_info_2.last_heartbeat_time + # Wait for heartbeat details to exist. At this point, the activity has finished executing + # due to cancellation from the pause. + await assert_eventually( + lambda: check_heartbeat_details_exist( + handle, activity_info_2.activity_id + ) ) # Unpause activity await unpause_and_assert(client, handle, activity_info_2.activity_id)