diff --git a/src/aws_durable_execution_sdk_python/concurrency.py b/src/aws_durable_execution_sdk_python/concurrency.py index b5af8c6..3cf2aa1 100644 --- a/src/aws_durable_execution_sdk_python/concurrency.py +++ b/src/aws_durable_execution_sdk_python/concurrency.py @@ -7,6 +7,7 @@ import threading import time from abc import ABC, abstractmethod +from collections import Counter from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum @@ -22,7 +23,7 @@ from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterable from aws_durable_execution_sdk_python.config import CompletionConfig from aws_durable_execution_sdk_python.lambda_service import OperationSubType @@ -67,6 +68,42 @@ def suspend(exception: SuspendExecution) -> SuspendResult: return SuspendResult(should_suspend=True, exception=exception) +def _get_completion_reason( + items: Iterable[BatchItemStatus], completion_config: CompletionConfig | None = None +) -> CompletionReason: + """ + Infer completion reason based on batch item statuses and completion config. + + This follows the same logic as the TypeScript implementation: + - If all items completed: ALL_COMPLETED + - If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED + - Otherwise: FAILURE_TOLERANCE_EXCEEDED + """ + + counts = Counter(items) + succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0) + failed_count = counts.get(BatchItemStatus.FAILED, 0) + started_count = counts.get(BatchItemStatus.STARTED, 0) + + completed_count = succeeded_count + failed_count + total_count = started_count + completed_count + + # If all items completed (no started items), it's ALL_COMPLETED + if completed_count == total_count: + return CompletionReason.ALL_COMPLETED + + # If we have completion config and minSuccessful threshold is met + if ( + completion_config + and (min_successful := completion_config.min_successful) is not None + and succeeded_count >= min_successful + ): + return CompletionReason.MIN_SUCCESSFUL_REACHED + + # Otherwise, assume failure tolerance was exceeded + return CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + @dataclass(frozen=True) class BatchItem(Generic[R]): index: int @@ -98,16 +135,44 @@ class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059 completion_reason: CompletionReason @classmethod - def from_dict(cls, data: dict) -> BatchResult[R]: + def from_dict( + cls, data: dict, completion_config: CompletionConfig | None = None + ) -> BatchResult[R]: batch_items: list[BatchItem[R]] = [ BatchItem.from_dict(item) for item in data["all"] ] - # TODO: is this valid? assuming completion reason is ALL_COMPLETED? - completion_reason = CompletionReason( - data.get("completionReason", "ALL_COMPLETED") - ) + + completion_reason_value = data.get("completionReason") + if completion_reason_value is None: + # Infer completion reason from batch item statuses and completion config + # This aligns with the TypeScript implementation that uses completion config + # to accurately reconstruct the completion reason during replay + result = cls.from_items(batch_items, completion_config) + logger.warning( + "Missing completionReason in BatchResult deserialization, " + "inferred '%s' from batch item statuses. " + "This may indicate incomplete serialization data.", + result.completion_reason.value, + ) + return result + + completion_reason = CompletionReason(completion_reason_value) return cls(batch_items, completion_reason) + @classmethod + def from_items( + cls, + items: Iterable[BatchItem[R]], + completion_config: CompletionConfig | None = None, + ): + items = list(items) + # Infer completion reason from batch item statuses and completion config + # This aligns with the TypeScript implementation that uses completion config + # to accurately reconstruct the completion reason during replay + statuses = (item.status for item in items) + completion_reason = _get_completion_reason(statuses, completion_config) + return cls(items, completion_reason) + def to_dict(self) -> dict: return { "all": [item.to_dict() for item in self.all], @@ -163,19 +228,15 @@ def get_errors(self) -> list[ErrorObject]: @property def success_count(self) -> int: - return len( - [item for item in self.all if item.status is BatchItemStatus.SUCCEEDED] - ) + return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED) @property def failure_count(self) -> int: - return len([item for item in self.all if item.status is BatchItemStatus.FAILED]) + return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED) @property def started_count(self) -> int: - return len( - [item for item in self.all if item.status is BatchItemStatus.STARTED] - ) + return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED) @property def total_count(self) -> int: @@ -336,25 +397,66 @@ def fail_task(self) -> None: with self._lock: self.failure_count += 1 - def should_complete(self) -> bool: - """Check if execution should complete.""" + def should_continue(self) -> bool: + """ + Check if we should continue starting new tasks (based on failure tolerance). + Matches TypeScript shouldContinue() logic. + """ with self._lock: - # Success condition - if self.success_count >= self.min_successful: - return True + # If no completion config, only continue if no failures + if ( + self.tolerated_failure_count is None + and self.tolerated_failure_percentage is None + ): + return self.failure_count == 0 - # Failure conditions - if self._is_failure_condition_reached( - tolerated_count=self.tolerated_failure_count, - tolerated_percentage=self.tolerated_failure_percentage, - failure_count=self.failure_count, + # Check failure count tolerance + if ( + self.tolerated_failure_count is not None + and self.failure_count > self.tolerated_failure_count ): + return False + + # Check failure percentage tolerance + if self.tolerated_failure_percentage is not None and self.total_tasks > 0: + failure_percentage = (self.failure_count / self.total_tasks) * 100 + if failure_percentage > self.tolerated_failure_percentage: + return False + + return True + + def is_complete(self) -> bool: + """ + Check if execution should complete (based on completion criteria). + Matches TypeScript isComplete() logic. + """ + with self._lock: + completed_count = self.success_count + self.failure_count + + # All tasks completed + if completed_count == self.total_tasks: + # Complete if no failure tolerance OR no failures OR min successful reached + return ( + ( + self.tolerated_failure_count is None + and self.tolerated_failure_percentage is None + ) + or self.failure_count == 0 + or self.success_count >= self.min_successful + ) + + # Min successful reached (even if not all tasks completed) + if self.success_count >= self.min_successful: return True - # Impossible to succeed condition - # TODO: should this keep running? TS doesn't currently handle this either. - remaining_tasks = self.total_tasks - self.success_count - self.failure_count - return self.success_count + remaining_tasks < self.min_successful + return False + + def should_complete(self) -> bool: + """ + Check if execution should complete. + Combines TypeScript shouldContinue() and isComplete() logic. + """ + return self.is_complete() or not self.should_continue() def is_all_completed(self) -> bool: """True if all tasks completed successfully.""" @@ -663,15 +765,9 @@ def _create_result(self) -> BatchResult[ResultType]: ) ) - completion_reason: CompletionReason = ( - CompletionReason.ALL_COMPLETED - if self.counters.is_all_completed() - else ( - CompletionReason.MIN_SUCCESSFUL_REACHED - if self.counters.is_min_successful_reached() - else CompletionReason.FAILURE_TOLERANCE_EXCEEDED - ) - ) + # Use the same completion reason logic as _get_completion_reason for consistency + statuses = (item.status for item in batch_items) + completion_reason = _get_completion_reason(statuses, self.completion_config) return BatchResult(batch_items, completion_reason) diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index 38dc9d6..11c2d64 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -319,8 +319,178 @@ def test_batch_result_from_dict_default_completion_reason(): # No completionReason provided } - result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + # Verify warning was logged + mock_logger.warning.assert_called_once() + assert "Missing completionReason" in mock_logger.warning.call_args[0][0] + + +def test_batch_result_from_dict_infer_all_completed_all_succeeded(): + """Test BatchResult from_dict infers ALL_COMPLETED when all items succeeded.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed(): + """Test BatchResult from_dict infers FAILURE_TOLERANCE_EXCEEDED when all items failed.""" + error_data = { + "message": "Test error", + "type": "TestError", + "data": None, + "stackTrace": None, + } + data = { + "all": [ + {"index": 0, "status": "FAILED", "result": None, "error": error_data}, + {"index": 1, "status": "FAILED", "result": None, "error": error_data}, + ], + # No completionReason provided + } + + # even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED + # https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335 + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_all_completed_mixed_success_failure(): + """Test BatchResult from_dict infers ALL_COMPLETED when mix of success/failure but no started items.""" + error_data = { + "message": "Test error", + "type": "TestError", + "data": None, + "stackTrace": None, + } + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "FAILED", "result": None, "error": error_data}, + {"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + # the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_min_successful_reached_has_started(): + """Test BatchResult from_dict infers MIN_SUCCESSFUL_REACHED when items are still started.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "STARTED", "result": None, "error": None}, + {"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data, CompletionConfig(1)) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_empty_items(): + """Test BatchResult from_dict infers ALL_COMPLETED for empty items.""" + data = { + "all": [], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_with_explicit_completion_reason(): + """Test BatchResult from_dict uses explicit completionReason when provided.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None} + ], + "completionReason": "MIN_SUCCESSFUL_REACHED", + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + # No warning should be logged when completionReason is provided + mock_logger.warning.assert_not_called() + + +def test_batch_result_infer_completion_reason_edge_cases(): + """Test _infer_completion_reason method with various edge cases.""" + # Test with only started items + started_items = [ + BatchItem(0, BatchItemStatus.STARTED).to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + items = {"all": started_items} + batch = BatchResult.from_dict(items, CompletionConfig(0)) # SLF001 + # this state is not possible with CompletionConfig(0) + assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Test with only started items + started_items = [ + BatchItem(0, BatchItemStatus.STARTED).to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + items = {"all": started_items} + batch = BatchResult.from_dict(items) # SLF001 + # this state is not possible with CompletionConfig(0) + assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + # Test with only failed items + failed_items = [ + BatchItem( + 0, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + ] + failed_items = {"all": failed_items} + batch = BatchResult.from_dict(failed_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test with only succeeded items + succeeded_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem(1, BatchItemStatus.SUCCEEDED, "result2").to_dict(), + ] + succeeded_items = {"all": succeeded_items} + batch = BatchResult.from_dict(succeeded_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test with mixed but no started (all completed) + mixed_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + + batch = BatchResult.from_items(mixed_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED def test_batch_result_get_results_empty(): @@ -924,7 +1094,10 @@ def mock_run_in_child_context(func, name, config): assert len(result.all) == 2 assert result.all[0].status == BatchItemStatus.SUCCEEDED assert result.all[1].status == BatchItemStatus.FAILED - assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + # WHEN all items complete, THEN completion reason is ALL_COMPLETED. + # we don't consider thresholds and limits. + # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_concurrent_executor_execute_item_in_child_context(): @@ -1007,7 +1180,10 @@ def mock_run_in_child_context(func, name, config): return func(Mock()) result = executor.execute(execution_state, mock_run_in_child_context) - assert result.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + # WHEN all items complete, THEN completion reason is ALL_COMPLETED. + # we don't consider thresholds and limits. + # https://github.com/aws/aws-durable-execution-sdk-js/blob/ff8b72ef888dd47a840f36d4eb0ee84dd3b55a30/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.test.ts#L630-L655 + assert result.completion_reason == CompletionReason.ALL_COMPLETED def test_single_task_suspend_bubbles_up(): @@ -1692,3 +1868,94 @@ def test_timer_scheduler_future_time_condition_false(): # Callback should not be called since time is in future callback.assert_not_called() + + +def test_batch_result_from_dict_with_completion_config(): + """Test BatchResult from_dict with completion config parameter.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "STARTED", "result": None, "error": None}, + ], + # No completionReason provided + } + + # With started items, should infer MIN_SUCCESSFUL_REACHED + completion_config = CompletionConfig(min_successful=1) + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data, completion_config) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_all_completed(): + """Test BatchResult from_dict infers ALL_COMPLETED when all items are completed.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + { + "index": 1, + "status": "FAILED", + "result": None, + "error": { + "message": "error", + "type": "Error", + "data": None, + "stackTrace": None, + }, + }, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_backward_compatibility(): + """Test BatchResult from_dict maintains backward compatibility when no completion_config provided.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None} + ], + "completionReason": "MIN_SUCCESSFUL_REACHED", + } + + # Should work without completion_config parameter + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Should also work with None completion_config + result2 = BatchResult.from_dict(data, None) + assert result2.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + +def test_batch_result_infer_completion_reason_basic_cases(): + """Test _infer_completion_reason method with basic scenarios.""" + # Test with started items - should be MIN_SUCCESSFUL_REACHED + items = { + "all": [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + } + batch = BatchResult.from_dict(items, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Test with all completed items - should be ALL_COMPLETED + completed_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + ] + completed_items = {"all": completed_items} + batch = BatchResult.from_dict(completed_items, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test empty items - should be ALL_COMPLETED + batch = BatchResult.from_dict({"all": []}, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.ALL_COMPLETED