diff --git a/src/aws_durable_execution_sdk_python/concurrency.py b/src/aws_durable_execution_sdk_python/concurrency.py index b5af8c6..fb43c03 100644 --- a/src/aws_durable_execution_sdk_python/concurrency.py +++ b/src/aws_durable_execution_sdk_python/concurrency.py @@ -7,6 +7,7 @@ import threading import time from abc import ABC, abstractmethod +from collections import Counter from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum @@ -98,16 +99,69 @@ class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059 completion_reason: CompletionReason @classmethod - def from_dict(cls, data: dict) -> BatchResult[R]: + def from_dict( + cls, data: dict, completion_config: CompletionConfig | None = None + ) -> BatchResult[R]: batch_items: list[BatchItem[R]] = [ BatchItem.from_dict(item) for item in data["all"] ] - # TODO: is this valid? assuming completion reason is ALL_COMPLETED? - completion_reason = CompletionReason( - data.get("completionReason", "ALL_COMPLETED") - ) + + completion_reason_value = data.get("completionReason") + if completion_reason_value is None: + # Infer completion reason from batch item statuses and completion config + # This aligns with the TypeScript implementation that uses completion config + # to accurately reconstruct the completion reason during replay + result = cls.from_items(batch_items, completion_config) + logger.warning( + "Missing completionReason in BatchResult deserialization, " + "inferred '%s' from batch item statuses. " + "This may indicate incomplete serialization data.", + result.completion_reason.value, + ) + return result + + completion_reason = CompletionReason(completion_reason_value) return cls(batch_items, completion_reason) + @classmethod + def from_items( + cls, + items: list[BatchItem[R]], + completion_config: CompletionConfig | None = None, + ): + """ + Infer completion reason based on batch item statuses and completion config. + + This follows the same logic as the TypeScript implementation: + - If all items completed: ALL_COMPLETED + - If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED + - Otherwise: FAILURE_TOLERANCE_EXCEEDED + """ + + statuses = (item.status for item in items) + counts = Counter(statuses) + succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0) + failed_count = counts.get(BatchItemStatus.FAILED, 0) + started_count = counts.get(BatchItemStatus.STARTED, 0) + + completed_count = succeeded_count + failed_count + total_count = started_count + completed_count + + # If all items completed (no started items), it's ALL_COMPLETED + if completed_count == total_count: + completion_reason = CompletionReason.ALL_COMPLETED + elif ( # If we have completion config and minSuccessful threshold is met + completion_config + and (min_successful := completion_config.min_successful) is not None + and succeeded_count >= min_successful + ): + completion_reason = CompletionReason.MIN_SUCCESSFUL_REACHED + else: + # Otherwise, assume failure tolerance was exceeded + completion_reason = CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + return cls(items, completion_reason) + def to_dict(self) -> dict: return { "all": [item.to_dict() for item in self.all], @@ -163,19 +217,15 @@ def get_errors(self) -> list[ErrorObject]: @property def success_count(self) -> int: - return len( - [item for item in self.all if item.status is BatchItemStatus.SUCCEEDED] - ) + return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED) @property def failure_count(self) -> int: - return len([item for item in self.all if item.status is BatchItemStatus.FAILED]) + return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED) @property def started_count(self) -> int: - return len( - [item for item in self.all if item.status is BatchItemStatus.STARTED] - ) + return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED) @property def total_count(self) -> int: @@ -336,25 +386,63 @@ def fail_task(self) -> None: with self._lock: self.failure_count += 1 - def should_complete(self) -> bool: - """Check if execution should complete.""" + def should_continue(self) -> bool: + """ + Check if we should continue starting new tasks (based on failure tolerance). + Matches TypeScript shouldContinue() logic. + """ with self._lock: - # Success condition - if self.success_count >= self.min_successful: - return True + # If no completion config, only continue if no failures + if ( + self.tolerated_failure_count is None + and self.tolerated_failure_percentage is None + ): + return self.failure_count == 0 - # Failure conditions - if self._is_failure_condition_reached( - tolerated_count=self.tolerated_failure_count, - tolerated_percentage=self.tolerated_failure_percentage, - failure_count=self.failure_count, + # Check failure count tolerance + if ( + self.tolerated_failure_count is not None + and self.failure_count > self.tolerated_failure_count ): - return True + return False - # Impossible to succeed condition - # TODO: should this keep running? TS doesn't currently handle this either. - remaining_tasks = self.total_tasks - self.success_count - self.failure_count - return self.success_count + remaining_tasks < self.min_successful + # Check failure percentage tolerance + if self.tolerated_failure_percentage is not None and self.total_tasks > 0: + failure_percentage = (self.failure_count / self.total_tasks) * 100 + if failure_percentage > self.tolerated_failure_percentage: + return False + + return True + + def is_complete(self) -> bool: + """ + Check if execution should complete (based on completion criteria). + Matches TypeScript isComplete() logic. + """ + with self._lock: + completed_count = self.success_count + self.failure_count + + # All tasks completed + if completed_count == self.total_tasks: + # Complete if no failure tolerance OR no failures OR min successful reached + return ( + ( + self.tolerated_failure_count is None + and self.tolerated_failure_percentage is None + ) + or self.failure_count == 0 + or self.success_count >= self.min_successful + ) + + # when we breach min successful, we've completed + return self.success_count >= self.min_successful + + def should_complete(self) -> bool: + """ + Check if execution should complete. + Combines TypeScript shouldContinue() and isComplete() logic. + """ + return self.is_complete() or not self.should_continue() def is_all_completed(self) -> bool: """True if all tasks completed successfully.""" @@ -640,40 +728,46 @@ def _on_task_complete( self._completion_event.set() def _create_result(self) -> BatchResult[ResultType]: - """Build the final BatchResult.""" - batch_items: list[BatchItem[ResultType]] = [] - completed_branches: list[ExecutableWithState] = [] - failed_branches: list[ExecutableWithState] = [] + """ + Build the final BatchResult. + When this function executes, we've terminated the upper/parent context for whatever reason. + It follows that our items can be only in 3 states, Completed, Failed and Started (in all of the possible forms). + We tag each branch based on its observed value at the time of completion of the parent / upper context, and pass the + results to BatchResult. + + Any inference wrt completion reason is left up to BatchResult, keeping the logic inference isolated. + """ + batch_items: list[BatchItem[ResultType]] = [] for executable in self.executables_with_state: - if executable.status is BranchStatus.COMPLETED: - completed_branches.append(executable) - batch_items.append( - BatchItem( - executable.index, BatchItemStatus.SUCCEEDED, executable.result + match executable.status: + case BranchStatus.COMPLETED: + batch_items.append( + BatchItem( + executable.index, + BatchItemStatus.SUCCEEDED, + executable.result, + ) ) - ) - elif executable.status is BranchStatus.FAILED: - failed_branches.append(executable) - batch_items.append( - BatchItem( - executable.index, - BatchItemStatus.FAILED, - error=ErrorObject.from_exception(executable.error), + case BranchStatus.FAILED: + batch_items.append( + BatchItem( + executable.index, + BatchItemStatus.FAILED, + error=ErrorObject.from_exception(executable.error), + ) + ) + case ( + BranchStatus.PENDING + | BranchStatus.RUNNING + | BranchStatus.SUSPENDED + | BranchStatus.SUSPENDED_WITH_TIMEOUT + ): + batch_items.append( + BatchItem(executable.index, BatchItemStatus.STARTED) ) - ) - - completion_reason: CompletionReason = ( - CompletionReason.ALL_COMPLETED - if self.counters.is_all_completed() - else ( - CompletionReason.MIN_SUCCESSFUL_REACHED - if self.counters.is_min_successful_reached() - else CompletionReason.FAILURE_TOLERANCE_EXCEEDED - ) - ) - return BatchResult(batch_items, completion_reason) + return BatchResult.from_items(batch_items, self.completion_config) def _execute_item_in_child_context( self, diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index 38dc9d6..d3d090b 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -319,8 +319,178 @@ def test_batch_result_from_dict_default_completion_reason(): # No completionReason provided } - result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + # Verify warning was logged + mock_logger.warning.assert_called_once() + assert "Missing completionReason" in mock_logger.warning.call_args[0][0] + + +def test_batch_result_from_dict_infer_all_completed_all_succeeded(): + """Test BatchResult from_dict infers ALL_COMPLETED when all items succeeded.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed(): + """Test BatchResult from_dict infers FAILURE_TOLERANCE_EXCEEDED when all items failed.""" + error_data = { + "message": "Test error", + "type": "TestError", + "data": None, + "stackTrace": None, + } + data = { + "all": [ + {"index": 0, "status": "FAILED", "result": None, "error": error_data}, + {"index": 1, "status": "FAILED", "result": None, "error": error_data}, + ], + # No completionReason provided + } + + # even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED + # https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335 + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_all_completed_mixed_success_failure(): + """Test BatchResult from_dict infers ALL_COMPLETED when mix of success/failure but no started items.""" + error_data = { + "message": "Test error", + "type": "TestError", + "data": None, + "stackTrace": None, + } + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "FAILED", "result": None, "error": error_data}, + {"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + # the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_min_successful_reached_has_started(): + """Test BatchResult from_dict infers MIN_SUCCESSFUL_REACHED when items are still started.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "STARTED", "result": None, "error": None}, + {"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data, CompletionConfig(1)) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_empty_items(): + """Test BatchResult from_dict infers ALL_COMPLETED for empty items.""" + data = { + "all": [], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_with_explicit_completion_reason(): + """Test BatchResult from_dict uses explicit completionReason when provided.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None} + ], + "completionReason": "MIN_SUCCESSFUL_REACHED", + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + # No warning should be logged when completionReason is provided + mock_logger.warning.assert_not_called() + + +def test_batch_result_infer_completion_reason_edge_cases(): + """Test _infer_completion_reason method with various edge cases.""" + # Test with only started items + started_items = [ + BatchItem(0, BatchItemStatus.STARTED).to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + items = {"all": started_items} + batch = BatchResult.from_dict(items, CompletionConfig(0)) # SLF001 + # this state is not possible with CompletionConfig(0) + assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Test with only started items + started_items = [ + BatchItem(0, BatchItemStatus.STARTED).to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + items = {"all": started_items} + batch = BatchResult.from_dict(items) # SLF001 + # this state is not possible with CompletionConfig(0) + assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + # Test with only failed items + failed_items = [ + BatchItem( + 0, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + ] + failed_items = {"all": failed_items} + batch = BatchResult.from_dict(failed_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test with only succeeded items + succeeded_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem(1, BatchItemStatus.SUCCEEDED, "result2").to_dict(), + ] + succeeded_items = {"all": succeeded_items} + batch = BatchResult.from_dict(succeeded_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test with mixed but no started (all completed) + mixed_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + + batch = BatchResult.from_items(mixed_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED def test_batch_result_get_results_empty(): @@ -924,7 +1094,10 @@ def mock_run_in_child_context(func, name, config): assert len(result.all) == 2 assert result.all[0].status == BatchItemStatus.SUCCEEDED assert result.all[1].status == BatchItemStatus.FAILED - assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + # WHEN all items complete, THEN completion reason is ALL_COMPLETED. + # we don't consider thresholds and limits. + # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_concurrent_executor_execute_item_in_child_context(): @@ -1007,7 +1180,10 @@ def mock_run_in_child_context(func, name, config): return func(Mock()) result = executor.execute(execution_state, mock_run_in_child_context) - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # WHEN all items complete, THEN completion reason is ALL_COMPLETED. + # we don't consider thresholds and limits. + # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_single_task_suspend_bubbles_up(): @@ -1675,6 +1851,482 @@ def mock_run_in_child_context(func, name, config): executor.execute(execution_state, mock_run_in_child_context) +# Tests for _create_result method match statement branches +def test_create_result_completed_branch(): + """Test _create_result with COMPLETED status branch.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [Executable(0, lambda: "test")] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=1, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executable with COMPLETED status + exe_state = ExecutableWithState(executables[0]) + exe_state.complete("test_result") + executor.executables_with_state = [exe_state] + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.SUCCEEDED + assert result.all[0].result == "test_result" + assert result.all[0].error is None + assert result.all[0].index == 0 + + +def test_create_result_failed_branch(): + """Test _create_result with FAILED status branch.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [Executable(0, lambda: "test")] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=1, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executable with FAILED status + exe_state = ExecutableWithState(executables[0]) + test_error = ValueError("Test error message") + exe_state.fail(test_error) + executor.executables_with_state = [exe_state] + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.FAILED + assert result.all[0].result is None + assert result.all[0].error is not None + assert result.all[0].error.message == "Test error message" + assert result.all[0].error.type == "ValueError" + assert result.all[0].index == 0 + + +def test_create_result_pending_branch(): + """Test _create_result with PENDING status branch.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [Executable(0, lambda: "test")] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=1, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executable with PENDING status (default state) + exe_state = ExecutableWithState(executables[0]) + # PENDING is the default state, no need to change it + executor.executables_with_state = [exe_state] + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.STARTED + assert result.all[0].result is None + assert result.all[0].error is None + assert result.all[0].index == 0 + # By default, if we've terminated the reasoning is failure tolerance exceeded + # according to the spec + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_create_result_running_branch(): + """Test _create_result with RUNNING status branch.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [Executable(0, lambda: "test")] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=1, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executable with RUNNING status + exe_state = ExecutableWithState(executables[0]) + future = Future() + exe_state.run(future) + executor.executables_with_state = [exe_state] + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.STARTED + assert result.all[0].result is None + assert result.all[0].error is None + assert result.all[0].index == 0 + # By default, if we've terminated the reasoning is failure tolerance exceeded + # according to the spec + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_create_result_suspended_branch(): + """Test _create_result with SUSPENDED status branch.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [Executable(0, lambda: "test")] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=1, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executable with SUSPENDED status + exe_state = ExecutableWithState(executables[0]) + exe_state.suspend() + executor.executables_with_state = [exe_state] + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.STARTED + assert result.all[0].result is None + assert result.all[0].error is None + assert result.all[0].index == 0 + # By default, if we've terminated the reasoning is failure tolerance exceeded + # according to the spec + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_create_result_suspended_with_timeout_branch(): + """Test _create_result with SUSPENDED_WITH_TIMEOUT status branch.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [Executable(0, lambda: "test")] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=1, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executable with SUSPENDED_WITH_TIMEOUT status + exe_state = ExecutableWithState(executables[0]) + future_time = time.time() + 10 + exe_state.suspend_with_timeout(future_time) + executor.executables_with_state = [exe_state] + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.STARTED + assert result.all[0].result is None + assert result.all[0].error is None + assert result.all[0].index == 0 + # By default, if we've terminated the reasoning is failure tolerance exceeded + # according to the spec + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_create_result_mixed_statuses(): + """Test _create_result with mixed executable statuses covering all branches.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [ + Executable(0, lambda: "test0"), # Will be COMPLETED + Executable(1, lambda: "test1"), # Will be FAILED + Executable(2, lambda: "test2"), # Will be PENDING + Executable(3, lambda: "test3"), # Will be RUNNING + Executable(4, lambda: "test4"), # Will be SUSPENDED + Executable(5, lambda: "test5"), # Will be SUSPENDED_WITH_TIMEOUT + ] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=6, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executables with different statuses + exe_states = [ExecutableWithState(exe) for exe in executables] + + # COMPLETED + exe_states[0].complete("completed_result") + + # FAILED + exe_states[1].fail(RuntimeError("Test failure")) + + # PENDING (default state, no change needed) + + # RUNNING + future = Future() + exe_states[3].run(future) + + # SUSPENDED + exe_states[4].suspend() + + # SUSPENDED_WITH_TIMEOUT + exe_states[5].suspend_with_timeout(time.time() + 10) + + executor.executables_with_state = exe_states + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 6 + + # Check COMPLETED -> SUCCEEDED + assert result.all[0].status == BatchItemStatus.SUCCEEDED + assert result.all[0].result == "completed_result" + assert result.all[0].error is None + + # Check FAILED -> FAILED + assert result.all[1].status == BatchItemStatus.FAILED + assert result.all[1].result is None + assert result.all[1].error is not None + assert result.all[1].error.message == "Test failure" + + # Check PENDING -> STARTED + assert result.all[2].status == BatchItemStatus.STARTED + assert result.all[2].result is None + assert result.all[2].error is None + + # Check RUNNING -> STARTED + assert result.all[3].status == BatchItemStatus.STARTED + assert result.all[3].result is None + assert result.all[3].error is None + + # Check SUSPENDED -> STARTED + assert result.all[4].status == BatchItemStatus.STARTED + assert result.all[4].result is None + assert result.all[4].error is None + + # Check SUSPENDED_WITH_TIMEOUT -> STARTED + assert result.all[5].status == BatchItemStatus.STARTED + assert result.all[5].result is None + assert result.all[5].error is None + + # we've a min succ set to 1. + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + +def test_create_result_multiple_completed(): + """Test _create_result with multiple COMPLETED executables.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [ + Executable(0, lambda: "test0"), + Executable(1, lambda: "test1"), + Executable(2, lambda: "test2"), + ] + completion_config = CompletionConfig(min_successful=3) + + executor = TestExecutor( + executables=executables, + max_concurrency=3, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create all executables with COMPLETED status + exe_states = [ExecutableWithState(exe) for exe in executables] + exe_states[0].complete("result_0") + exe_states[1].complete("result_1") + exe_states[2].complete("result_2") + + executor.executables_with_state = exe_states + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 3 + assert all(item.status == BatchItemStatus.SUCCEEDED for item in result.all) + assert result.all[0].result == "result_0" + assert result.all[1].result == "result_1" + assert result.all[2].result == "result_2" + assert result.completion_reason == CompletionReason.ALL_COMPLETED + + +def test_create_result_multiple_failed(): + """Test _create_result with multiple FAILED executables.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [ + Executable(0, lambda: "test0"), + Executable(1, lambda: "test1"), + Executable(2, lambda: "test2"), + ] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=3, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create all executables with FAILED status + exe_states = [ExecutableWithState(exe) for exe in executables] + exe_states[0].fail(ValueError("Error 0")) + exe_states[1].fail(RuntimeError("Error 1")) + exe_states[2].fail(TypeError("Error 2")) + + executor.executables_with_state = exe_states + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 3 + assert all(item.status == BatchItemStatus.FAILED for item in result.all) + assert result.all[0].error.message == "Error 0" + assert result.all[1].error.message == "Error 1" + assert result.all[2].error.message == "Error 2" + assert result.completion_reason == CompletionReason.ALL_COMPLETED + + +def test_create_result_multiple_started_states(): + """Test _create_result with multiple executables in STARTED states.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [ + Executable(0, lambda: "test0"), # PENDING + Executable(1, lambda: "test1"), # RUNNING + Executable(2, lambda: "test2"), # SUSPENDED + Executable(3, lambda: "test3"), # SUSPENDED_WITH_TIMEOUT + ] + completion_config = CompletionConfig(min_successful=1) + + executor = TestExecutor( + executables=executables, + max_concurrency=4, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executables with different STARTED states + exe_states = [ExecutableWithState(exe) for exe in executables] + + # PENDING (default state) + + # RUNNING + future = Future() + exe_states[1].run(future) + + # SUSPENDED + exe_states[2].suspend() + + # SUSPENDED_WITH_TIMEOUT + exe_states[3].suspend_with_timeout(time.time() + 5) + + executor.executables_with_state = exe_states + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 4 + assert all(item.status == BatchItemStatus.STARTED for item in result.all) + assert all(item.result is None for item in result.all) + assert all(item.error is None for item in result.all) + # With completion config min_successful=1 and no completed items, + # this should be FAILURE_TOLERANCE_EXCEEDED + assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + +def test_create_result_empty_executables(): + """Test _create_result with no executables.""" + + class TestExecutor(ConcurrentExecutor): + def execute_item(self, child_context, executable): + return f"result_{executable.index}" + + executables = [] + completion_config = CompletionConfig(min_successful=0) + + executor = TestExecutor( + executables=executables, + max_concurrency=1, + completion_config=completion_config, + sub_type_top="TOP", + sub_type_iteration="ITER", + name_prefix="test_", + serdes=None, + ) + + executor.executables_with_state = [] + + result = executor._create_result() # noqa: SLF001 + + assert len(result.all) == 0 + assert result.completion_reason == CompletionReason.ALL_COMPLETED + + def test_timer_scheduler_future_time_condition_false(): """Test TimerScheduler when scheduled time is in future (434->433 branch).""" callback = Mock() @@ -1692,3 +2344,94 @@ def test_timer_scheduler_future_time_condition_false(): # Callback should not be called since time is in future callback.assert_not_called() + + +def test_batch_result_from_dict_with_completion_config(): + """Test BatchResult from_dict with completion config parameter.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "STARTED", "result": None, "error": None}, + ], + # No completionReason provided + } + + # With started items, should infer MIN_SUCCESSFUL_REACHED + completion_config = CompletionConfig(min_successful=1) + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data, completion_config) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_all_completed(): + """Test BatchResult from_dict infers ALL_COMPLETED when all items are completed.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + { + "index": 1, + "status": "FAILED", + "result": None, + "error": { + "message": "error", + "type": "Error", + "data": None, + "stackTrace": None, + }, + }, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_backward_compatibility(): + """Test BatchResult from_dict maintains backward compatibility when no completion_config provided.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None} + ], + "completionReason": "MIN_SUCCESSFUL_REACHED", + } + + # Should work without completion_config parameter + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Should also work with None completion_config + result2 = BatchResult.from_dict(data, None) + assert result2.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + +def test_batch_result_infer_completion_reason_basic_cases(): + """Test _infer_completion_reason method with basic scenarios.""" + # Test with started items - should be MIN_SUCCESSFUL_REACHED + items = { + "all": [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + } + batch = BatchResult.from_dict(items, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Test with all completed items - should be ALL_COMPLETED + completed_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + ] + completed_items = {"all": completed_items} + batch = BatchResult.from_dict(completed_items, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test empty items - should be ALL_COMPLETED + batch = BatchResult.from_dict({"all": []}, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.ALL_COMPLETED