diff --git a/temporalio/client.py b/temporalio/client.py index 8f9e5095d..2289603a8 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -4047,6 +4047,7 @@ async def _to_proto( temporalio.converter.encode_search_attributes( untyped_not_in_typed, action.start_workflow.search_attributes ) + # TODO (dan): confirm whether this be `is not None` if self.typed_search_attributes: temporalio.converter.encode_search_attributes( self.typed_search_attributes, action.start_workflow.search_attributes @@ -4499,6 +4500,9 @@ class ScheduleUpdate: schedule: Schedule """Schedule to update.""" + search_attributes: Optional[temporalio.common.TypedSearchAttributes] = None + """Search attributes to update.""" + @dataclass class ScheduleListDescription: @@ -6520,14 +6524,20 @@ async def update_schedule(self, input: UpdateScheduleInput) -> None: if not update: return assert isinstance(update, ScheduleUpdate) + request = temporalio.api.workflowservice.v1.UpdateScheduleRequest( + namespace=self._client.namespace, + schedule_id=input.id, + schedule=await update.schedule._to_proto(self._client), + identity=self._client.identity, + request_id=str(uuid.uuid4()), + ) + if update.search_attributes is not None: + request.search_attributes.indexed_fields.clear() # Ensure that we at least create an empty map + temporalio.converter.encode_search_attributes( + update.search_attributes, request.search_attributes + ) await self._client.workflow_service.update_schedule( - temporalio.api.workflowservice.v1.UpdateScheduleRequest( - namespace=self._client.namespace, - schedule_id=input.id, - schedule=await update.schedule._to_proto(self._client), - identity=self._client.identity, - request_id=str(uuid.uuid4()), - ), + request, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout, diff --git a/tests/test_client.py b/tests/test_client.py index dc0128f38..43ec631af 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1188,7 +1188,7 @@ async def test_schedule_create_limited_actions_validation( assert "are remaining actions set" in str(err.value) -async def test_schedule_search_attribute_update( +async def test_schedule_workflow_search_attribute_update( client: Client, env: WorkflowEnvironment ): if env.supports_time_skipping: @@ -1196,10 +1196,8 @@ async def test_schedule_search_attribute_update( await assert_no_schedules(client) # Put search attribute on server - text_attr_key = SearchAttributeKey.for_text(f"python-test-schedule-text") - untyped_keyword_key = SearchAttributeKey.for_keyword( - f"python-test-schedule-keyword" - ) + text_attr_key = SearchAttributeKey.for_text("python-test-schedule-text") + untyped_keyword_key = SearchAttributeKey.for_keyword("python-test-schedule-keyword") await ensure_search_attributes_present(client, text_attr_key, untyped_keyword_key) # Create a schedule with search attributes on the schedule and on the @@ -1273,6 +1271,7 @@ def update_schedule_typed_attrs( # Check that it changed desc = await handle.describe() assert isinstance(desc.schedule.action, ScheduleActionStartWorkflow) + # Check that the workflow search attributes were changed # This assertion has changed since server 1.24. Now, even untyped search # attributes are given a type server side assert ( @@ -1283,6 +1282,148 @@ def update_schedule_typed_attrs( and desc.schedule.action.typed_search_attributes[untyped_keyword_key] == "some-untyped-attr1" ) + # Check that the schedule search attributes were not changed + assert desc.search_attributes[text_attr_key.name] == ["some-schedule-attr1"] + assert desc.typed_search_attributes[text_attr_key] == "some-schedule-attr1" + + await handle.delete() + await assert_no_schedules(client) + + +@pytest.mark.parametrize( + "test_case", + [ + "none-is-noop", + "empty-but-non-none-clears", + "all-new-values-overwrites", + "partial-new-values-overwrites-and-drops", + ], +) +async def test_schedule_search_attribute_update( + client: Client, env: WorkflowEnvironment, test_case: str +): + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support schedules") + await assert_no_schedules(client) + + # Put search attributes on server + key_1 = SearchAttributeKey.for_text("python-test-schedule-sa-update-key-1") + key_2 = SearchAttributeKey.for_keyword("python-test-schedule-sa-update-key-2") + await ensure_search_attributes_present(client, key_1, key_2) + val_1 = "val-1" + val_2 = "val-2" + + # Create a schedule with search attributes + create_action = ScheduleActionStartWorkflow( + "some workflow", + [], + id=f"workflow-{uuid.uuid4()}", + task_queue=f"tq-{uuid.uuid4()}", + ) + handle = await client.create_schedule( + f"schedule-{uuid.uuid4()}", + Schedule(action=create_action, spec=ScheduleSpec()), + search_attributes=TypedSearchAttributes( + [ + SearchAttributePair(key_1, val_1), + SearchAttributePair(key_2, val_2), + ] + ), + ) + + def update_search_attributes( + input: ScheduleUpdateInput, + ) -> Optional[ScheduleUpdate]: + # Make sure the initial search attributes are present + assert input.description.search_attributes[key_1.name] == [val_1] + assert input.description.search_attributes[key_2.name] == [val_2] + assert input.description.typed_search_attributes[key_1] == val_1 + assert input.description.typed_search_attributes[key_2] == val_2 + + if test_case == "none-is-noop": + # Passing None makes no changes + return ScheduleUpdate(input.description.schedule, search_attributes=None) + elif test_case == "empty-but-non-none-clears": + # Pass empty but non-None to clear all attributes + return ScheduleUpdate( + input.description.schedule, + search_attributes=TypedSearchAttributes.empty, + ) + elif test_case == "all-new-values-overwrites": + # Pass all new values to overwrite existing + return ScheduleUpdate( + input.description.schedule, + search_attributes=input.description.typed_search_attributes.updated( + SearchAttributePair(key_1, val_1 + "-new"), + SearchAttributePair(key_2, val_2 + "-new"), + ), + ) + elif test_case == "partial-new-values-overwrites-and-drops": + # Only update key_1, which should drop key_2 + return ScheduleUpdate( + input.description.schedule, + search_attributes=TypedSearchAttributes( + [ + SearchAttributePair(key_1, val_1 + "-new"), + ] + ), + ) + else: + raise ValueError(f"Invalid test case: {test_case}") + + await handle.update(update_search_attributes) + + if test_case == "none-is-noop": + + async def expectation() -> bool: + desc = await handle.describe() + return ( + desc.search_attributes[key_1.name] == [val_1] + and desc.search_attributes[key_2.name] == [val_2] + and desc.typed_search_attributes[key_1] == val_1 + and desc.typed_search_attributes[key_2] == val_2 + ) + + await assert_eq_eventually(True, expectation) + elif test_case == "empty-but-non-none-clears": + + async def expectation() -> bool: + desc = await handle.describe() + return ( + len(desc.typed_search_attributes) == 0 + and len(desc.search_attributes) == 0 + ) + + await assert_eq_eventually(True, expectation) + elif test_case == "all-new-values-overwrites": + + async def expectation() -> bool: + desc = await handle.describe() + return ( + desc.search_attributes[key_1.name] == [val_1 + "-new"] + and desc.search_attributes[key_2.name] == [val_2 + "-new"] + and desc.typed_search_attributes[key_1] == val_1 + "-new" + and desc.typed_search_attributes[key_2] == val_2 + "-new" + ) + + await assert_eq_eventually(True, expectation) + elif test_case == "partial-new-values-overwrites-and-drops": + + async def expectation() -> bool: + desc = await handle.describe() + return ( + desc.search_attributes[key_1.name] == [val_1 + "-new"] + and desc.typed_search_attributes[key_1] == val_1 + "-new" + and key_2.name not in desc.search_attributes + and key_2 not in desc.typed_search_attributes + ) + + await assert_eq_eventually(True, expectation) + else: + raise ValueError(f"Invalid test case: {test_case}") + + await handle.delete() + await assert_no_schedules(client) async def assert_no_schedules(client: Client) -> None: diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 52ab83474..665a5393e 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -38,6 +38,7 @@ from typing_extensions import Literal, Protocol, runtime_checkable import temporalio.worker +import temporalio.workflow from temporalio import activity, workflow from temporalio.api.common.v1 import Payload, Payloads, WorkflowExecution from temporalio.api.enums.v1 import EventType @@ -5040,61 +5041,58 @@ async def run_scenario( update_scenario=scenario, ) - # Run all tasks concurrently - await asyncio.gather( # When unconfigured completely, confirm task fails as normal - run_scenario( + await run_scenario( FailureTypesUnconfiguredWorkflow, FailureTypesScenario.THROW_CUSTOM_EXCEPTION, expect_task_fail=True, - ), - run_scenario( + ) + await run_scenario( FailureTypesUnconfiguredWorkflow, FailureTypesScenario.CAUSE_NON_DETERMINISM, expect_task_fail=True, - ), + ) # When configured at the worker level explicitly, confirm not task fail # but rather expected exceptions - run_scenario( + await run_scenario( FailureTypesUnconfiguredWorkflow, FailureTypesScenario.THROW_CUSTOM_EXCEPTION, worker_level_failure_exception_type=FailureTypesCustomException, - ), - run_scenario( + ) + await run_scenario( FailureTypesUnconfiguredWorkflow, FailureTypesScenario.CAUSE_NON_DETERMINISM, - worker_level_failure_exception_type=workflow.NondeterminismError, - ), + worker_level_failure_exception_type=temporalio.workflow.NondeterminismError, + ) # When configured at the worker level inherited - run_scenario( + await run_scenario( FailureTypesUnconfiguredWorkflow, FailureTypesScenario.THROW_CUSTOM_EXCEPTION, worker_level_failure_exception_type=Exception, - ), - run_scenario( + ) + await run_scenario( FailureTypesUnconfiguredWorkflow, FailureTypesScenario.CAUSE_NON_DETERMINISM, worker_level_failure_exception_type=Exception, - ), + ) # When configured at the workflow level explicitly - run_scenario( + await run_scenario( FailureTypesConfiguredExplicitlyWorkflow, FailureTypesScenario.THROW_CUSTOM_EXCEPTION, - ), - run_scenario( + ) + await run_scenario( FailureTypesConfiguredExplicitlyWorkflow, FailureTypesScenario.CAUSE_NON_DETERMINISM, - ), + ) # When configured at the workflow level inherited - run_scenario( + await run_scenario( FailureTypesConfiguredInheritedWorkflow, FailureTypesScenario.THROW_CUSTOM_EXCEPTION, - ), - run_scenario( + ) + await run_scenario( FailureTypesConfiguredInheritedWorkflow, FailureTypesScenario.CAUSE_NON_DETERMINISM, - ), - ) + ) @workflow.defn(failure_exception_types=[Exception])