diff --git a/src/aws_durable_execution_sdk_python/concurrency.py b/src/aws_durable_execution_sdk_python/concurrency.py index b5af8c6..e102c92 100644 --- a/src/aws_durable_execution_sdk_python/concurrency.py +++ b/src/aws_durable_execution_sdk_python/concurrency.py @@ -7,6 +7,7 @@ import threading import time from abc import ABC, abstractmethod +from collections import Counter from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum @@ -98,16 +99,69 @@ class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059 completion_reason: CompletionReason @classmethod - def from_dict(cls, data: dict) -> BatchResult[R]: + def from_dict( + cls, data: dict, completion_config: CompletionConfig | None = None + ) -> BatchResult[R]: batch_items: list[BatchItem[R]] = [ BatchItem.from_dict(item) for item in data["all"] ] - # TODO: is this valid? assuming completion reason is ALL_COMPLETED? - completion_reason = CompletionReason( - data.get("completionReason", "ALL_COMPLETED") - ) + + completion_reason_value = data.get("completionReason") + if completion_reason_value is None: + # Infer completion reason from batch item statuses and completion config + # This aligns with the TypeScript implementation that uses completion config + # to accurately reconstruct the completion reason during replay + result = cls.from_items(batch_items, completion_config) + logger.warning( + "Missing completionReason in BatchResult deserialization, " + "inferred '%s' from batch item statuses. " + "This may indicate incomplete serialization data.", + result.completion_reason.value, + ) + return result + + completion_reason = CompletionReason(completion_reason_value) return cls(batch_items, completion_reason) + @classmethod + def from_items( + cls, + items: list[BatchItem[R]], + completion_config: CompletionConfig | None = None, + ): + """ + Infer completion reason based on batch item statuses and completion config. + + This follows the same logic as the TypeScript implementation: + - If all items completed: ALL_COMPLETED + - If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED + - Otherwise: FAILURE_TOLERANCE_EXCEEDED + """ + + statuses = (item.status for item in items) + counts = Counter(statuses) + succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0) + failed_count = counts.get(BatchItemStatus.FAILED, 0) + started_count = counts.get(BatchItemStatus.STARTED, 0) + + completed_count = succeeded_count + failed_count + total_count = started_count + completed_count + + # If all items completed (no started items), it's ALL_COMPLETED + if completed_count == total_count: + completion_reason = CompletionReason.ALL_COMPLETED + elif ( # If we have completion config and minSuccessful threshold is met + completion_config + and (min_successful := completion_config.min_successful) is not None + and succeeded_count >= min_successful + ): + completion_reason = CompletionReason.MIN_SUCCESSFUL_REACHED + else: + # Otherwise, assume failure tolerance was exceeded + completion_reason = CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + return cls(items, completion_reason) + def to_dict(self) -> dict: return { "all": [item.to_dict() for item in self.all], @@ -163,19 +217,15 @@ def get_errors(self) -> list[ErrorObject]: @property def success_count(self) -> int: - return len( - [item for item in self.all if item.status is BatchItemStatus.SUCCEEDED] - ) + return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED) @property def failure_count(self) -> int: - return len([item for item in self.all if item.status is BatchItemStatus.FAILED]) + return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED) @property def started_count(self) -> int: - return len( - [item for item in self.all if item.status is BatchItemStatus.STARTED] - ) + return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED) @property def total_count(self) -> int: diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index 38dc9d6..c024749 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -319,8 +319,178 @@ def test_batch_result_from_dict_default_completion_reason(): # No completionReason provided } - result = BatchResult.from_dict(data) - assert result.completion_reason == CompletionReason.ALL_COMPLETED + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + # Verify warning was logged + mock_logger.warning.assert_called_once() + assert "Missing completionReason" in mock_logger.warning.call_args[0][0] + + +def test_batch_result_from_dict_infer_all_completed_all_succeeded(): + """Test BatchResult from_dict infers ALL_COMPLETED when all items succeeded.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_failure_tolerance_exceeded_all_failed(): + """Test BatchResult from_dict infers FAILURE_TOLERANCE_EXCEEDED when all items failed.""" + error_data = { + "message": "Test error", + "type": "TestError", + "data": None, + "stackTrace": None, + } + data = { + "all": [ + {"index": 0, "status": "FAILED", "result": None, "error": error_data}, + {"index": 1, "status": "FAILED", "result": None, "error": error_data}, + ], + # No completionReason provided + } + + # even if everything has failed, if we've completed all items, then we've finished as ALL_COMPLETED + # https://github.com/aws/aws-durable-execution-sdk-js/blob/f20396f24afa9d6539d8e5056ee851ac7ef62301/packages/aws-durable-execution-sdk-js/src/handlers/concurrent-execution-handler/concurrent-execution-handler.ts#L324-L335 + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_all_completed_mixed_success_failure(): + """Test BatchResult from_dict infers ALL_COMPLETED when mix of success/failure but no started items.""" + error_data = { + "message": "Test error", + "type": "TestError", + "data": None, + "stackTrace": None, + } + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "FAILED", "result": None, "error": error_data}, + {"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + # the logic is that when \every item i: hasCompleted(i) then terminate due to all_completed + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_min_successful_reached_has_started(): + """Test BatchResult from_dict infers MIN_SUCCESSFUL_REACHED when items are still started.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "STARTED", "result": None, "error": None}, + {"index": 2, "status": "SUCCEEDED", "result": "result2", "error": None}, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data, CompletionConfig(1)) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_infer_empty_items(): + """Test BatchResult from_dict infers ALL_COMPLETED for empty items.""" + data = { + "all": [], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_with_explicit_completion_reason(): + """Test BatchResult from_dict uses explicit completionReason when provided.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None} + ], + "completionReason": "MIN_SUCCESSFUL_REACHED", + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + # No warning should be logged when completionReason is provided + mock_logger.warning.assert_not_called() + + +def test_batch_result_infer_completion_reason_edge_cases(): + """Test _infer_completion_reason method with various edge cases.""" + # Test with only started items + started_items = [ + BatchItem(0, BatchItemStatus.STARTED).to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + items = {"all": started_items} + batch = BatchResult.from_dict(items, CompletionConfig(0)) # SLF001 + # this state is not possible with CompletionConfig(0) + assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Test with only started items + started_items = [ + BatchItem(0, BatchItemStatus.STARTED).to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + items = {"all": started_items} + batch = BatchResult.from_dict(items) # SLF001 + # this state is not possible with CompletionConfig(0) + assert batch.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED + + # Test with only failed items + failed_items = [ + BatchItem( + 0, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + ] + failed_items = {"all": failed_items} + batch = BatchResult.from_dict(failed_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test with only succeeded items + succeeded_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem(1, BatchItemStatus.SUCCEEDED, "result2").to_dict(), + ] + succeeded_items = {"all": succeeded_items} + batch = BatchResult.from_dict(succeeded_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test with mixed but no started (all completed) + mixed_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1"), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ), + ] + + batch = BatchResult.from_items(mixed_items) # SLF001 + assert batch.completion_reason == CompletionReason.ALL_COMPLETED def test_batch_result_get_results_empty(): @@ -1692,3 +1862,94 @@ def test_timer_scheduler_future_time_condition_false(): # Callback should not be called since time is in future callback.assert_not_called() + + +def test_batch_result_from_dict_with_completion_config(): + """Test BatchResult from_dict with completion config parameter.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + {"index": 1, "status": "STARTED", "result": None, "error": None}, + ], + # No completionReason provided + } + + # With started items, should infer MIN_SUCCESSFUL_REACHED + completion_config = CompletionConfig(min_successful=1) + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data, completion_config) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_all_completed(): + """Test BatchResult from_dict infers ALL_COMPLETED when all items are completed.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None}, + { + "index": 1, + "status": "FAILED", + "result": None, + "error": { + "message": "error", + "type": "Error", + "data": None, + "stackTrace": None, + }, + }, + ], + # No completionReason provided + } + + with patch("aws_durable_execution_sdk_python.concurrency.logger") as mock_logger: + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.ALL_COMPLETED + mock_logger.warning.assert_called_once() + + +def test_batch_result_from_dict_backward_compatibility(): + """Test BatchResult from_dict maintains backward compatibility when no completion_config provided.""" + data = { + "all": [ + {"index": 0, "status": "SUCCEEDED", "result": "result1", "error": None} + ], + "completionReason": "MIN_SUCCESSFUL_REACHED", + } + + # Should work without completion_config parameter + result = BatchResult.from_dict(data) + assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Should also work with None completion_config + result2 = BatchResult.from_dict(data, None) + assert result2.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + +def test_batch_result_infer_completion_reason_basic_cases(): + """Test _infer_completion_reason method with basic scenarios.""" + # Test with started items - should be MIN_SUCCESSFUL_REACHED + items = { + "all": [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem(1, BatchItemStatus.STARTED).to_dict(), + ] + } + batch = BatchResult.from_dict(items, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED + + # Test with all completed items - should be ALL_COMPLETED + completed_items = [ + BatchItem(0, BatchItemStatus.SUCCEEDED, "result1").to_dict(), + BatchItem( + 1, BatchItemStatus.FAILED, error=ErrorObject("msg", "Error", None, None) + ).to_dict(), + ] + completed_items = {"all": completed_items} + batch = BatchResult.from_dict(completed_items, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + # Test empty items - should be ALL_COMPLETED + batch = BatchResult.from_dict({"all": []}, CompletionConfig(1)) + assert batch.completion_reason == CompletionReason.ALL_COMPLETED