Skip to content

Interrupt heartbeating activity on pause #854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions temporalio/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
overload,
)

import temporalio.bridge
import temporalio.bridge.proto
import temporalio.bridge.proto.activity_task
import temporalio.common
import temporalio.converter

Expand Down Expand Up @@ -135,6 +138,34 @@ def _logger_details(self) -> Mapping[str, Any]:
_current_context: contextvars.ContextVar[_Context] = contextvars.ContextVar("activity")


@dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be frozen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We mutate the fields in this object to reflect changes across running activity & _context

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to frozen class + holder

class _ActivityCancellationDetailsHolder:
details: Optional[ActivityCancellationDetails] = None


@dataclass(frozen=True)
class ActivityCancellationDetails:
"""Provides the reasons for the activity's cancellation. Cancellation details are set once and do not change once set."""

not_found: bool = False
cancel_requested: bool = False
paused: bool = False
timed_out: bool = False
worker_shutdown: bool = False

@staticmethod
def _from_proto(
proto: temporalio.bridge.proto.activity_task.ActivityCancellationDetails,
) -> ActivityCancellationDetails:
return ActivityCancellationDetails(
not_found=proto.is_not_found,
cancel_requested=proto.is_cancelled,
paused=proto.is_paused,
timed_out=proto.is_timed_out,
worker_shutdown=proto.is_worker_shutdown,
)


@dataclass
class _Context:
info: Callable[[], Info]
Expand All @@ -148,6 +179,7 @@ class _Context:
temporalio.converter.PayloadConverter,
]
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
cancellation_details: _ActivityCancellationDetailsHolder
_logger_details: Optional[Mapping[str, Any]] = None
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
_metric_meter: Optional[temporalio.common.MetricMeter] = None
Expand Down Expand Up @@ -260,6 +292,11 @@ def info() -> Info:
return _Context.current().info()


def cancellation_details() -> Optional[ActivityCancellationDetails]:
"""Cancellation details of the current activity, if any. Once set, cancellation details do not change."""
return _Context.current().cancellation_details.details


def heartbeat(*details: Any) -> None:
"""Send a heartbeat for the current activity.

Expand Down
6 changes: 6 additions & 0 deletions temporalio/bridge/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ impl ClientRef {
"patch_schedule" => {
rpc_call!(retry_client, call, patch_schedule)
}
"pause_activity" => {
rpc_call!(retry_client, call, pause_activity)
}
"poll_activity_task_queue" => {
rpc_call!(retry_client, call, poll_activity_task_queue)
}
Expand Down Expand Up @@ -325,6 +328,9 @@ impl ClientRef {
"trigger_workflow_rule" => {
rpc_call!(retry_client, call, trigger_workflow_rule)
}
"unpause_activity" => {
rpc_call!(retry_client, call, unpause_activity)
}
"update_namespace" => {
rpc_call_on_trait!(retry_client, call, WorkflowService, update_namespace)
}
Expand Down
23 changes: 18 additions & 5 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import temporalio.runtime
import temporalio.service
import temporalio.workflow
from temporalio.activity import ActivityCancellationDetails
from temporalio.service import (
HttpConnectProxyConfig,
KeepAliveConfig,
Expand Down Expand Up @@ -5145,9 +5146,10 @@ def __init__(self) -> None:
class AsyncActivityCancelledError(temporalio.exceptions.TemporalError):
"""Error that occurs when async activity attempted heartbeat but was cancelled."""

def __init__(self) -> None:
def __init__(self, details: Optional[ActivityCancellationDetails] = None) -> None:
"""Create async activity cancelled error."""
super().__init__("Activity cancelled")
self.details = details


class ScheduleAlreadyRunningError(temporalio.exceptions.TemporalError):
Expand Down Expand Up @@ -6287,8 +6289,14 @@ async def heartbeat_async_activity(
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
if resp_by_id.cancel_requested:
raise AsyncActivityCancelledError()
if resp_by_id.cancel_requested or resp_by_id.activity_paused:
raise AsyncActivityCancelledError(
details=ActivityCancellationDetails(
cancel_requested=resp_by_id.cancel_requested,
paused=resp_by_id.activity_paused,
)
)

else:
resp = await self._client.workflow_service.record_activity_task_heartbeat(
temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest(
Expand All @@ -6301,8 +6309,13 @@ async def heartbeat_async_activity(
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
if resp.cancel_requested:
raise AsyncActivityCancelledError()
if resp.cancel_requested or resp.activity_paused:
raise AsyncActivityCancelledError(
details=ActivityCancellationDetails(
cancel_requested=resp.cancel_requested,
paused=resp.activity_paused,
)
)

async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
result = (
Expand Down
17 changes: 16 additions & 1 deletion temporalio/testing/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,29 @@ def __init__(self) -> None:
self._cancelled = False
self._worker_shutdown = False
self._activities: Set[_Activity] = set()
self._cancellation_details = (
temporalio.activity._ActivityCancellationDetailsHolder()
)

def cancel(self) -> None:
def cancel(
self,
cancellation_details: temporalio.activity.ActivityCancellationDetails = temporalio.activity.ActivityCancellationDetails(
cancel_requested=True
),
) -> None:
"""Cancel the activity.

