Skip to content

Commit a4d39c9

Browse files
committed
Add tests of asyncio.Lock and asyncio.Semaphore usage
1 parent 7ac4445 commit a4d39c9

File tree

2 files changed

+322
-2
lines changed

2 files changed

+322
-2
lines changed

tests/helpers/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from contextlib import closing
66
from datetime import timedelta
7-
from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar
7+
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
1010
from temporalio.api.enums.v1 import IndexedValueType
@@ -14,11 +14,12 @@
1414
)
1515
from temporalio.api.update.v1 import UpdateRef
1616
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
17-
from temporalio.client import BuildIdOpAddNewDefault, Client
17+
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
1818
from temporalio.common import SearchAttributeKey
1919
from temporalio.service import RPCError, RPCStatusCode
2020
from temporalio.worker import Worker, WorkflowRunner
2121
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
22+
from temporalio.workflow import UpdateMethodMultiParam
2223

2324

2425
def new_worker(
@@ -128,3 +129,24 @@ async def workflow_update_exists(
128129
if err.status != RPCStatusCode.NOT_FOUND:
129130
raise
130131
return False
132+
133+
134+
# TODO: type update return value
135+
async def admitted_update_task(
136+
client: Client,
137+
handle: WorkflowHandle,
138+
update_method: UpdateMethodMultiParam,
139+
id: str,
140+
**kwargs,
141+
) -> asyncio.Task:
142+
"""
143+
Return an asyncio.Task for an update after waiting for it to be admitted.
144+
"""
145+
update_task = asyncio.create_task(
146+
handle.execute_update(update_method, id=id, **kwargs)
147+
)
148+
await assert_eq_eventually(
149+
True,
150+
lambda: workflow_update_exists(client, handle.id, id),
151+
)
152+
return update_task

tests/worker/test_workflow.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
WorkflowRunner,
106106
)
107107
from tests.helpers import (
108+
admitted_update_task,
108109
assert_eq_eventually,
109110
ensure_search_attributes_present,
110111
find_free_port,
@@ -5510,3 +5511,300 @@ def _unfinished_handler_warning_cls(self) -> Type:
55105511
"update": workflow.UnfinishedUpdateHandlersWarning,
55115512
"signal": workflow.UnfinishedSignalHandlersWarning,
55125513
}[self.handler_type]
5514+
5515+
5516+
# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected
5517+
# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and
5518+
# asyncio.Semaphore are used here.
5519+
5520+
5521+
@activity.defn
5522+
async def noop_activity_for_lock_or_semaphore_tests() -> None:
5523+
return None
5524+
5525+
5526+
@dataclass
5527+
class LockOrSemaphoreWorkflowConcurrencySummary:
5528+
ever_in_critical_section: int
5529+
peak_in_critical_section: int
5530+
5531+
5532+
@dataclass
5533+
class UseLockOrSemaphoreWorkflowParameters:
5534+
n_coroutines: int = 0
5535+
semaphore_initial_value: Optional[int] = None
5536+
sleep: Optional[float] = None
5537+
timeout: Optional[float] = None
5538+
5539+
5540+
@workflow.defn
5541+
class CoroutinesUseLockWorkflow:
5542+
def __init__(self) -> None:
5543+
self.params: UseLockOrSemaphoreWorkflowParameters
5544+
self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore]
5545+
self._currently_in_critical_section: set[str] = set()
5546+
self._ever_in_critical_section: set[str] = set()
5547+
self._peak_in_critical_section = 0
5548+
5549+
def init(self, params: UseLockOrSemaphoreWorkflowParameters):
5550+
self.params = params
5551+
if self.params.semaphore_initial_value is not None:
5552+
self.lock_or_semaphore = asyncio.Semaphore(
5553+
self.params.semaphore_initial_value
5554+
)
5555+
else:
5556+
self.lock_or_semaphore = asyncio.Lock()
5557+
5558+
@workflow.run
5559+
async def run(
5560+
self,
5561+
params: UseLockOrSemaphoreWorkflowParameters,
5562+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
5563+
# TODO: Use workflow init method when it exists.
5564+
self.init(params)
5565+
await asyncio.gather(
5566+
*(self.coroutine(f"{i}") for i in range(self.params.n_coroutines))
5567+
)
5568+
assert not any(self._currently_in_critical_section)
5569+
return LockOrSemaphoreWorkflowConcurrencySummary(
5570+
len(self._ever_in_critical_section),
5571+
self._peak_in_critical_section,
5572+
)
5573+
5574+
async def coroutine(self, id: str):
5575+
if self.params.timeout:
5576+
try:
5577+
await asyncio.wait_for(
5578+
self.lock_or_semaphore.acquire(), self.params.timeout
5579+
)
5580+
except asyncio.TimeoutError:
5581+
return
5582+
else:
5583+
await self.lock_or_semaphore.acquire()
5584+
self._enters_critical_section(id)
5585+
try:
5586+
if self.params.sleep:
5587+
await asyncio.sleep(self.params.sleep)
5588+
else:
5589+
await workflow.execute_activity(
5590+
noop_activity_for_lock_or_semaphore_tests,
5591+
schedule_to_close_timeout=timedelta(seconds=30),
5592+
)
5593+
finally:
5594+
self.lock_or_semaphore.release()
5595+
self._exits_critical_section(id)
5596+
5597+
def _enters_critical_section(self, id: str) -> None:
5598+
self._currently_in_critical_section.add(id)
5599+
self._ever_in_critical_section.add(id)
5600+
self._peak_in_critical_section = max(
5601+
self._peak_in_critical_section,
5602+
len(self._currently_in_critical_section),
5603+
)
5604+
5605+
def _exits_critical_section(self, id: str) -> None:
5606+
self._currently_in_critical_section.remove(id)
5607+
5608+
5609+
@workflow.defn
5610+
class HandlerCoroutinesUseLockWorkflow(CoroutinesUseLockWorkflow):
5611+
def __init__(self) -> None:
5612+
super().__init__()
5613+
self.workflow_may_exit = False
5614+
5615+
@workflow.run
5616+
async def run(
5617+
self,
5618+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
5619+
await workflow.wait_condition(lambda: self.workflow_may_exit)
5620+
return LockOrSemaphoreWorkflowConcurrencySummary(
5621+
len(self._ever_in_critical_section),
5622+
self._peak_in_critical_section,
5623+
)
5624+
5625+
@workflow.update
5626+
async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters):
5627+
# TODO: Use workflow init method when it exists.
5628+
if not hasattr(self, "params"):
5629+
self.init(params)
5630+
assert (update_info := workflow.current_update_info())
5631+
await self.coroutine(update_info.id)
5632+
5633+
@workflow.signal
5634+
async def finish(self):
5635+
self.workflow_may_exit = True
5636+
5637+
5638+
async def _do_workflow_coroutines_lock_or_semaphore_test(
5639+
client: Client,
5640+
params: UseLockOrSemaphoreWorkflowParameters,
5641+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
5642+
):
5643+
async with new_worker(
5644+
client,
5645+
CoroutinesUseLockWorkflow,
5646+
activities=[noop_activity_for_lock_or_semaphore_tests],
5647+
) as worker:
5648+
summary = await client.execute_workflow(
5649+
CoroutinesUseLockWorkflow.run,
5650+
arg=params,
5651+
id=str(uuid.uuid4()),
5652+
task_queue=worker.task_queue,
5653+
)
5654+
assert summary == expectation
5655+
5656+
5657+
async def _do_update_handler_lock_or_semaphore_test(
5658+
client: Client,
5659+
env: WorkflowEnvironment,
5660+
params: UseLockOrSemaphoreWorkflowParameters,
5661+
n_updates: int,
5662+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
5663+
):
5664+
if env.supports_time_skipping:
5665+
pytest.skip(
5666+
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
5667+
)
5668+
5669+
task_queue = "tq"
5670+
handle = await client.start_workflow(
5671+
HandlerCoroutinesUseLockWorkflow.run,
5672+
id=f"wf-{str(uuid.uuid4())}",
5673+
task_queue=task_queue,
5674+
)
5675+
# Create updates in Admitted state, before the worker starts polling.
5676+
admitted_updates = [
5677+
await admitted_update_task(
5678+
client,
5679+
handle,
5680+
HandlerCoroutinesUseLockWorkflow.my_update,
5681+
arg=params,
5682+
id=f"update-{i}",
5683+
)
5684+
for i in range(n_updates)
5685+
]
5686+
async with new_worker(
5687+
client,
5688+
HandlerCoroutinesUseLockWorkflow,
5689+
activities=[noop_activity_for_lock_or_semaphore_tests],
5690+
task_queue=task_queue,
5691+
):
5692+
for update_task in admitted_updates:
5693+
await update_task
5694+
await handle.signal(HandlerCoroutinesUseLockWorkflow.finish)
5695+
summary = await handle.result()
5696+
assert summary == expectation
5697+
5698+
5699+
async def test_workflow_coroutines_can_use_lock(client: Client):
5700+
await _do_workflow_coroutines_lock_or_semaphore_test(
5701+
client,
5702+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5),
5703+
# The lock limits concurrency to 1
5704+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5705+
ever_in_critical_section=5, peak_in_critical_section=1
5706+
),
5707+
)
5708+
5709+
5710+
async def test_update_handler_can_use_lock_to_serialize_handler_executions(
5711+
client: Client, env: WorkflowEnvironment
5712+
):
5713+
await _do_update_handler_lock_or_semaphore_test(
5714+
client,
5715+
env,
5716+
UseLockOrSemaphoreWorkflowParameters(),
5717+
n_updates=5,
5718+
# The lock limits concurrency to 1
5719+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5720+
ever_in_critical_section=5, peak_in_critical_section=1
5721+
),
5722+
)
5723+
5724+
5725+
async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client):
5726+
await _do_workflow_coroutines_lock_or_semaphore_test(
5727+
client,
5728+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1),
5729+
# Second and subsequent coroutines fail to acquire the lock due to the timeout.
5730+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5731+
ever_in_critical_section=1, peak_in_critical_section=1
5732+
),
5733+
)
5734+
5735+
5736+
async def test_update_handler_lock_acquisition_respects_timeout(
5737+
client: Client, env: WorkflowEnvironment
5738+
):
5739+
await _do_update_handler_lock_or_semaphore_test(
5740+
client,
5741+
env,
5742+
# Second and subsequent handler executions fail to acquire the lock due to the timeout.
5743+
UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1),
5744+
n_updates=5,
5745+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5746+
ever_in_critical_section=1, peak_in_critical_section=1
5747+
),
5748+
)
5749+
5750+
5751+
async def test_workflow_coroutines_can_use_semaphore(client: Client):
5752+
await _do_workflow_coroutines_lock_or_semaphore_test(
5753+
client,
5754+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3),
5755+
# The semaphore limits concurrency to 3
5756+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5757+
ever_in_critical_section=5, peak_in_critical_section=3
5758+
),
5759+
)
5760+
5761+
5762+
async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency(
5763+
client: Client, env: WorkflowEnvironment
5764+
):
5765+
await _do_update_handler_lock_or_semaphore_test(
5766+
client,
5767+
env,
5768+
# The semaphore limits concurrency to 3
5769+
UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3),
5770+
n_updates=5,
5771+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5772+
ever_in_critical_section=5, peak_in_critical_section=3
5773+
),
5774+
)
5775+
5776+
5777+
async def test_workflow_coroutine_semaphore_acquisition_respects_timeout(
5778+
client: Client,
5779+
):
5780+
await _do_workflow_coroutines_lock_or_semaphore_test(
5781+
client,
5782+
UseLockOrSemaphoreWorkflowParameters(
5783+
n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1
5784+
),
5785+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
5786+
# slot fail.
5787+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5788+
ever_in_critical_section=3, peak_in_critical_section=3
5789+
),
5790+
)
5791+
5792+
5793+
async def test_update_handler_semaphore_acquisition_respects_timeout(
5794+
client: Client, env: WorkflowEnvironment
5795+
):
5796+
await _do_update_handler_lock_or_semaphore_test(
5797+
client,
5798+
env,
5799+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
5800+
# slot fail.
5801+
UseLockOrSemaphoreWorkflowParameters(
5802+
semaphore_initial_value=3,
5803+
sleep=0.5,
5804+
timeout=0.1,
5805+
),
5806+
n_updates=5,
5807+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5808+
ever_in_critical_section=3, peak_in_critical_section=3
5809+
),
5810+
)

0 commit comments

Comments
 (0)