Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions durabletask/internal/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,16 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction
return pb.OrchestratorAction(id=id, createTimer=pb.CreateTimerAction(fireAt=timestamp))


def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str]) -> pb.OrchestratorAction:
return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction(
name=name,
input=get_string_value(encoded_input)
))
def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], router: Optional[pb.TaskRouter] = None) -> pb.OrchestratorAction:
return pb.OrchestratorAction(
id=id,
scheduleTask=pb.ScheduleTaskAction(
name=name,
input=get_string_value(encoded_input),
router=router,
),
router=router,
)


def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp:
Expand All @@ -195,12 +200,18 @@ def new_create_sub_orchestration_action(
id: int,
name: str,
instance_id: Optional[str],
encoded_input: Optional[str]) -> pb.OrchestratorAction:
return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction(
name=name,
instanceId=instance_id,
input=get_string_value(encoded_input)
))
encoded_input: Optional[str],
router: Optional[pb.TaskRouter] = None) -> pb.OrchestratorAction:
return pb.OrchestratorAction(
id=id,
createSubOrchestration=pb.CreateSubOrchestrationAction(
name=name,
instanceId=instance_id,
input=get_string_value(encoded_input),
router=router,
),
router=router,
)


def is_empty(v: wrappers_pb2.StringValue):
Expand Down
10 changes: 8 additions & 2 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:
@abstractmethod
def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
input: Optional[TInput] = None,
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
retry_policy: Optional[RetryPolicy] = None,
app_id: Optional[str] = None) -> Task[TOutput]:
"""Schedule an activity for execution.

Parameters
Expand All @@ -111,6 +112,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
The JSON-serializable input (or None) to pass to the activity.
retry_policy: Optional[RetryPolicy]
The retry policy to use for this activity call.
app_id: Optional[str]
The app ID that will execute the activity. If not specified, the activity will be executed by the same app as the orchestrator.

Returns
-------
Expand All @@ -123,7 +126,8 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
retry_policy: Optional[RetryPolicy] = None,
app_id: Optional[str] = None) -> Task[TOutput]:
"""Schedule sub-orchestrator function for execution.

Parameters
Expand All @@ -137,6 +141,8 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
random UUID will be used.
retry_policy: Optional[RetryPolicy]
The retry policy to use for this sub-orchestrator call.
app_id: Optional[str]
The app ID that will execute the sub-orchestrator. If not specified, the sub-orchestrator will be executed by the same app as the orchestrator.

Returns
-------
Expand Down
38 changes: 33 additions & 5 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def __init__(self, instance_id: str):
self._sequence_number = 0
self._current_utc_datetime = datetime(1000, 1, 1)
self._instance_id = instance_id
self._app_id = None
self._completion_status: Optional[pb.OrchestrationStatus] = None
self._received_events: dict[str, list[Any]] = {}
self._pending_events: dict[str, list[task.CompletableTask]] = {}
Expand Down Expand Up @@ -705,6 +706,10 @@ def next_sequence_number(self) -> int:
self._sequence_number += 1
return self._sequence_number

@property
def app_id(self) -> str:
return self._app_id

@property
def instance_id(self) -> str:
return self._instance_id
Expand Down Expand Up @@ -752,31 +757,37 @@ def call_activity(
*,
input: Optional[TInput] = None,
retry_policy: Optional[task.RetryPolicy] = None,
app_id: Optional[str] = None,
) -> task.Task[TOutput]:
id = self.next_sequence_number()

self.call_activity_function_helper(
id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False
id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False, app_id=app_id
)
return self._pending_tasks.get(id, task.CompletableTask())

def call_sub_orchestrator(
self,
orchestrator: task.Orchestrator[TInput, TOutput],
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
*,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[task.RetryPolicy] = None,
app_id: Optional[str] = None,
) -> task.Task[TOutput]:
id = self.next_sequence_number()
orchestrator_name = task.get_name(orchestrator)
if isinstance(orchestrator, str):
orchestrator_name = orchestrator
else:
orchestrator_name = task.get_name(orchestrator)
self.call_activity_function_helper(
id,
orchestrator_name,
input=input,
retry_policy=retry_policy,
is_sub_orch=True,
instance_id=instance_id,
app_id=app_id,
)
return self._pending_tasks.get(id, task.CompletableTask())

Expand All @@ -790,10 +801,16 @@ def call_activity_function_helper(
is_sub_orch: bool = False,
instance_id: Optional[str] = None,
fn_task: Optional[task.CompletableTask[TOutput]] = None,
app_id: Optional[str] = None,
):
if id is None:
id = self.next_sequence_number()

router = pb.TaskRouter()
router.sourceAppID = self._app_id
if app_id is not None:
router.targetAppID = app_id

if fn_task is None:
encoded_input = shared.to_json(input) if input is not None else None
else:
Expand All @@ -806,15 +823,15 @@ def call_activity_function_helper(
if isinstance(activity_function, str)
else task.get_name(activity_function)
)
action = ph.new_schedule_task_action(id, name, encoded_input)
action = ph.new_schedule_task_action(id, name, encoded_input, router)
else:
if instance_id is None:
# Create a deteministic instance ID based on the parent instance ID
instance_id = f"{self.instance_id}:{id:04x}"
if not isinstance(activity_function, str):
raise ValueError("Orchestrator function name must be a string")
action = ph.new_create_sub_orchestration_action(
id, activity_function, instance_id, encoded_input
id, activity_function, instance_id, encoded_input, router
)
self._pending_actions[id] = action

