Skip to content

Commit 57bf9e6

Browse files
authored
Minor type improvements (#239)
Fixes #236 Fixes #237 Fixes #234 Fixes #232
1 parent 22e8d92 commit 57bf9e6

File tree

7 files changed

+148
-13
lines changed

7 files changed

+148
-13
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ The default data converter supports converting multiple types including:
278278

279279
This notably doesn't include any `date`, `time`, or `datetime` objects as they may not work across SDKs.
280280

281+
Classes with generics may not have the generics properly resolved. The current implementation, similar to Pydantic, does
282+
not have generic type resolution. Users should use concrete types.
283+
281284
For converting from JSON, the workflow/activity type hint is taken into account to convert to the proper type. Care has
282285
been taken to support all common typings including `Optional`, `Union`, all forms of iterables and mappings, `NewType`,
283286
etc in addition to the regular JSON values mentioned before.
@@ -357,7 +360,10 @@ class GreetingWorkflow:
357360
# Wait for salutation update or complete signal (this can be
358361
# cancelled)
359362
await asyncio.wait(
360-
[self._greeting_info_update.wait(), self._complete.wait()],
363+
[
364+
asyncio.create_task(self._greeting_info_update.wait()),
365+
asyncio.create_task(self._complete.wait()),
366+
],
361367
return_when=asyncio.FIRST_COMPLETED,
362368
)
363369
if self._complete.is_set():

temporalio/client.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ async def start_workflow(
320320
args: Sequence[Any] = [],
321321
id: str,
322322
task_queue: str,
323+
result_type: Optional[Type] = None,
323324
execution_timeout: Optional[timedelta] = None,
324325
run_timeout: Optional[timedelta] = None,
325326
task_timeout: Optional[timedelta] = None,
@@ -343,6 +344,7 @@ async def start_workflow(
343344
args: Sequence[Any] = [],
344345
id: str,
345346
task_queue: str,
347+
result_type: Optional[Type] = None,
346348
execution_timeout: Optional[timedelta] = None,
347349
run_timeout: Optional[timedelta] = None,
348350
task_timeout: Optional[timedelta] = None,
@@ -365,6 +367,8 @@ async def start_workflow(
365367
args: Multiple arguments to the workflow. Cannot be set if arg is.
366368
id: Unique identifier for the workflow execution.
367369
task_queue: Task queue to run the workflow on.
370+
result_type: For string workflows, this can set the specific result
371+
type hint to deserialize into.
368372
execution_timeout: Total workflow execution timeout including
369373
retries and continue as new.
370374
run_timeout: Timeout of a single workflow run.
@@ -390,13 +394,13 @@ async def start_workflow(
390394
"""
391395
# Use definition if callable
392396
name: str
393-
ret_type: Optional[Type] = None
394397
if isinstance(workflow, str):
395398
name = workflow
396399
elif callable(workflow):
397400
defn = temporalio.workflow._Definition.must_from_run_fn(workflow)
398401
name = defn.name
399-
ret_type = defn.ret_type
402+
if result_type is None:
403+
result_type = defn.ret_type
400404
else:
401405
raise TypeError("Workflow must be a string or callable")
402406

@@ -417,7 +421,7 @@ async def start_workflow(
417421
headers={},
418422
start_signal=start_signal,
419423
start_signal_args=start_signal_args,
420-
ret_type=ret_type,
424+
ret_type=result_type,
421425
rpc_metadata=rpc_metadata,
422426
rpc_timeout=rpc_timeout,
423427
)
@@ -506,6 +510,7 @@ async def execute_workflow(
506510
args: Sequence[Any] = [],
507511
id: str,
508512
task_queue: str,
513+
result_type: Optional[Type] = None,
509514
execution_timeout: Optional[timedelta] = None,
510515
run_timeout: Optional[timedelta] = None,
511516
task_timeout: Optional[timedelta] = None,
@@ -529,6 +534,7 @@ async def execute_workflow(
529534
args: Sequence[Any] = [],
530535
id: str,
531536
task_queue: str,
537+
result_type: Optional[Type] = None,
532538
execution_timeout: Optional[timedelta] = None,
533539
run_timeout: Optional[timedelta] = None,
534540
task_timeout: Optional[timedelta] = None,
@@ -555,6 +561,7 @@ async def execute_workflow(
555561
arg,
556562
args=args,
557563
task_queue=task_queue,
564+
result_type=result_type,
558565
id=id,
559566
execution_timeout=execution_timeout,
560567
run_timeout=run_timeout,

temporalio/common.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,7 @@ def _type_hints_from_func(
222222
sig = inspect.signature(func)
223223
hints = get_type_hints(func)
224224
ret_hint = hints.get("return")
225-
ret = (
226-
ret_hint
227-
if inspect.isclass(ret_hint) and ret_hint is not inspect.Signature.empty
228-
else None
229-
)
225+
ret = ret_hint if ret_hint is not inspect.Signature.empty else None
230226
args: List[Type] = []
231227
for index, value in enumerate(sig.parameters.values()):
232228
# Ignore self on methods
@@ -244,7 +240,9 @@ def _type_hints_from_func(
244240
return (None, ret)
245241
# All params must have annotations or we consider none to have them
246242
arg_hint = hints.get(value.name)
247-
if not inspect.isclass(arg_hint) or arg_hint is inspect.Parameter.empty:
243+
if arg_hint is inspect.Parameter.empty:
248244
return (None, ret)
249-
args.append(arg_hint)
245+
# Ignoring type here because union/optional isn't really a type
246+
# necessarily
247+
args.append(arg_hint) # type: ignore
250248
return args, ret

temporalio/worker/_activity.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,21 @@ async def _run_activity(
447447
temporalio.activity.logger.warning(
448448
"Completing activity as failed", exc_info=True
449449
)
450+
# In some cases, like worker shutdown of an sync activity,
451+
# this results in a CancelledError, but the server will fail
452+
# if you send a cancelled error outside of a requested
453+
# cancellation. So we wrap as a retryable application error.
454+
if isinstance(
455+
err,
456+
(asyncio.CancelledError, temporalio.exceptions.CancelledError),
457+
):
458+
new_err = temporalio.exceptions.ApplicationError(
459+
"Cancelled without request, possibly due to worker shutdown",
460+
type="CancelledError",
461+
)
462+
new_err.__traceback__ = err.__traceback__
463+
new_err.__cause__ = err.__cause__
464+
err = new_err
450465
await self._data_converter.encode_failure(
451466
err, completion.result.failed.failure
452467
)

tests/helpers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def assert_eq_eventually(
3434
expected: T,
3535
fn: Callable[[], Awaitable[T]],
3636
*,
37-
timeout: timedelta = timedelta(seconds=3),
37+
timeout: timedelta = timedelta(seconds=10),
3838
interval: timedelta = timedelta(milliseconds=200),
3939
) -> None:
4040
start_sec = time.monotonic()

tests/worker/test_activity.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging.handlers
55
import multiprocessing
66
import queue
7+
import threading
78
import time
89
import uuid
910
from dataclasses import dataclass
@@ -12,13 +13,14 @@
1213

1314
import pytest
1415

15-
from temporalio import activity
16+
from temporalio import activity, workflow
1617
from temporalio.client import (
1718
AsyncActivityHandle,
1819
Client,
1920
WorkflowFailureError,
2021
WorkflowHandle,
2122
)
23+
from temporalio.common import RetryPolicy
2224
from temporalio.exceptions import (
2325
ActivityError,
2426
ApplicationError,
@@ -375,6 +377,67 @@ def wait_cancel() -> None:
375377
assert events == ["pre1", "pre2", "pre3", "post3", "post2"]
376378

377379

380+
sync_activity_waiting_cancel = threading.Event()
381+
382+
383+
@activity.defn
384+
def sync_activity_wait_cancel():
385+
sync_activity_waiting_cancel.set()
386+
while True:
387+
time.sleep(1)
388+
activity.heartbeat()
389+
390+
391+
# We don't sandbox because Python logging uses multiprocessing if it's present
392+
# which we don't want to get warnings about
393+
@workflow.defn(sandboxed=False)
394+
class CancelOnWorkerShutdownWorkflow:
395+
@workflow.run
396+
async def run(self) -> None:
397+
await workflow.execute_activity(
398+
sync_activity_wait_cancel,
399+
start_to_close_timeout=timedelta(hours=1),
400+
retry_policy=RetryPolicy(maximum_attempts=1),
401+
)
402+
403+
404+
# This test used to fail because we were sending a cancelled error and the
405+
# server doesn't allow that
406+
async def test_sync_activity_thread_cancel_on_worker_shutdown(client: Client):
407+
task_queue = f"tq-{uuid.uuid4()}"
408+
409+
def new_worker() -> Worker:
410+
return Worker(
411+
client,
412+
task_queue=task_queue,
413+
activities=[sync_activity_wait_cancel],
414+
workflows=[CancelOnWorkerShutdownWorkflow],
415+
activity_executor=executor,
416+
max_cached_workflows=0,
417+
)
418+
419+
with concurrent.futures.ThreadPoolExecutor() as executor:
420+
async with new_worker():
421+
# Start the workflow
422+
handle = await client.start_workflow(
423+
CancelOnWorkerShutdownWorkflow.run,
424+
id=f"workflow-{uuid.uuid4()}",
425+
task_queue=task_queue,
426+
)
427+
# Wait for activity to start
428+
assert await asyncio.get_running_loop().run_in_executor(
429+
executor, lambda: sync_activity_waiting_cancel.wait(20)
430+
)
431+
# Shut down the worker
432+
# Start the worker again and wait for result
433+
with pytest.raises(WorkflowFailureError) as err:
434+
async with new_worker():
435+
await handle.result()
436+
assert isinstance(err.value.cause, ActivityError)
437+
assert isinstance(err.value.cause.cause, ApplicationError)
438+
assert "due to worker shutdown" in err.value.cause.cause.message
439+
440+
378441
@activity.defn
379442
def picklable_activity_wait_cancel() -> str:
380443
while not activity.is_cancelled():

tests/worker/test_workflow.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2646,3 +2646,49 @@ async def test_workflow_custom_failure_converter(client: Client):
26462646
if not failure.HasField("cause"):
26472647
break
26482648
failure = failure.cause
2649+
2650+
2651+
@dataclass
2652+
class OptionalParam:
2653+
some_string: str
2654+
2655+
2656+
@workflow.defn
2657+
class OptionalParamWorkflow:
2658+
@workflow.run
2659+
async def run(
2660+
self, some_param: Optional[OptionalParam] = OptionalParam(some_string="default")
2661+
) -> Optional[OptionalParam]:
2662+
assert some_param is None or (
2663+
isinstance(some_param, OptionalParam)
2664+
and some_param.some_string in ["default", "foo"]
2665+
)
2666+
return some_param
2667+
2668+
2669+
async def test_workflow_optional_param(client: Client):
2670+
async with new_worker(client, OptionalParamWorkflow) as worker:
2671+
# Don't send a parameter and confirm it is defaulted
2672+
result1 = await client.execute_workflow(
2673+
"OptionalParamWorkflow",
2674+
id=f"workflow-{uuid.uuid4()}",
2675+
task_queue=worker.task_queue,
2676+
result_type=OptionalParam,
2677+
)
2678+
assert result1 == OptionalParam(some_string="default")
2679+
# Send None explicitly
2680+
result2 = await client.execute_workflow(
2681+
OptionalParamWorkflow.run,
2682+
None,
2683+
id=f"workflow-{uuid.uuid4()}",
2684+
task_queue=worker.task_queue,
2685+
)
2686+
assert result2 is None
2687+
# Send param explicitly
2688+
result3 = await client.execute_workflow(
2689+
OptionalParamWorkflow.run,
2690+
OptionalParam(some_string="foo"),
2691+
id=f"workflow-{uuid.uuid4()}",
2692+
task_queue=worker.task_queue,
2693+
)
2694+
assert result3 == OptionalParam(some_string="foo")

0 commit comments

Comments
 (0)