diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 6b36586..48ab14b 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -178,11 +178,16 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction return pb.OrchestratorAction(id=id, createTimer=pb.CreateTimerAction(fireAt=timestamp)) -def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str]) -> pb.OrchestratorAction: - return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction( - name=name, - input=get_string_value(encoded_input) - )) +def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], router: Optional[pb.TaskRouter] = None) -> pb.OrchestratorAction: + return pb.OrchestratorAction( + id=id, + scheduleTask=pb.ScheduleTaskAction( + name=name, + input=get_string_value(encoded_input), + router=router, + ), + router=router, + ) def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp: @@ -195,12 +200,18 @@ def new_create_sub_orchestration_action( id: int, name: str, instance_id: Optional[str], - encoded_input: Optional[str]) -> pb.OrchestratorAction: - return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction( - name=name, - instanceId=instance_id, - input=get_string_value(encoded_input) - )) + encoded_input: Optional[str], + router: Optional[pb.TaskRouter] = None) -> pb.OrchestratorAction: + return pb.OrchestratorAction( + id=id, + createSubOrchestration=pb.CreateSubOrchestrationAction( + name=name, + instanceId=instance_id, + input=get_string_value(encoded_input), + router=router, + ), + router=router, + ) def is_empty(v: wrappers_pb2.StringValue): diff --git a/durabletask/task.py b/durabletask/task.py index d319bf2..29af2c5 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -100,7 +100,8 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task: @abstractmethod def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, input: Optional[TInput] = None, - retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]: + retry_policy: Optional[RetryPolicy] = None, + app_id: Optional[str] = None) -> Task[TOutput]: """Schedule an activity for execution. Parameters @@ -111,6 +112,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, The JSON-serializable input (or None) to pass to the activity. retry_policy: Optional[RetryPolicy] The retry policy to use for this activity call. + app_id: Optional[str] + The app ID that will execute the activity. If not specified, the activity will be executed by the same app as the orchestrator. Returns ------- @@ -123,7 +126,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, - retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]: + retry_policy: Optional[RetryPolicy] = None, + app_id: Optional[str] = None) -> Task[TOutput]: """Schedule sub-orchestrator function for execution. Parameters @@ -137,6 +141,8 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, random UUID will be used. retry_policy: Optional[RetryPolicy] The retry policy to use for this sub-orchestrator call. + app_id: Optional[str] + The app ID that will execute the sub-orchestrator. If not specified, the sub-orchestrator will be executed by the same app as the orchestrator. Returns ------- diff --git a/durabletask/worker.py b/durabletask/worker.py index 0922567..7a04649 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -584,6 +584,7 @@ def __init__(self, instance_id: str): self._sequence_number = 0 self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id + self._app_id = None self._completion_status: Optional[pb.OrchestrationStatus] = None self._received_events: dict[str, list[Any]] = {} self._pending_events: dict[str, list[task.CompletableTask]] = {} @@ -705,6 +706,10 @@ def next_sequence_number(self) -> int: self._sequence_number += 1 return self._sequence_number + @property + def app_id(self) -> str: + return self._app_id + @property def instance_id(self) -> str: return self._instance_id @@ -752,24 +757,29 @@ def call_activity( *, input: Optional[TInput] = None, retry_policy: Optional[task.RetryPolicy] = None, + app_id: Optional[str] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() self.call_activity_function_helper( - id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False + id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, app_id=app_id ) return self._pending_tasks.get(id, task.CompletableTask()) def call_sub_orchestrator( self, - orchestrator: task.Orchestrator[TInput, TOutput], + orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, retry_policy: Optional[task.RetryPolicy] = None, + app_id: Optional[str] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() - orchestrator_name = task.get_name(orchestrator) + if isinstance(orchestrator, str): + orchestrator_name = orchestrator + else: + orchestrator_name = task.get_name(orchestrator) self.call_activity_function_helper( id, orchestrator_name, @@ -777,6 +787,7 @@ def call_sub_orchestrator( retry_policy=retry_policy, is_sub_orch=True, instance_id=instance_id, + app_id=app_id, ) return self._pending_tasks.get(id, task.CompletableTask()) @@ -790,10 +801,16 @@ def call_activity_function_helper( is_sub_orch: bool = False, instance_id: Optional[str] = None, fn_task: Optional[task.CompletableTask[TOutput]] = None, + app_id: Optional[str] = None, ): if id is None: id = self.next_sequence_number() + router = pb.TaskRouter() + router.sourceAppID = self._app_id + if app_id is not None: + router.targetAppID = app_id + if fn_task is None: encoded_input = shared.to_json(input) if input is not None else None else: @@ -806,7 +823,7 @@ def call_activity_function_helper( if isinstance(activity_function, str) else task.get_name(activity_function) ) - action = ph.new_schedule_task_action(id, name, encoded_input) + action = ph.new_schedule_task_action(id, name, encoded_input, router) else: if instance_id is None: # Create a deteministic instance ID based on the parent instance ID @@ -814,7 +831,7 @@ def call_activity_function_helper( if not isinstance(activity_function, str): raise ValueError("Orchestrator function name must be a string") action = ph.new_create_sub_orchestration_action( - id, activity_function, instance_id, encoded_input + id, activity_function, instance_id, encoded_input, router ) self._pending_actions[id] = action @@ -953,6 +970,11 @@ def process_event( if event.HasField("orchestratorStarted"): ctx.current_utc_datetime = event.timestamp.ToDatetime() elif event.HasField("executionStarted"): + if event.router.targetAppID: + ctx._app_id = event.router.targetAppID + else: + ctx._app_id = event.router.sourceAppID + # TODO: Check if we already started the orchestration fn = self._registry.get_orchestrator(event.executionStarted.name) if fn is None: @@ -1010,6 +1032,11 @@ def process_event( else: cur_task = activity_action.createSubOrchestration instance_id = cur_task.instanceId + if cur_task.router and cur_task.router.targetAppID: + target_app_id = cur_task.router.targetAppID + else: + target_app_id = None + ctx.call_activity_function_helper( id=activity_action.id, activity_function=cur_task.name, @@ -1018,6 +1045,7 @@ def process_event( is_sub_orch=timer_task._retryable_parent._is_sub_orch, instance_id=instance_id, fn_task=timer_task._retryable_parent, + app_id=target_app_id, ) else: ctx.resume() diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index cb77c81..21f6c6c 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -171,6 +171,70 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): assert actions[0].scheduleTask.input.value == encoded_input +def test_schedule_activity_actions_router_without_app_id(): + """Tests that scheduleTask action contains correct router fields when app_id is specified""" + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(dummy_activity, input=42) + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + # Prepare execution started event with source app set on router + exec_evt = helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None) + exec_evt.router.sourceAppID = "source-app" + + new_events = [ + helpers.new_orchestrator_started_event(), + exec_evt, + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + actions = result.actions + + assert len(actions) == 1 + action = actions[0] + assert action.router.sourceAppID == "source-app" + assert action.router.targetAppID == '' + assert action.scheduleTask.router.sourceAppID == "source-app" + assert action.scheduleTask.router.targetAppID == '' + + +def test_schedule_activity_actions_router_with_app_id(): + """Tests that scheduleTask action contains correct router fields when app_id is specified""" + def dummy_activity(ctx, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(dummy_activity, input=42, app_id="target-app") + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + # Prepare execution started event with source app set on router + exec_evt = helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None) + exec_evt.router.sourceAppID = "source-app" + + new_events = [ + helpers.new_orchestrator_started_event(), + exec_evt, + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + actions = result.actions + + assert len(actions) == 1 + action = actions[0] + assert action.router.sourceAppID == "source-app" + assert action.router.targetAppID == "target-app" + assert action.scheduleTask.router.sourceAppID == "source-app" + assert action.scheduleTask.router.targetAppID == "target-app" + + def test_activity_task_completion(): """Tests the successful completion of an activity task""" @@ -561,6 +625,70 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert complete_action.result.value == "42" +def test_create_sub_orchestration_actions_router_without_app_id(): + """Tests that createSubOrchestration action contains correct router fields when app_id is specified""" + def suborchestrator(ctx: task.OrchestrationContext, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator(suborchestrator, input=None) + + registry = worker._Registry() + suborchestrator_name = registry.add_orchestrator(suborchestrator) + orchestrator_name = registry.add_orchestrator(orchestrator) + + exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) + exec_evt.router.sourceAppID = "source-app" + + new_events = [ + helpers.new_orchestrator_started_event(), + exec_evt, + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + actions = result.actions + + assert len(actions) == 1 + action = actions[0] + assert action.router.sourceAppID == "source-app" + assert action.router.targetAppID == '' + assert action.createSubOrchestration.router.sourceAppID == "source-app" + assert action.createSubOrchestration.router.targetAppID == '' + + +def test_create_sub_orchestration_actions_router_with_app_id(): + """Tests that createSubOrchestration action contains correct router fields when app_id is specified""" + def suborchestrator(ctx: task.OrchestrationContext, _): + pass + + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator(suborchestrator, input=None, app_id="target-app") + + registry = worker._Registry() + suborchestrator_name = registry.add_orchestrator(suborchestrator) + orchestrator_name = registry.add_orchestrator(orchestrator) + + exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) + exec_evt.router.sourceAppID = "source-app" + + new_events = [ + helpers.new_orchestrator_started_event(), + exec_evt, + ] + + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + actions = result.actions + + assert len(actions) == 1 + action = actions[0] + assert action.router.sourceAppID == "source-app" + assert action.router.targetAppID == "target-app" + assert action.createSubOrchestration.router.sourceAppID == "source-app" + assert action.createSubOrchestration.router.targetAppID == "target-app" + + def test_sub_orchestration_task_failed(): """Tests that a sub-orchestration task is completed when the sub-orchestration fails""" def suborchestrator(ctx: task.OrchestrationContext, _):