Expand Down Expand Up @@ -953,6 +970,11 @@ def process_event(
if event.HasField("orchestratorStarted"):
ctx.current_utc_datetime = event.timestamp.ToDatetime()
elif event.HasField("executionStarted"):
if event.router.targetAppID:
ctx._app_id = event.router.targetAppID
else:
ctx._app_id = event.router.sourceAppID

# TODO: Check if we already started the orchestration
fn = self._registry.get_orchestrator(event.executionStarted.name)
if fn is None:
Expand Down Expand Up @@ -1010,6 +1032,11 @@ def process_event(
else:
cur_task = activity_action.createSubOrchestration
instance_id = cur_task.instanceId
if cur_task.router and cur_task.router.targetAppID:
target_app_id = cur_task.router.targetAppID
else:
target_app_id = None

ctx.call_activity_function_helper(
id=activity_action.id,
activity_function=cur_task.name,
Expand All @@ -1018,6 +1045,7 @@ def process_event(
is_sub_orch=timer_task._retryable_parent._is_sub_orch,
instance_id=instance_id,
fn_task=timer_task._retryable_parent,
app_id=target_app_id,
)
else:
ctx.resume()
Expand Down
128 changes: 128 additions & 0 deletions tests/durabletask/test_orchestration_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,70 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input):
assert actions[0].scheduleTask.input.value == encoded_input


def test_schedule_activity_actions_router_without_app_id():
"""Tests that scheduleTask action contains correct router fields when app_id is specified"""
def dummy_activity(ctx, _):
pass

def orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_activity(dummy_activity, input=42)

registry = worker._Registry()
name = registry.add_orchestrator(orchestrator)

# Prepare execution started event with source app set on router
exec_evt = helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)
exec_evt.router.sourceAppID = "source-app"

new_events = [
helpers.new_orchestrator_started_event(),
exec_evt,
]

executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
result = executor.execute(TEST_INSTANCE_ID, [], new_events)
actions = result.actions

assert len(actions) == 1
action = actions[0]
assert action.router.sourceAppID == "source-app"
assert action.router.targetAppID == ''
assert action.scheduleTask.router.sourceAppID == "source-app"
assert action.scheduleTask.router.targetAppID == ''


def test_schedule_activity_actions_router_with_app_id():
"""Tests that scheduleTask action contains correct router fields when app_id is specified"""
def dummy_activity(ctx, _):
pass

def orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_activity(dummy_activity, input=42, app_id="target-app")

registry = worker._Registry()
name = registry.add_orchestrator(orchestrator)

# Prepare execution started event with source app set on router
exec_evt = helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)
exec_evt.router.sourceAppID = "source-app"

new_events = [
helpers.new_orchestrator_started_event(),
exec_evt,
]

executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
result = executor.execute(TEST_INSTANCE_ID, [], new_events)
actions = result.actions

assert len(actions) == 1
action = actions[0]
assert action.router.sourceAppID == "source-app"
assert action.router.targetAppID == "target-app"
assert action.scheduleTask.router.sourceAppID == "source-app"
assert action.scheduleTask.router.targetAppID == "target-app"


def test_activity_task_completion():
"""Tests the successful completion of an activity task"""

Expand Down Expand Up @@ -561,6 +625,70 @@ def orchestrator(ctx: task.OrchestrationContext, _):
assert complete_action.result.value == "42"


def test_create_sub_orchestration_actions_router_without_app_id():
"""Tests that createSubOrchestration action contains correct router fields when app_id is specified"""
def suborchestrator(ctx: task.OrchestrationContext, _):
pass

def orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_sub_orchestrator(suborchestrator, input=None)

registry = worker._Registry()
suborchestrator_name = registry.add_orchestrator(suborchestrator)
orchestrator_name = registry.add_orchestrator(orchestrator)

exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)
exec_evt.router.sourceAppID = "source-app"

new_events = [
helpers.new_orchestrator_started_event(),
exec_evt,
]

executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
result = executor.execute(TEST_INSTANCE_ID, [], new_events)
actions = result.actions

assert len(actions) == 1
action = actions[0]
assert action.router.sourceAppID == "source-app"
assert action.router.targetAppID == ''
assert action.createSubOrchestration.router.sourceAppID == "source-app"
assert action.createSubOrchestration.router.targetAppID == ''


def test_create_sub_orchestration_actions_router_with_app_id():
"""Tests that createSubOrchestration action contains correct router fields when app_id is specified"""
def suborchestrator(ctx: task.OrchestrationContext, _):
pass

def orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_sub_orchestrator(suborchestrator, input=None, app_id="target-app")

registry = worker._Registry()
suborchestrator_name = registry.add_orchestrator(suborchestrator)
orchestrator_name = registry.add_orchestrator(orchestrator)

exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)
exec_evt.router.sourceAppID = "source-app"

new_events = [
helpers.new_orchestrator_started_event(),
exec_evt,
]

executor = worker._OrchestrationExecutor(registry, TEST_LOGGER)
result = executor.execute(TEST_INSTANCE_ID, [], new_events)
actions = result.actions

assert len(actions) == 1
action = actions[0]
assert action.router.sourceAppID == "source-app"
assert action.router.targetAppID == "target-app"
assert action.createSubOrchestration.router.sourceAppID == "source-app"
assert action.createSubOrchestration.router.targetAppID == "target-app"


def test_sub_orchestration_task_failed():
"""Tests that a sub-orchestration task is completed when the sub-orchestration fails"""
def suborchestrator(ctx: task.OrchestrationContext, _):
Expand Down
Loading