Skip to content

Commit 58d6951

Browse files
authored
Access current update info with ID inside update handler (#544)
Fixes #542
1 parent 2d65d82 commit 58d6951

File tree

3 files changed

+117
-0
lines changed

3 files changed

+117
-0
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,11 @@ def _apply_do_update(
463463
# inside the task, since the update may not be defined until after we have started the workflow - for example
464464
# if an update is in the first WFT & is also registered dynamically at the top of workflow code.
465465
async def run_update() -> None:
466+
# Set the current update for the life of this task
467+
temporalio.workflow._set_current_update_info(
468+
temporalio.workflow.UpdateInfo(id=job.id, name=job.name)
469+
)
470+
466471
command = self._add_command()
467472
command.update_response.protocol_instance_id = job.protocol_instance_id
468473
past_validation = False

temporalio/workflow.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import asyncio
6+
import contextvars
67
import inspect
78
import logging
89
import threading
@@ -424,6 +425,17 @@ class ParentInfo:
424425
workflow_id: str
425426

426427

428+
@dataclass(frozen=True)
429+
class UpdateInfo:
430+
"""Information about a workflow update."""
431+
432+
id: str
433+
"""Update ID."""
434+
435+
name: str
436+
"""Update type name."""
437+
438+
427439
class _Runtime(ABC):
428440
@staticmethod
429441
def current() -> _Runtime:
@@ -654,6 +666,31 @@ async def workflow_wait_condition(
654666
...
655667

656668

669+
_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
670+
"__temporal_current_update_info"
671+
)
672+
673+
674+
def _set_current_update_info(info: UpdateInfo) -> None:
675+
_current_update_info.set(info)
676+
677+
678+
def current_update_info() -> Optional[UpdateInfo]:
679+
"""Info for the current update if any.
680+
681+
This is powered by :py:mod:`contextvars` so it is only valid within the
682+
update handler and coroutines/tasks it has started.
683+
684+
.. warning::
685+
This API is experimental
686+
687+
Returns:
688+
Info for the current update handler the code calling this is executing
689+
within if any.
690+
"""
691+
return _current_update_info.get(None)
692+
693+
657694
def deprecate_patch(id: str) -> None:
658695
"""Mark a patch as deprecated.
659696

tests/worker/test_workflow.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4927,3 +4927,78 @@ async def test_workflow_wait_utility(client: Client):
49274927
task_queue=worker.task_queue,
49284928
)
49294929
assert len(result) == 10
4930+
4931+
4932+
@workflow.defn
4933+
class CurrentUpdateWorkflow:
4934+
def __init__(self) -> None:
4935+
self._pending_get_update_id_tasks: List[asyncio.Task[str]] = []
4936+
4937+
@workflow.run
4938+
async def run(self) -> List[str]:
4939+
# Confirm no update info
4940+
assert not workflow.current_update_info()
4941+
4942+
# Wait for all tasks to come in, then return the full set
4943+
await workflow.wait_condition(
4944+
lambda: len(self._pending_get_update_id_tasks) == 5
4945+
)
4946+
assert not workflow.current_update_info()
4947+
return list(await asyncio.gather(*self._pending_get_update_id_tasks))
4948+
4949+
@workflow.update
4950+
async def do_update(self) -> str:
4951+
# Check that simple helper awaited has the ID
4952+
info = workflow.current_update_info()
4953+
assert info
4954+
assert info.name == "do_update"
4955+
assert info.id == await self.get_update_id()
4956+
4957+
# Also schedule the task and wait for it in the main workflow to confirm
4958+
# it still gets the update ID
4959+
self._pending_get_update_id_tasks.append(
4960+
asyncio.create_task(self.get_update_id())
4961+
)
4962+
4963+
# Re-fetch and return
4964+
info = workflow.current_update_info()
4965+
assert info
4966+
return info.id
4967+
4968+
@do_update.validator
4969+
def do_update_validator(self) -> None:
4970+
info = workflow.current_update_info()
4971+
assert info
4972+
assert info.name == "do_update"
4973+
4974+
async def get_update_id(self) -> str:
4975+
await asyncio.sleep(0.01)
4976+
info = workflow.current_update_info()
4977+
assert info
4978+
return info.id
4979+
4980+
4981+
async def test_workflow_current_update(client: Client, env: WorkflowEnvironment):
4982+
if env.supports_time_skipping:
4983+
pytest.skip(
4984+
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
4985+
)
4986+
async with new_worker(client, CurrentUpdateWorkflow) as worker:
4987+
handle = await client.start_workflow(
4988+
CurrentUpdateWorkflow.run,
4989+
id=f"wf-{uuid.uuid4()}",
4990+
task_queue=worker.task_queue,
4991+
)
4992+
update_ids = await asyncio.gather(
4993+
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update1"),
4994+
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update2"),
4995+
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update3"),
4996+
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update4"),
4997+
handle.execute_update(CurrentUpdateWorkflow.do_update, id="update5"),
4998+
)
4999+
assert {"update1", "update2", "update3", "update4", "update5"} == set(
5000+
update_ids
5001+
)
5002+
assert {"update1", "update2", "update3", "update4", "update5"} == set(
5003+
await handle.result()
5004+
)

0 commit comments

Comments
 (0)