Args:
cancellation_details: details about the cancellation. These will
be accessible through temporalio.activity.cancellation_details()
in the activity after cancellation.

This only has an effect on the first call.
"""
if self._cancelled:
return
self._cancelled = True
self._cancellation_details.details = cancellation_details
for act in self._activities:
act.cancel()

Expand Down Expand Up @@ -154,6 +168,7 @@ def __init__(
else self.cancel_thread_raiser.shielded,
payload_converter_class_or_instance=env.payload_converter,
runtime_metric_meter=env.metric_meter,
cancellation_details=env._cancellation_details,
)
self.task: Optional[asyncio.Task] = None

Expand Down
38 changes: 34 additions & 4 deletions temporalio/worker/_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import (
Any,
Expand Down Expand Up @@ -216,7 +216,13 @@ def _cancel(
warnings.warn(f"Cannot find activity to cancel for token {task_token!r}")
return
logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason)
activity.cancel(cancelled_by_request=True)
activity.cancellation_details.details = (
temporalio.activity.ActivityCancellationDetails._from_proto(cancel.details)
)
activity.cancel(
cancelled_by_request=cancel.details.is_cancelled
or cancel.details.is_worker_shutdown
)

def _heartbeat(self, task_token: bytes, *details: Any) -> None:
# We intentionally make heartbeating non-async, but since the data
Expand Down Expand Up @@ -303,6 +309,24 @@ async def _run_activity(
await self._data_converter.encode_failure(
err, completion.result.failed.failure
)
elif (
isinstance(
err,
(asyncio.CancelledError, temporalio.exceptions.CancelledError),
)
and running_activity.cancellation_details.details
and running_activity.cancellation_details.details.paused
):
temporalio.activity.logger.warning(
f"Completing as failure due to unhandled cancel error produced by activity pause",
)
await self._data_converter.encode_failure(
temporalio.exceptions.ApplicationError(
type="ActivityPause",
message="Unhandled activity cancel error produced by activity pause",
),
completion.result.failed.failure,
)
elif (
isinstance(
err,
Expand Down Expand Up @@ -336,7 +360,6 @@ async def _run_activity(
await self._data_converter.encode_failure(
err, completion.result.failed.failure
)

# For broken executors, we have to fail the entire worker
if isinstance(err, concurrent.futures.BrokenExecutor):
self._fail_worker_exception_queue.put_nowait(err)
Expand Down Expand Up @@ -524,6 +547,7 @@ async def _execute_activity(
else running_activity.cancel_thread_raiser.shielded,
payload_converter_class_or_instance=self._data_converter.payload_converter,
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
cancellation_details=running_activity.cancellation_details,
)
)
temporalio.activity.logger.debug("Starting activity")
Expand Down Expand Up @@ -570,6 +594,9 @@ class _RunningActivity:
done: bool = False
cancelled_by_request: bool = False
cancelled_due_to_heartbeat_error: Optional[Exception] = None
cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder = (
field(default_factory=temporalio.activity._ActivityCancellationDetailsHolder)
)

def cancel(
self,
Expand Down Expand Up @@ -659,6 +686,7 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
# can set the initializer on the executor).
ctx = temporalio.activity._Context.current()
info = ctx.info()
cancellation_details = ctx.cancellation_details

# Heartbeat calls internally use a data converter which is async so
# they need to be called on the event loop
Expand Down Expand Up @@ -717,6 +745,7 @@ async def heartbeat_with_context(*details: Any) -> None:
worker_shutdown_event.thread_event,
payload_converter_class_or_instance,
ctx.runtime_metric_meter,
cancellation_details,
input.fn,
*input.args,
]
Expand All @@ -732,7 +761,6 @@ async def heartbeat_with_context(*details: Any) -> None:
finally:
if shared_manager:
await shared_manager.unregister_heartbeater(info.task_token)

# Otherwise for async activity, just run
return await input.fn(*input.args)

Expand Down Expand Up @@ -764,6 +792,7 @@ def _execute_sync_activity(
temporalio.converter.PayloadConverter,
],
runtime_metric_meter: Optional[temporalio.common.MetricMeter],
cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder,
fn: Callable[..., Any],
*args: Any,
) -> Any:
Expand Down Expand Up @@ -795,6 +824,7 @@ def _execute_sync_activity(
else cancel_thread_raiser.shielded,
payload_converter_class_or_instance=payload_converter_class_or_instance,
runtime_metric_meter=runtime_metric_meter,
cancellation_details=cancellation_details,
)
)
return fn(*args)
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]:
"frontend.workerVersioningDataAPIs=true",
"--dynamic-config-value",
"system.enableDeploymentVersions=true",
"--dynamic-config-value",
"frontend.activityAPIsEnabled=true",
],
dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION,
)
Expand Down
79 changes: 78 additions & 1 deletion tests/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
ListSearchAttributesRequest,
)
from temporalio.api.update.v1 import UpdateRef
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
from temporalio.api.workflow.v1 import PendingActivityInfo
from temporalio.api.workflowservice.v1 import (
PauseActivityRequest,
PollWorkflowExecutionUpdateRequest,
UnpauseActivityRequest,
)
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
from temporalio.common import SearchAttributeKey
from temporalio.service import RPCError, RPCStatusCode
Expand Down Expand Up @@ -210,3 +215,75 @@ async def check_workflow_exists() -> bool:
await assert_eq_eventually(True, check_workflow_exists)
assert handle is not None
return handle


async def assert_pending_activity_exists_eventually(
handle: WorkflowHandle,
activity_id: str,
timeout: timedelta = timedelta(seconds=5),
) -> PendingActivityInfo:
"""Wait until a pending activity with the given ID exists and return it."""

async def check() -> PendingActivityInfo:
act_info = await get_pending_activity_info(handle, activity_id)
if act_info is not None:
return act_info
raise AssertionError(
f"Activity with ID {activity_id} not found in pending activities"
)

return await assert_eventually(check, timeout=timeout)


async def get_pending_activity_info(
handle: WorkflowHandle,
activity_id: str,
) -> Optional[PendingActivityInfo]:
"""Get pending activity info by ID, or None if not found."""
desc = await handle.describe()
for act in desc.raw_description.pending_activities:
if act.activity_id == activity_id:
return act
return None


async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str):
"""Pause the given activity and assert it becomes paused."""
desc = await handle.describe()
req = PauseActivityRequest(
namespace=client.namespace,
execution=WorkflowExecution(
workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id,
run_id=desc.raw_description.workflow_execution_info.execution.run_id,
),
id=activity_id,
)
await client.workflow_service.pause_activity(req)

# Assert eventually paused
async def check_paused() -> bool:
info = await assert_pending_activity_exists_eventually(handle, activity_id)
return info.paused

await assert_eventually(check_paused)


async def unpause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str):
"""Unpause the given activity and assert it is not paused."""
desc = await handle.describe()
req = UnpauseActivityRequest(
namespace=client.namespace,
execution=WorkflowExecution(
workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id,
run_id=desc.raw_description.workflow_execution_info.execution.run_id,
),
id=activity_id,
)
await client.workflow_service.unpause_activity(req)

# Assert eventually not paused
async def check_unpaused() -> bool:
info = await assert_pending_activity_exists_eventually(handle, activity_id)
return not info.paused

await assert_eventually(check_unpaused)
Loading
Loading