Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/aws_durable_execution_sdk_python/concurrency/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def resubmitter(executable_with_state: ExecutableWithState) -> None:
ThreadPoolExecutor(max_workers=max_workers) as thread_executor,
):

def submit_task(executable_with_state: ExecutableWithState) -> None:
def submit_task(executable_with_state: ExecutableWithState) -> Future:
"""Submit task to the thread executor and mark its state as started."""
future = thread_executor.submit(
self._execute_item_in_child_context,
Expand All @@ -216,14 +216,21 @@ def on_done(future: Future) -> None:
self._on_task_complete(executable_with_state, future, scheduler)

future.add_done_callback(on_done)
return future

# Submit initial tasks
for exe_state in self.executables_with_state:
submit_task(exe_state)
futures = [
submit_task(exe_state) for exe_state in self.executables_with_state
]

# Wait for completion
self._completion_event.wait()

# Cancel remaining futures so
# that we don't wait for them to join.
for future in futures:
future.cancel()

# Suspend execution if everything done and at least one of the tasks raised a suspend exception.
if self._suspend_exception:
raise self._suspend_exception
Expand Down Expand Up @@ -275,6 +282,11 @@ def _on_task_complete(
scheduler: TimerScheduler,
) -> None:
"""Handle task completion, suspension, or failure."""

if future.cancelled():
exe_state.suspend()
return

try:
result = future.result()
exe_state.complete(result)
Expand Down
47 changes: 45 additions & 2 deletions tests/concurrency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ def execute_item(self, child_context, executable):
exe_state = ExecutableWithState(executables[0])
future = Mock()
future.result.side_effect = TimedSuspendExecution("test message", time.time() + 1)
future.cancelled.return_value = False

scheduler = Mock()
scheduler.schedule_resume = Mock()
Expand Down Expand Up @@ -1040,6 +1041,7 @@ def execute_item(self, child_context, executable):
exe_state = ExecutableWithState(executables[0])
future = Mock()
future.result.side_effect = ValueError("Test error")
future.cancelled.return_value = False

scheduler = Mock()

Expand All @@ -1049,14 +1051,16 @@ def execute_item(self, child_context, executable):
assert isinstance(exe_state.error, ValueError)


def test_concurrent_executor_create_result_with_failed_branches():
def test_concurrent_executor_create_result_with_early_exit():
"""Test ConcurrentExecutor with failed branches using public execute method."""

class TestExecutor(ConcurrentExecutor):
def execute_item(self, child_context, executable):
if executable.index == 0:
return f"result_{executable.index}"
msg = "Test error"
# giving space to terminate early with
time.sleep(0.5)
raise ValueError(msg)

def success_callable():
Expand All @@ -1067,7 +1071,8 @@ def failure_callable():

executables = [Executable(0, success_callable), Executable(1, failure_callable)]
completion_config = CompletionConfig(
min_successful=1,
# setting min successful to None to execute all children and avoid early stopping
min_successful=None,
tolerated_failure_count=None,
tolerated_failure_percentage=None,
)
Expand Down Expand Up @@ -2795,3 +2800,41 @@ def execute_item(self, child_context, executable):
statuses = {item.index: item.status for item in result.all}
assert statuses[0] == BatchItemStatus.FAILED
assert statuses[1] == BatchItemStatus.SUCCEEDED


def test_executor_terminates_quickly_when_impossible_to_succeed():
"""Test that executor terminates when min_successful becomes impossible."""
executed_count = {"value": 0}

def task_func(ctx, item, idx, items):
executed_count["value"] += 1
if idx < 2:
raise Exception(f"fail_{idx}") # noqa EM102 TRY002
time.sleep(0.05)
return f"ok_{idx}"

items = list(range(100))
config = MapConfig(
max_concurrency=10, completion_config=CompletionConfig(min_successful=99)
)

executor = MapExecutor.from_items(items=items, func=task_func, config=config)

execution_state = Mock()
execution_state.create_checkpoint = Mock()
executor_context = Mock()
executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001
executor_context.create_child_context = lambda *args: Mock()

result = executor.execute(execution_state, executor_context)

# With concurrency=1, only 2 tasks should execute before terminating
# min_successful(99) + failure_count(2) = 101 > total_tasks(100)
assert executed_count["value"] < 100
assert (
result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED
), executed_count
assert sum(1 for item in result.all if item.status == BatchItemStatus.FAILED) == 2
assert (
sum(1 for item in result.all if item.status == BatchItemStatus.SUCCEEDED) < 98
)