diff --git a/src/aws_durable_execution_sdk_python/concurrency.py b/src/aws_durable_execution_sdk_python/concurrency.py index baaa354..4797d05 100644 --- a/src/aws_durable_execution_sdk_python/concurrency.py +++ b/src/aws_durable_execution_sdk_python/concurrency.py @@ -824,5 +824,39 @@ def run_in_child_handler(): ), ) + def replay(self, execution_state: ExecutionState, executor_context: DurableContext): + """ + Replay rather than re-run children. + + if we are here, then we are in replay_children. + This will pre-generate all the operation ids for the children and collect the checkpointed + results. + """ + items: list[BatchItem[ResultType]] = [] + for executable in self.executables: + operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001 + executable.index + ) + checkpoint = execution_state.get_checkpoint_result(operation_id) + + result: ResultType | None = None + error = None + status: BatchItemStatus + if checkpoint.is_succeeded(): + status = BatchItemStatus.SUCCEEDED + result = self._execute_item_in_child_context( + executor_context, executable + ) + + elif checkpoint.is_failed(): + error = checkpoint.error + status = BatchItemStatus.FAILED + else: + status = BatchItemStatus.STARTED + + batch_item = BatchItem(executable.index, status, result=result, error=error) + items.append(batch_item) + return BatchResult.from_items(items, self.completion_config) + # endregion concurrency logic diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 10b9c08..2938ca9 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -317,7 +317,13 @@ def map( """Execute a callable for each item in parallel.""" map_name: str | None = self._resolve_step_name(name, func) - def map_in_child_context(map_context) -> BatchResult[R]: + operation_id = self._create_step_id() + operation_identifier = OperationIdentifier( + operation_id=operation_id, parent_id=self._parent_id, name=map_name + ) + map_context = self.create_child_context(parent_id=operation_id) + + def map_in_child_context() -> BatchResult[R]: # map_context is a child_context of the context upon which `.map` # was called. We are calling it `map_context` to make it explicit # that any operations happening from hereon are done on the context @@ -328,11 +334,13 @@ def map_in_child_context(map_context) -> BatchResult[R]: config=config, execution_state=self.state, map_context=map_context, + operation_identifier=operation_identifier, ) - return self.run_in_child_context( + return child_handler( func=map_in_child_context, - name=map_name, + state=self.state, + operation_identifier=operation_identifier, config=ChildConfig( sub_type=OperationSubType.MAP, serdes=config.serdes if config is not None else None, @@ -346,8 +354,14 @@ def parallel( config: ParallelConfig | None = None, ) -> BatchResult[T]: """Execute multiple callables in parallel.""" + # _create_step_id() is thread-safe. rest of method is safe, since using local copy of parent id + operation_id = self._create_step_id() + parallel_context = self.create_child_context(parent_id=operation_id) + operation_identifier = OperationIdentifier( + operation_id=operation_id, parent_id=self._parent_id, name=name + ) - def parallel_in_child_context(parallel_context) -> BatchResult[T]: + def parallel_in_child_context() -> BatchResult[T]: # parallel_context is a child_context of the context upon which `.map` # was called. We are calling it `parallel_context` to make it explicit # that any operations happening from hereon are done on the context @@ -357,11 +371,13 @@ def parallel_in_child_context(parallel_context) -> BatchResult[T]: config=config, execution_state=self.state, parallel_context=parallel_context, + operation_identifier=operation_identifier, ) - return self.run_in_child_context( + return child_handler( func=parallel_in_child_context, - name=name, + state=self.state, + operation_identifier=operation_identifier, config=ChildConfig( sub_type=OperationSubType.PARALLEL, serdes=config.serdes if config is not None else None, diff --git a/src/aws_durable_execution_sdk_python/operation/map.py b/src/aws_durable_execution_sdk_python/operation/map.py index d2d582c..ed76bb4 100644 --- a/src/aws_durable_execution_sdk_python/operation/map.py +++ b/src/aws_durable_execution_sdk_python/operation/map.py @@ -17,8 +17,12 @@ if TYPE_CHECKING: from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.serdes import SerDes - from aws_durable_execution_sdk_python.state import ExecutionState + from aws_durable_execution_sdk_python.state import ( + CheckpointedResult, + ExecutionState, + ) from aws_durable_execution_sdk_python.types import SummaryGenerator logger = logging.getLogger(__name__) @@ -94,6 +98,7 @@ def map_handler( config: MapConfig | None, execution_state: ExecutionState, map_context: DurableContext, + operation_identifier: OperationIdentifier, ) -> BatchResult[R]: """Execute a callable for each item in parallel.""" # Summary Generator Construction (matches TypeScript implementation): @@ -107,6 +112,13 @@ def map_handler( func=func, config=config or MapConfig(summary_generator=MapSummaryGenerator()), ) + + checkpoint: CheckpointedResult = execution_state.get_checkpoint_result( + operation_identifier.operation_id + ) + if checkpoint.is_succeeded(): + # if we've reached this point, then not only is the step succeeded, but it is also `replay_children`. + return executor.replay(execution_state, map_context) # we are making it explicit that we are now executing within the map_context return executor.execute(execution_state, executor_context=map_context) diff --git a/src/aws_durable_execution_sdk_python/operation/parallel.py b/src/aws_durable_execution_sdk_python/operation/parallel.py index b58251a..e81499f 100644 --- a/src/aws_durable_execution_sdk_python/operation/parallel.py +++ b/src/aws_durable_execution_sdk_python/operation/parallel.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from aws_durable_execution_sdk_python.concurrency import BatchResult from aws_durable_execution_sdk_python.context import DurableContext + from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import ExecutionState from aws_durable_execution_sdk_python.types import SummaryGenerator @@ -82,6 +83,7 @@ def parallel_handler( config: ParallelConfig | None, execution_state: ExecutionState, parallel_context: DurableContext, + operation_identifier: OperationIdentifier, ) -> BatchResult[R]: """Execute multiple operations in parallel.""" # Summary Generator Construction (matches TypeScript implementation): @@ -94,6 +96,12 @@ def parallel_handler( callables, config or ParallelConfig(summary_generator=ParallelSummaryGenerator()), ) + + checkpoint = execution_state.get_checkpoint_result( + operation_identifier.operation_id + ) + if checkpoint.is_succeeded(): + return executor.replay(execution_state, parallel_context) return executor.execute(execution_state, executor_context=parallel_context) diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index 5cb3d87..ea9c26f 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -22,7 +22,7 @@ ExecutionCounters, TimerScheduler, ) -from aws_durable_execution_sdk_python.config import CompletionConfig +from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig from aws_durable_execution_sdk_python.exceptions import ( CallableRuntimeError, InvalidStateError, @@ -30,6 +30,7 @@ TimedSuspendExecution, ) from aws_durable_execution_sdk_python.lambda_service import ErrorObject +from aws_durable_execution_sdk_python.operation.map import MapExecutor def test_batch_item_status_enum(): @@ -2535,3 +2536,143 @@ def create_child_context(operation_id): assert all( assoc1 == assoc2 for assoc1, assoc2 in combinations(associations_per_run, 2) ) + + +def test_concurrent_executor_replay_with_succeeded_operations(): + """Test ConcurrentExecutor replay method with succeeded operations.""" + + def func1(ctx, item, idx, items): + return f"result_{item}" + + items = ["a", "b"] + config = MapConfig() + + executor = MapExecutor.from_items( + items=items, + func=func1, + config=config, + ) + + # Mock execution state with succeeded operations + mock_execution_state = Mock() + mock_execution_state.durable_execution_arn = ( + "arn:aws:durable:us-east-1:123456789012:execution/test" + ) + + def mock_get_checkpoint_result(operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = True + mock_result.is_failed.return_value = False + mock_result.is_replay_children.return_value = False + mock_result.is_existent.return_value = True + # Provide properly serialized JSON data + mock_result.result = f'"cached_result_{operation_id}"' # JSON string + return mock_result + + mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result + + def mock_create_step_id_for_logical_step(step): + return f"op_{step}" + + # Mock executor context + mock_executor_context = Mock() + mock_executor_context._create_step_id_for_logical_step = ( # noqa + mock_create_step_id_for_logical_step + ) + + # Mock child context that has the same execution state + mock_child_context = Mock() + mock_child_context.state = mock_execution_state + mock_executor_context.create_child_context = Mock(return_value=mock_child_context) + mock_executor_context._parent_id = "parent_id" # noqa + + result = executor.replay(mock_execution_state, mock_executor_context) + + assert isinstance(result, BatchResult) + assert len(result.all) == 2 + assert result.all[0].status == BatchItemStatus.SUCCEEDED + assert result.all[0].result == "cached_result_op_0" + assert result.all[1].status == BatchItemStatus.SUCCEEDED + assert result.all[1].result == "cached_result_op_1" + + +def test_concurrent_executor_replay_with_failed_operations(): + """Test ConcurrentExecutor replay method with failed operations.""" + + def func1(ctx, item, idx, items): + return f"result_{item}" + + items = ["a"] + config = MapConfig() + + executor = MapExecutor.from_items( + items=items, + func=func1, + config=config, + ) + + # Mock execution state with failed operation + mock_execution_state = Mock() + + def mock_get_checkpoint_result(operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + mock_result.is_failed.return_value = True + mock_result.error = Exception("Test error") + return mock_result + + mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result + + # Mock executor context + mock_executor_context = Mock() + mock_executor_context._create_step_id_for_logical_step = Mock(return_value="op_1") # noqa: SLF001 + + result = executor.replay(mock_execution_state, mock_executor_context) + + assert isinstance(result, BatchResult) + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.FAILED + assert result.all[0].error is not None + + +def test_concurrent_executor_replay_with_replay_children(): + """Test ConcurrentExecutor replay method when children need re-execution.""" + + def func1(ctx, item, idx, items): + return f"result_{item}" + + items = ["a"] + config = MapConfig() + + executor = MapExecutor.from_items( + items=items, + func=func1, + config=config, + ) + + # Mock execution state with succeeded operation that needs replay + mock_execution_state = Mock() + + def mock_get_checkpoint_result(operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = True + mock_result.is_failed.return_value = False + mock_result.is_replay_children.return_value = True + return mock_result + + mock_execution_state.get_checkpoint_result = mock_get_checkpoint_result + + # Mock executor context + mock_executor_context = Mock() + mock_executor_context._create_step_id_for_logical_step = Mock(return_value="op_1") # noqa: SLF001 + + # Mock _execute_item_in_child_context to return a result + with patch.object( + executor, "_execute_item_in_child_context", return_value="re_executed_result" + ): + result = executor.replay(mock_execution_state, mock_executor_context) + + assert isinstance(result, BatchResult) + assert len(result.all) == 1 + assert result.all[0].status == BatchItemStatus.SUCCEEDED + assert result.all[0].result == "re_executed_result" diff --git a/tests/context_test.py b/tests/context_test.py index 5c85709..3804ee4 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -1086,7 +1086,7 @@ def run_child_context(callable_func, name): # region map -@patch("aws_durable_execution_sdk_python.context.map_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_map_basic(mock_handler): """Test map with basic parameters.""" mock_handler.return_value = "map_result" @@ -1100,22 +1100,19 @@ def test_function(context, item, index, items): inputs = [1, 2, 3] - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "map_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.map(inputs, test_function) + result = context.map(inputs, test_function) - assert result == "map_result" - mock_run_in_child.assert_called_once() + assert result == "map_result" + mock_handler.assert_called_once() - # Verify the child context callable - call_args = mock_run_in_child.call_args - assert call_args[1]["name"] is None # name should be None - assert call_args[1]["config"].sub_type.value == "Map" + # Verify the child handler was called with correct parameters + call_args = mock_handler.call_args + assert call_args[1]["config"].sub_type.value == "Map" -@patch("aws_durable_execution_sdk_python.context.map_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_map_with_name_and_config(mock_handler): """Test map with name and config.""" mock_handler.return_value = "configured_map_result" @@ -1132,18 +1129,18 @@ def test_function(context, item, index, items): inputs = ["a", "b", "c"] config = MapConfig() - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "configured_map_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.map(inputs, test_function, name="custom_map", config=config) + result = context.map(inputs, test_function, name="custom_map", config=config) - assert result == "configured_map_result" - call_args = mock_run_in_child.call_args - assert call_args[1]["name"] == "custom_map" # name should be custom_map + assert result == "configured_map_result" + call_args = mock_handler.call_args + assert ( + call_args[1]["operation_identifier"].name == "custom_map" + ) # name should be custom_map -@patch("aws_durable_execution_sdk_python.context.map_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_map_calls_handler_correctly(mock_handler): """Test map calls map_handler with correct parameters.""" mock_handler.return_value = "handler_result" @@ -1157,14 +1154,12 @@ def test_function(context, item, index, items): inputs = ["hello", "world"] - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "handler_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.map(inputs, test_function) + result = context.map(inputs, test_function) - assert result == "handler_result" - mock_run_in_child.assert_called_once() + assert result == "handler_result" + mock_handler.assert_called_once() @patch("aws_durable_execution_sdk_python.context.map_handler") @@ -1217,7 +1212,7 @@ def test_function(context, item, index, items): # region parallel -@patch("aws_durable_execution_sdk_python.context.parallel_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_parallel_basic(mock_handler): """Test parallel with basic parameters.""" mock_handler.return_value = "parallel_result" @@ -1234,22 +1229,19 @@ def task2(context): callables = [task1, task2] - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "parallel_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.parallel(callables) + result = context.parallel(callables) - assert result == "parallel_result" - mock_run_in_child.assert_called_once() + assert result == "parallel_result" + mock_handler.assert_called_once() - # Verify the child context callable - call_args = mock_run_in_child.call_args - assert call_args[1]["name"] is None # name should be None - assert call_args[1]["config"].sub_type.value == "Parallel" + # Verify the child handler was called with correct parameters + call_args = mock_handler.call_args + assert call_args[1]["config"].sub_type.value == "Parallel" -@patch("aws_durable_execution_sdk_python.context.parallel_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_parallel_with_name_and_config(mock_handler): """Test parallel with name and config.""" mock_handler.return_value = "configured_parallel_result" @@ -1267,20 +1259,18 @@ def task2(context): callables = [task1, task2] config = ParallelConfig() - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "configured_parallel_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.parallel(callables, name="custom_parallel", config=config) + result = context.parallel(callables, name="custom_parallel", config=config) - assert result == "configured_parallel_result" - call_args = mock_run_in_child.call_args - assert ( - call_args[1]["name"] == "custom_parallel" - ) # name should be custom_parallel + assert result == "configured_parallel_result" + call_args = mock_handler.call_args + assert ( + call_args[1]["operation_identifier"].name == "custom_parallel" + ) # name should be custom_parallel -@patch("aws_durable_execution_sdk_python.context.parallel_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_parallel_resolves_name_from_callable(mock_handler): """Test parallel resolves name from callable._original_name.""" mock_handler.return_value = "named_parallel_result" @@ -1301,23 +1291,21 @@ def task2(context): callables = [task1, task2] - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "named_parallel_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - # Use _resolve_step_name to test name resolution - resolved_name = context._resolve_step_name(None, mock_callable) # noqa: SLF001 - assert resolved_name == "parallel_tasks" + # Use _resolve_step_name to test name resolution + resolved_name = context._resolve_step_name(None, mock_callable) # noqa: SLF001 + assert resolved_name == "parallel_tasks" - context.parallel(callables) + context.parallel(callables) - call_args = mock_run_in_child.call_args - assert ( - call_args[1]["name"] is None - ) # name should be None since callables don't have _original_name + call_args = mock_handler.call_args + assert ( + call_args[1]["operation_identifier"].name is None + ) # name should be None since callables don't have _original_name -@patch("aws_durable_execution_sdk_python.context.parallel_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_parallel_calls_handler_correctly(mock_handler): """Test parallel calls parallel_handler with correct parameters.""" mock_handler.return_value = "handler_result" @@ -1334,14 +1322,12 @@ def task2(context): callables = [task1, task2] - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "handler_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.parallel(callables) + result = context.parallel(callables) - assert result == "handler_result" - mock_run_in_child.assert_called_once() + assert result == "handler_result" + mock_handler.assert_called_once() @patch("aws_durable_execution_sdk_python.context.parallel_handler") @@ -1417,7 +1403,7 @@ def task(context): # region map -@patch("aws_durable_execution_sdk_python.context.map_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_map_calls_handler(mock_handler): """Test map calls map_handler through run_in_child_context.""" mock_handler.return_value = "map_result" @@ -1432,17 +1418,15 @@ def test_function(context, item, index, items): inputs = ["a", "b", "c"] config = MapConfig() - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "map_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.map(inputs, test_function, config=config) + result = context.map(inputs, test_function, config=config) - assert result == "map_result" - mock_run_in_child.assert_called_once() + assert result == "map_result" + mock_handler.assert_called_once() -@patch("aws_durable_execution_sdk_python.context.parallel_handler") +@patch("aws_durable_execution_sdk_python.context.child_handler") def test_parallel_calls_handler(mock_handler): """Test parallel calls parallel_handler through run_in_child_context.""" mock_handler.return_value = "parallel_result" @@ -1460,14 +1444,12 @@ def task2(context): callables = [task1, task2] config = ParallelConfig() - with patch.object(DurableContext, "run_in_child_context") as mock_run_in_child: - mock_run_in_child.return_value = "parallel_result" - context = DurableContext(state=mock_state) + context = DurableContext(state=mock_state) - result = context.parallel(callables, config=config) + result = context.parallel(callables, config=config) - assert result == "parallel_result" - mock_run_in_child.assert_called_once() + assert result == "parallel_result" + mock_handler.assert_called_once() # region wait_for_condition diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index 87d6369..eb099d1 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -15,6 +15,7 @@ ItemBatcher, MapConfig, ) +from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler from tests.serdes_test import CustomStrSerDes @@ -116,13 +117,22 @@ def mock_run_in_child_context(func, name, config): # Create a minimal ExecutionState mock class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() config = MapConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") result = map_handler( - items, callable_func, config, execution_state, mock_run_in_child_context + items, + callable_func, + config, + execution_state, + mock_run_in_child_context, + operation_identifier, ) assert isinstance(result, BatchResult) @@ -139,15 +149,24 @@ def mock_run_in_child_context(func, name, config): return func("mock_context") class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") # Since MapConfig() is called in map_handler when config is None, # we need to provide a valid config to avoid the NameError # This tests the behavior when config is provided instead result = map_handler( - items, callable_func, MapConfig(), execution_state, mock_run_in_child_context + items, + callable_func, + MapConfig(), + execution_state, + mock_run_in_child_context, + operation_identifier, ) assert isinstance(result, BatchResult) @@ -236,13 +255,22 @@ def callable_func(ctx, item, idx, items): ) as mock_execute: class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() config = MapConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") result = map_handler( - items, callable_func, config, execution_state, executor_context + items, + callable_func, + config, + execution_state, + executor_context, + operation_identifier, ) # Verify execute was called @@ -274,12 +302,21 @@ def callable_func(ctx, item, idx, items): executor_context.create_child_context = lambda *args: Mock() class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") result = map_handler( - items, callable_func, None, execution_state, executor_context + items, + callable_func, + None, + execution_state, + executor_context, + operation_identifier, ) # Verify from_items was called with a MapConfig instance @@ -311,13 +348,22 @@ def callable_func(ctx, item, idx, items): executor_context.create_child_context = lambda *args: Mock() class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() config = MapConfig(serdes=CustomStrSerDes()) + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") result = map_handler( - items, callable_func, config, execution_state, executor_context + items, + callable_func, + config, + execution_state, + executor_context, + operation_identifier, ) # Verify execute was called @@ -341,12 +387,23 @@ def mock_summary_generator(result): executor_context.create_child_context = Mock(return_value=Mock()) class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") # Call map_handler - map_handler(items, callable_func, config, execution_state, executor_context) + map_handler( + items, + callable_func, + config, + execution_state, + executor_context, + operation_identifier, + ) # Verify that create_child_context was called twice (N=2 items) assert executor_context.create_child_context.call_count == 2 @@ -388,12 +445,23 @@ def callable_func(ctx, item, idx, items): executor_context.create_child_context = Mock(return_value=Mock()) # SLF001 class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") # Call map_handler with None config (should use default) - map_handler(items, callable_func, None, execution_state, executor_context) + map_handler( + items, + callable_func, + None, + execution_state, + executor_context, + operation_identifier, + ) # Verify that create_child_context was called once (N=1 item) assert executor_context.create_child_context.call_count == 1 @@ -438,9 +506,13 @@ def func(ctx, item, index, array): config = MapConfig(summary_generator=None) class MockExecutionState: - pass + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 @@ -455,17 +527,132 @@ class MockExecutionState: config=config, execution_state=execution_state, map_context=executor_context, + operation_identifier=operation_identifier, ) # Verify that create_child_context was called 3 times (N=3 items) assert executor_context.create_child_context.call_count == 3 - # Verify that _create_step_id_for_logical_step was called 3 times with unique values - assert executor_context._create_step_id_for_logical_step.call_count == 3 # noqa SLF001 - calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 - # Verify all calls have unique values - call_values = [call[0][0] for call in calls] - assert len(set(call_values)) == 3 # All unique + +def test_map_handler_replay_mechanism(): + """Test that map_handler uses replay when operation has already succeeded.""" + items = ["item1", "item2"] + + def callable_func(ctx, item, idx, items): + return f"result_{item}" + + # Mock execution state that indicates operation already succeeded + class MockExecutionState: + durable_execution_arn = "arn:aws:durable:us-east-1:123456789012:execution/test" + + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = True + mock_result.is_replay_children.return_value = False + # Provide properly serialized JSON data + mock_result.result = f'"cached_result_{operation_id}"' # JSON string + return mock_result + + execution_state = MockExecutionState() + config = MapConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + + # Mock map context + map_context = Mock() + map_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 + side_effect=["child_1", "child_2"] + ) + + # Mock the executor's replay method + with patch.object(MapExecutor, "replay") as mock_replay: + expected_batch_result = BatchResult( + all=[ + BatchItem( + index=0, + status=BatchItemStatus.SUCCEEDED, + result="cached_result_child_1", + ), + BatchItem( + index=1, + status=BatchItemStatus.SUCCEEDED, + result="cached_result_child_2", + ), + ], + completion_reason=CompletionReason.ALL_COMPLETED, + ) + mock_replay.return_value = expected_batch_result + + result = map_handler( + items, + callable_func, + config, + execution_state, + map_context, + operation_identifier, + ) + + # Verify replay was called instead of execute + mock_replay.assert_called_once_with(execution_state, map_context) + assert result == expected_batch_result + + +def test_map_handler_replay_with_replay_children(): + """Test map_handler replay when children need to be re-executed.""" + items = ["item1"] + + def callable_func(ctx, item, idx, items): + return f"result_{item}" + + # Mock execution state that indicates operation succeeded but children need replay + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + if operation_id == "test_op": + mock_result.is_succeeded.return_value = True + else: # child operations + mock_result.is_succeeded.return_value = True + mock_result.is_replay_children.return_value = True + return mock_result + + execution_state = MockExecutionState() + config = MapConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + + # Mock map context + map_context = Mock() + map_context._create_step_id_for_logical_step = Mock(return_value="child_1") # noqa: SLF001 + + # Mock the executor's replay method and _execute_item_in_child_context + with ( + patch.object(MapExecutor, "replay") as mock_replay, + patch.object( + MapExecutor, "_execute_item_in_child_context" + ) as mock_execute_item, + ): + mock_execute_item.return_value = "re_executed_result" + expected_batch_result = BatchResult( + all=[ + BatchItem( + index=0, + status=BatchItemStatus.SUCCEEDED, + result="re_executed_result", + ) + ], + completion_reason=CompletionReason.ALL_COMPLETED, + ) + mock_replay.return_value = expected_batch_result + + result = map_handler( + items, + callable_func, + config, + execution_state, + map_context, + operation_identifier, + ) + + mock_replay.assert_called_once_with(execution_state, map_context) + assert result == expected_batch_result def test_map_config_with_explicit_none_summary_generator(): @@ -493,3 +680,73 @@ def test_map_config_default_summary_generator_behavior(): config.summary_generator("test_data") if config.summary_generator else "" ) assert test_result == "" # noqa PLC1901 + + +def test_map_handler_first_execution_then_replay_integration(): + """Test map_handler called twice - first calls execute, second calls replay.""" + + def test_func(ctx, item, idx, items): + return f"processed_{item}" + + items = ["a", "b"] + config = MapConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_map") + + # Track whether we're in first or second execution + execution_count = 0 + + class MockExecutionState: + durable_execution_arn = "arn:aws:durable:us-east-1:123456789012:execution/test" + + def get_checkpoint_result(self, operation_id): + nonlocal execution_count + mock_result = Mock() + + if operation_id == "test_op": + # Main operation checkpoint + if execution_count == 0: + # First execution - operation not succeeded yet + mock_result.is_succeeded.return_value = False + else: + # Second execution - operation succeeded, trigger replay + mock_result.is_succeeded.return_value = True + + return mock_result + + execution_state = MockExecutionState() + map_context = Mock() + + with ( + patch( + "aws_durable_execution_sdk_python.operation.map.MapExecutor.execute" + ) as mock_execute, + patch( + "aws_durable_execution_sdk_python.operation.map.MapExecutor.replay" + ) as mock_replay, + ): + mock_execute.return_value = Mock() # Mock BatchResult + mock_replay.return_value = Mock() # Mock BatchResult + + # FIRST EXECUTION - should call execute + execution_count = 0 + map_handler( + items, test_func, config, execution_state, map_context, operation_identifier + ) + + # Verify execute was called, replay was not + mock_execute.assert_called_once() + mock_replay.assert_not_called() + + # Reset mocks for second call + mock_execute.reset_mock() + mock_replay.reset_mock() + + # SECOND EXECUTION - should call replay + execution_count = 1 + map_handler( + items, test_func, config, execution_state, map_context, operation_identifier + ) + + # Verify replay was called, execute was not + mock_replay.assert_called_once() + mock_execute.assert_not_called() diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 1dcb5d8..54f2229 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -14,6 +14,7 @@ Executable, ) from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig +from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation.parallel import ( ParallelExecutor, @@ -144,7 +145,15 @@ def func2(ctx): callables = [func1, func2] config = ParallelConfig(max_concurrency=2) - execution_state = Mock() + + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") # Mock the run_in_child_context function def mock_run_in_child_context(callable_func, name, child_config): @@ -157,7 +166,11 @@ def mock_run_in_child_context(callable_func, name, child_config): with patch.object(ParallelExecutor, "execute", return_value=mock_batch_result): result = parallel_handler( - callables, config, execution_state, mock_run_in_child_context + callables, + config, + execution_state, + mock_run_in_child_context, + operation_identifier, ) assert result == mock_batch_result @@ -170,7 +183,15 @@ def func1(ctx): return "result1" callables = [func1] - execution_state = Mock() + + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") def mock_run_in_child_context(callable_func, name, child_config): return callable_func("mock-context") @@ -182,7 +203,11 @@ def mock_run_in_child_context(callable_func, name, child_config): with patch.object(ParallelExecutor, "execute", return_value=mock_batch_result): result = parallel_handler( - callables, None, execution_state, mock_run_in_child_context + callables, + None, + execution_state, + mock_run_in_child_context, + operation_identifier, ) assert result == mock_batch_result @@ -196,7 +221,15 @@ def func1(ctx): callables = [func1] config = ParallelConfig(max_concurrency=5) - execution_state = Mock() + + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 @@ -208,7 +241,9 @@ def func1(ctx): mock_executor.execute.return_value = mock_batch_result mock_from_callables.return_value = mock_executor - result = parallel_handler(callables, config, execution_state, executor_context) + result = parallel_handler( + callables, config, execution_state, executor_context, operation_identifier + ) mock_from_callables.assert_called_once_with(callables, config) mock_executor.execute.assert_called_once_with( @@ -224,7 +259,15 @@ def func1(ctx): return "result1" callables = [func1] - execution_state = Mock() + + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 @@ -236,7 +279,9 @@ def func1(ctx): mock_executor.execute.return_value = mock_batch_result mock_from_callables.return_value = mock_executor - result = parallel_handler(callables, None, execution_state, executor_context) + result = parallel_handler( + callables, None, execution_state, executor_context, operation_identifier + ) assert result == mock_batch_result # Verify that a default ParallelConfig was created @@ -313,7 +358,15 @@ def func1(ctx): return "RESULT1" callables = [func1] - execution_state = Mock() + + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 @@ -324,6 +377,7 @@ def func1(ctx): ParallelConfig(serdes=CustomStrSerDes()), execution_state, executor_context, + operation_identifier, ) assert result.all[0].result == "RESULT1" @@ -340,14 +394,24 @@ def mock_summary_generator(result): callables = [func1] config = ParallelConfig(summary_generator=mock_summary_generator) - execution_state = Mock() + + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock(return_value="1") # noqa SLF001 executor_context.create_child_context = Mock(return_value=Mock()) # Call parallel_handler - parallel_handler(callables, config, execution_state, executor_context) + parallel_handler( + callables, config, execution_state, executor_context, operation_identifier + ) # Verify that create_child_context was called once (N=1 job) assert executor_context.create_child_context.call_count == 1 @@ -384,14 +448,24 @@ def func2(ctx): return "result2" callables = [func1, func2] - execution_state = Mock() + + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2"]) # noqa SLF001 executor_context.create_child_context = Mock(return_value=Mock()) # Call parallel_handler with None config (should use default) - parallel_handler(callables, None, execution_state, executor_context) + parallel_handler( + callables, None, execution_state, executor_context, operation_identifier + ) # Verify that create_child_context was called twice (N=2 jobs) assert executor_context.create_child_context.call_count == 2 @@ -419,7 +493,14 @@ def func3(ctx): # Explicitly set summary_generator to None config = ParallelConfig(summary_generator=None) - execution_state = Mock() + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = False + return mock_result + + execution_state = MockExecutionState() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") executor_context = Mock() executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 @@ -433,17 +514,127 @@ def func3(ctx): config=config, execution_state=execution_state, parallel_context=executor_context, + operation_identifier=operation_identifier, ) # Verify that create_child_context was called 3 times (N=3 jobs) assert executor_context.create_child_context.call_count == 3 - # Verify that _create_step_id_for_logical_step was called 3 times with unique values - assert executor_context._create_step_id_for_logical_step.call_count == 3 # noqa SLF001 - calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 - # Verify all calls have unique values - call_values = [call[0][0] for call in calls] - assert len(set(call_values)) == 3 # All unique + +def test_parallel_handler_replay_mechanism(): + """Test that parallel_handler uses replay when operation has already succeeded.""" + + def func1(ctx): + return "result1" + + def func2(ctx): + return "result2" + + callables = [func1, func2] + + # Mock execution state that indicates operation already succeeded + class MockExecutionState: + durable_execution_arn = "arn:aws:durable:us-east-1:123456789012:execution/test" + + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + mock_result.is_succeeded.return_value = True + mock_result.is_replay_children.return_value = False + # Provide properly serialized JSON data + mock_result.result = f'"cached_result_{operation_id}"' # JSON string + return mock_result + + execution_state = MockExecutionState() + config = ParallelConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + + # Mock parallel context + parallel_context = Mock() + parallel_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 + side_effect=["child_1", "child_2"] + ) + + # Mock the executor's replay method + with patch.object(ParallelExecutor, "replay") as mock_replay: + expected_batch_result = BatchResult( + all=[ + BatchItem( + index=0, + status=BatchItemStatus.SUCCEEDED, + result="cached_result_child_1", + ), + BatchItem( + index=1, + status=BatchItemStatus.SUCCEEDED, + result="cached_result_child_2", + ), + ], + completion_reason=CompletionReason.ALL_COMPLETED, + ) + mock_replay.return_value = expected_batch_result + + result = parallel_handler( + callables, config, execution_state, parallel_context, operation_identifier + ) + + # Verify replay was called instead of execute + mock_replay.assert_called_once_with(execution_state, parallel_context) + assert result == expected_batch_result + + +def test_parallel_handler_replay_with_replay_children(): + """Test parallel_handler replay when children need to be re-executed.""" + + def func1(ctx): + return "result1" + + callables = [func1] + + # Mock execution state that indicates operation succeeded but children need replay + class MockExecutionState: + def get_checkpoint_result(self, operation_id): + mock_result = Mock() + if operation_id == "test_op": + mock_result.is_succeeded.return_value = True + else: # child operations + mock_result.is_succeeded.return_value = True + mock_result.is_replay_children.return_value = True + return mock_result + + execution_state = MockExecutionState() + config = ParallelConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + + # Mock parallel context + parallel_context = Mock() + parallel_context._create_step_id_for_logical_step = Mock(return_value="child_1") # noqa: SLF001 + + # Mock the executor's replay method and _execute_item_in_child_context + with ( + patch.object(ParallelExecutor, "replay") as mock_replay, + patch.object( + ParallelExecutor, "_execute_item_in_child_context" + ) as mock_execute_item, + ): + mock_execute_item.return_value = "re_executed_result" + expected_batch_result = BatchResult( + all=[ + BatchItem( + index=0, + status=BatchItemStatus.SUCCEEDED, + result="re_executed_result", + ) + ], + completion_reason=CompletionReason.ALL_COMPLETED, + ) + mock_replay.return_value = expected_batch_result + + result = parallel_handler( + callables, config, execution_state, parallel_context, operation_identifier + ) + + mock_replay.assert_called_once_with(execution_state, parallel_context) + assert result == expected_batch_result def test_parallel_config_with_explicit_none_summary_generator(): @@ -470,3 +661,76 @@ def test_parallel_config_default_summary_generator_behavior(): ) assert test_result == "" # noqa PLC1901 assert config.serdes is None + + +def test_parallel_handler_first_execution_then_replay(): + """Test parallel_handler called twice - first calls execute, second calls replay.""" + + def task1(ctx): + return "result1" + + def task2(ctx): + return "result2" + + callables = [task1, task2] + config = ParallelConfig() + operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel") + + # Track whether we're in first or second execution + execution_count = 0 + + class MockExecutionState: + durable_execution_arn = "arn:aws:durable:us-east-1:123456789012:execution/test" + + def get_checkpoint_result(self, operation_id): + nonlocal execution_count + mock_result = Mock() + + if operation_id == "test_op": + # Main operation checkpoint + if execution_count == 0: + # First execution - operation not succeeded yet + mock_result.is_succeeded.return_value = False + else: + # Second execution - operation succeeded, trigger replay + mock_result.is_succeeded.return_value = True + + return mock_result + + execution_state = MockExecutionState() + parallel_context = Mock() + + with ( + patch( + "aws_durable_execution_sdk_python.operation.parallel.ParallelExecutor.execute" + ) as mock_execute, + patch( + "aws_durable_execution_sdk_python.operation.parallel.ParallelExecutor.replay" + ) as mock_replay, + ): + mock_execute.return_value = Mock() # Mock BatchResult + mock_replay.return_value = Mock() # Mock BatchResult + + # FIRST EXECUTION - should call execute + execution_count = 0 + parallel_handler( + callables, config, execution_state, parallel_context, operation_identifier + ) + + # Verify execute was called, replay was not + mock_execute.assert_called_once() + mock_replay.assert_not_called() + + # Reset mocks for second call + mock_execute.reset_mock() + mock_replay.reset_mock() + + # SECOND EXECUTION - should call replay + execution_count = 1 + parallel_handler( + callables, config, execution_state, parallel_context, operation_identifier + ) + + # Verify replay was called, execute was not + mock_replay.assert_called_once() + mock_execute.assert_not_called()