diff --git a/src/aws_durable_execution_sdk_python/concurrency.py b/src/aws_durable_execution_sdk_python/concurrency.py index b7ddc24..baaa354 100644 --- a/src/aws_durable_execution_sdk_python/concurrency.py +++ b/src/aws_durable_execution_sdk_python/concurrency.py @@ -19,17 +19,20 @@ SuspendExecution, TimedSuspendExecution, ) +from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ErrorObject +from aws_durable_execution_sdk_python.operation.child import child_handler from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol if TYPE_CHECKING: from collections.abc import Callable from aws_durable_execution_sdk_python.config import CompletionConfig + from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import ExecutionState - from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator + from aws_durable_execution_sdk_python.types import SummaryGenerator logger = logging.getLogger(__name__) @@ -615,12 +618,7 @@ def execute_item( raise NotImplementedError def execute( - self, - execution_state: ExecutionState, - run_in_child_context: Callable[ - [Callable[[DurableContext], ResultType], str | None, ChildConfig | None], - ResultType, - ], + self, execution_state: ExecutionState, executor_context: DurableContext ) -> BatchResult[ResultType]: """Execute items concurrently with event-driven state management.""" logger.debug( @@ -649,7 +647,7 @@ def submit_task(executable_with_state: ExecutableWithState) -> None: """Submit task to the thread executor and mark its state as started.""" future = thread_executor.submit( self._execute_item_in_child_context, - run_in_child_context, + executor_context, executable_with_state.executable, ) executable_with_state.run(future) @@ -784,21 +782,42 @@ def _create_result(self) -> BatchResult[ResultType]: def _execute_item_in_child_context( self, - run_in_child_context: Callable[ - [Callable[[DurableContext], ResultType], str | None, ChildConfig | None], - ResultType, - ], + executor_context: DurableContext, executable: Executable[CallableType], ) -> ResultType: - """Execute a single item in a child context.""" + """ + Execute a single item in a derived child context. + + instead of relying on `executor_context.run_in_child_context` + we generate an operation_id for the child, and then call `child_handler` + directly. This avoids the hidden mutation of the context's internal counter. + we can do this because we explicitly control the generation of step_id and do it + using executable.index. + + + invariant: `operation_id` for a given executable is deterministic, + and execution order invariant. + """ + + operation_id = executor_context._create_step_id_for_logical_step( # noqa: SLF001 + executable.index + ) + name = f"{self.name_prefix}{executable.index}" + child_context = executor_context.create_child_context(operation_id) + operation_identifier = OperationIdentifier( + operation_id, + executor_context._parent_id, # noqa: SLF001 + name, + ) - def execute_in_child_context(child_context: DurableContext) -> ResultType: + def run_in_child_handler(): return self.execute_item(child_context, executable) - return run_in_child_context( - execute_in_child_context, - f"{self.name_prefix}{executable.index}", - ChildConfig( + return child_handler( + run_in_child_handler, + child_context.state, + operation_identifier=operation_identifier, + config=ChildConfig( serdes=self.item_serdes or self.serdes, sub_type=self.sub_type_iteration, summary_generator=self.summary_generator, diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index bd647e0..10b9c08 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -222,6 +222,15 @@ def set_logger(self, new_logger: LoggerInterface): info=self._log_info, ) + def _create_step_id_for_logical_step(self, step: int) -> str: + """ + Generate a step_id based on the given logical step. + This allows us to recover operation ids or even look + forward without changing the internal state of this context. + """ + step_id = f"{self._parent_id}-{step}" if self._parent_id else str(step) + return hashlib.blake2b(step_id.encode()).hexdigest()[:64] + def _create_step_id(self) -> str: """Generate a thread-safe step id, incrementing in order of invocation. @@ -229,10 +238,7 @@ def _create_step_id(self) -> str: the id generated by this method. It is subject to change without notice. """ new_counter: int = self._step_counter.increment() - step_id = ( - f"{self._parent_id}-{new_counter}" if self._parent_id else str(new_counter) - ) - return hashlib.blake2b(step_id.encode()).hexdigest()[:64] + return self._create_step_id_for_logical_step(new_counter) # region Operations @@ -311,13 +317,17 @@ def map( """Execute a callable for each item in parallel.""" map_name: str | None = self._resolve_step_name(name, func) - def map_in_child_context(child_context) -> BatchResult[R]: + def map_in_child_context(map_context) -> BatchResult[R]: + # map_context is a child_context of the context upon which `.map` + # was called. We are calling it `map_context` to make it explicit + # that any operations happening from hereon are done on the context + # that owns the branches return map_handler( items=inputs, func=func, config=config, execution_state=self.state, - run_in_child_context=child_context.run_in_child_context, + map_context=map_context, ) return self.run_in_child_context( @@ -337,12 +347,16 @@ def parallel( ) -> BatchResult[T]: """Execute multiple callables in parallel.""" - def parallel_in_child_context(child_context) -> BatchResult[T]: + def parallel_in_child_context(parallel_context) -> BatchResult[T]: + # parallel_context is a child_context of the context upon which `.map` + # was called. We are calling it `parallel_context` to make it explicit + # that any operations happening from hereon are done on the context + # that owns the branches return parallel_handler( callables=functions, config=config, execution_state=self.state, - run_in_child_context=child_context.run_in_child_context, + parallel_context=parallel_context, ) return self.run_in_child_context( diff --git a/src/aws_durable_execution_sdk_python/operation/map.py b/src/aws_durable_execution_sdk_python/operation/map.py index 0d851ef..d2d582c 100644 --- a/src/aws_durable_execution_sdk_python/operation/map.py +++ b/src/aws_durable_execution_sdk_python/operation/map.py @@ -16,10 +16,10 @@ from aws_durable_execution_sdk_python.lambda_service import OperationSubType if TYPE_CHECKING: - from aws_durable_execution_sdk_python.config import ChildConfig + from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import ExecutionState - from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator + from aws_durable_execution_sdk_python.types import SummaryGenerator logger = logging.getLogger(__name__) @@ -93,9 +93,7 @@ def map_handler( func: Callable, config: MapConfig | None, execution_state: ExecutionState, - run_in_child_context: Callable[ - [Callable[[DurableContext], R], str | None, ChildConfig | None], R - ], + map_context: DurableContext, ) -> BatchResult[R]: """Execute a callable for each item in parallel.""" # Summary Generator Construction (matches TypeScript implementation): @@ -109,7 +107,8 @@ def map_handler( func=func, config=config or MapConfig(summary_generator=MapSummaryGenerator()), ) - return executor.execute(execution_state, run_in_child_context) + # we are making it explicit that we are now executing within the map_context + return executor.execute(execution_state, executor_context=map_context) class MapSummaryGenerator: diff --git a/src/aws_durable_execution_sdk_python/operation/parallel.py b/src/aws_durable_execution_sdk_python/operation/parallel.py index 78330a9..b58251a 100644 --- a/src/aws_durable_execution_sdk_python/operation/parallel.py +++ b/src/aws_durable_execution_sdk_python/operation/parallel.py @@ -13,10 +13,10 @@ if TYPE_CHECKING: from aws_durable_execution_sdk_python.concurrency import BatchResult - from aws_durable_execution_sdk_python.config import ChildConfig + from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import ExecutionState - from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator + from aws_durable_execution_sdk_python.types import SummaryGenerator logger = logging.getLogger(__name__) @@ -81,9 +81,7 @@ def parallel_handler( callables: Sequence[Callable], config: ParallelConfig | None, execution_state: ExecutionState, - run_in_child_context: Callable[ - [Callable[[DurableContext], R], str | None, ChildConfig | None], R - ], + parallel_context: DurableContext, ) -> BatchResult[R]: """Execute multiple operations in parallel.""" # Summary Generator Construction (matches TypeScript implementation): @@ -96,7 +94,7 @@ def parallel_handler( callables, config or ParallelConfig(summary_generator=ParallelSummaryGenerator()), ) - return executor.execute(execution_state, run_in_child_context) + return executor.execute(execution_state, executor_context=parallel_context) class ParallelSummaryGenerator: diff --git a/tests/concurrency_test.py b/tests/concurrency_test.py index d3d090b..5cb3d87 100644 --- a/tests/concurrency_test.py +++ b/tests/concurrency_test.py @@ -1,8 +1,11 @@ """Tests for the concurrency module.""" +import random import threading import time from concurrent.futures import Future +from functools import partial +from itertools import combinations from unittest.mock import Mock, patch import pytest @@ -1086,10 +1089,11 @@ def failure_callable(): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() - result = executor.execute(execution_state, mock_run_in_child_context) + result = executor.execute(execution_state, executor_context) assert len(result.all) == 2 assert result.all[0].status == BatchItemStatus.SUCCEEDED @@ -1124,11 +1128,12 @@ def execute_item(self, child_context, executable): serdes=None, ) - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() result = executor._execute_item_in_child_context( # noqa: SLF001 - mock_run_in_child_context, executables[0] + executor_context, executables[0] ) assert result == "result_0" @@ -1214,12 +1219,13 @@ def execute_item(self, child_context, executable): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() # Should raise TimedSuspendExecution since no other tasks running with pytest.raises(TimedSuspendExecution): - executor.execute(execution_state, mock_run_in_child_context) + executor.execute(execution_state, executor_context) def test_multiple_tasks_one_suspends_execution_continues(): @@ -1259,12 +1265,13 @@ def execute_item(self, child_context, executable): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() # Should raise TimedSuspendExecution after Task B completes with pytest.raises(TimedSuspendExecution): - executor.execute(execution_state, mock_run_in_child_context) + executor.execute(execution_state, executor_context) # Assert that Task B did complete before suspension assert executor.task_b_completed @@ -1303,12 +1310,13 @@ def execute_item(self, child_context, executable): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() # Should raise TimedSuspendExecution since single task suspends with pytest.raises(TimedSuspendExecution): - executor.execute(execution_state, mock_run_in_child_context) + executor.execute(execution_state, executor_context) def test_concurrent_executor_with_timed_resubmit_while_other_task_running(): @@ -1375,11 +1383,12 @@ def execute_item(self, child_context, executable): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() # Should complete successfully after B resubmits and both tasks finish - result = executor.execute(execution_state, mock_run_in_child_context) + result = executor.execute(execution_state, executor_context) # Verify results assert len(result.all) == 2 @@ -1517,10 +1526,11 @@ def failure_callable(): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() - result = executor.execute(execution_state, mock_run_in_child_context) + result = executor.execute(execution_state, executor_context) assert len(result.all) == 1 assert result.all[0].status == BatchItemStatus.FAILED @@ -1741,10 +1751,11 @@ def failure_callable(): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() - result = executor.execute(execution_state, mock_run_in_child_context) + result = executor.execute(execution_state, executor_context) assert len(result.all) == 1 assert result.all[0].status == BatchItemStatus.FAILED @@ -1802,10 +1813,11 @@ def success_callable(): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() - result = executor.execute(execution_state, mock_run_in_child_context) + result = executor.execute(execution_state, executor_context) assert len(result.all) == 1 assert result.all[0].status == BatchItemStatus.SUCCEEDED @@ -1843,12 +1855,13 @@ def suspend_callable(): execution_state = Mock() execution_state.create_checkpoint = Mock() - def mock_run_in_child_context(func, name, config): - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() # Should raise SuspendExecution since single task suspends with pytest.raises(SuspendExecution): - executor.execute(execution_state, mock_run_in_child_context) + executor.execute(execution_state, executor_context) # Tests for _create_result method match statement branches @@ -2435,3 +2448,90 @@ def test_batch_result_infer_completion_reason_basic_cases(): # Test empty items - should be ALL_COMPLETED batch = BatchResult.from_dict({"all": []}, CompletionConfig(1)) assert batch.completion_reason == CompletionReason.ALL_COMPLETED + + +def test_operation_id_determinism_across_shuffles(): + """Test that operation_id depends on Executable.index, not execution order.""" + + def index_based_function(index, ctx): + """Function that returns a result based on the executable index.""" + return f"result_for_index_{index}" + + class TestExecutor(ConcurrentExecutor): + """Custom executor for testing operation_id determinism.""" + + def execute_item(self, child_context, executable): + return executable.func(child_context) + + # Create executables with specific indices using partial + num_executables = 50 + funcs = [partial(index_based_function, i) for i in range(num_executables)] + + # Track operation_id -> result associations + captured_associations = [] + + def patched_child_handler(func, execution_state, operation_identifier, config): + """Patched child handler that captures operation_id -> result mapping.""" + result = func() # Execute the function + captured_associations.append((operation_identifier.operation_id, result)) + return result + + execution_state = Mock() + execution_state.create_checkpoint = Mock() + + completion_config = CompletionConfig(min_successful=num_executables) + + # Run multiple times with different shuffle orders + associations_per_run = [] + + for run in range(10): # Test 10 different shuffle orders + captured_associations.clear() + + # Create executables from shuffled functions + executables = [Executable(index=i, func=func) for i, func in enumerate(funcs)] + random.seed(run) # Different seed for each run + random.shuffle(executables) + + executor = TestExecutor( + executables=executables, + max_concurrency=2, + completion_config=completion_config, + sub_type_top="TEST", + sub_type_iteration="TEST_ITER", + name_prefix="test_", + serdes=None, + ) + + # Create executor context mock + executor_context = Mock() + executor_context._parent_id = "parent_123" # noqa SLF001 + + def create_step_id(index): + return f"step_{index}" + + executor_context._create_step_id_for_logical_step = create_step_id # noqa SLF001 + + def create_child_context(operation_id): + child_ctx = Mock() + child_ctx.state = execution_state + return child_ctx + + executor_context.create_child_context = create_child_context + + with patch( + "aws_durable_execution_sdk_python.concurrency.child_handler", + patched_child_handler, + ): + executor.execute(execution_state, executor_context) + + associations_per_run.append(captured_associations.copy()) + + # first we will verify the validity of the test by ensuring that there exist at least 2 runs with different ordering + assert any( + assoc1 != assoc2 for assoc1, assoc2 in combinations(associations_per_run, 2) + ) + # then we will verify the invariant of association between step_id and result + associations_per_run = [dict(assoc) for assoc in associations_per_run] + assert all( + assoc1 == assoc2 for assoc1, assoc2 in combinations(associations_per_run, 2) + ) diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index 1afa596..87d6369 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -17,7 +17,6 @@ ) from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler -from aws_durable_execution_sdk_python.serdes import serialize from tests.serdes_test import CustomStrSerDes @@ -228,13 +227,14 @@ def callable_func(ctx, item, idx, items): completion_reason=CompletionReason.ALL_COMPLETED, ) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() + with patch.object( MapExecutor, "execute", return_value=mock_batch_result ) as mock_execute: - def mock_run_in_child_context(func, name, config): - return func("mock_context") - class MockExecutionState: pass @@ -242,11 +242,13 @@ class MockExecutionState: config = MapConfig() result = map_handler( - items, callable_func, config, execution_state, mock_run_in_child_context + items, callable_func, config, execution_state, executor_context ) # Verify execute was called - mock_execute.assert_called_once_with(execution_state, mock_run_in_child_context) + mock_execute.assert_called_once_with( + execution_state, executor_context=executor_context + ) assert result == mock_batch_result @@ -267,8 +269,9 @@ def callable_func(ctx, item, idx, items): mock_executor.execute.return_value = mock_batch_result mock_from_items.return_value = mock_executor - def mock_run_in_child_context(func, name, config): - return func("mock_context") + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() class MockExecutionState: pass @@ -276,7 +279,7 @@ class MockExecutionState: execution_state = MockExecutionState() result = map_handler( - items, callable_func, None, execution_state, mock_run_in_child_context + items, callable_func, None, execution_state, executor_context ) # Verify from_items was called with a MapConfig instance @@ -297,21 +300,15 @@ class MockExecutionState: def test_map_handler_with_serdes(): - """Test that map_handler calls executor.execute method.""" + """Test that map_handler with serdes""" items = ["test_item"] def callable_func(ctx, item, idx, items): - return f"result_{item}" + return f"RESULT_{item.upper()}" - # Mock the executor.execute method - - def mock_run_in_child_context(func, name, config): - return serialize( - serdes=config.serdes, - value=func("mock_context"), - operation_id="op_id", - durable_execution_arn="durable_execution_arn", - ) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() class MockExecutionState: pass @@ -320,7 +317,7 @@ class MockExecutionState: config = MapConfig(serdes=CustomStrSerDes()) result = map_handler( - items, callable_func, config, execution_state, mock_run_in_child_context + items, callable_func, config, execution_state, executor_context ) # Verify execute was called @@ -328,7 +325,7 @@ class MockExecutionState: def test_map_handler_with_summary_generator(): - """Test that map_handler passes summary_generator to child config.""" + """Test that map_handler calls executor_context methods correctly.""" items = ["item1", "item2"] def callable_func(ctx, item, idx, items): @@ -339,31 +336,26 @@ def mock_summary_generator(result): config = MapConfig(summary_generator=mock_summary_generator) - # Track the child_config passed to run_in_child_context - captured_child_configs = [] - - def mock_run_in_child_context(callable_func, name, child_config): - captured_child_configs.append(child_config) - return callable_func("mock-context") + executor_context = Mock() + executor_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2"]) # noqa SLF001 + executor_context.create_child_context = Mock(return_value=Mock()) class MockExecutionState: pass execution_state = MockExecutionState() - # Call map_handler with our mock run_in_child_context - map_handler( - items, callable_func, config, execution_state, mock_run_in_child_context - ) + # Call map_handler + map_handler(items, callable_func, config, execution_state, executor_context) - # Verify that the summary_generator was passed to the child config - assert len(captured_child_configs) > 0 - child_config = captured_child_configs[0] - assert child_config.summary_generator is mock_summary_generator + # Verify that create_child_context was called twice (N=2 items) + assert executor_context.create_child_context.call_count == 2 - # Test that the summary generator works - test_result = child_config.summary_generator("test" * 100) - assert test_result == "Summary of 400 chars for map item" + # Verify that _create_step_id_for_logical_step was called twice with unique values + assert executor_context._create_step_id_for_logical_step.call_count == 2 # noqa SLF001 + calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 + # Verify unique values were passed + assert calls[0] != calls[1] def test_map_executor_from_items_with_summary_generator(): @@ -385,18 +377,15 @@ def mock_summary_generator(result): def test_map_handler_default_summary_generator(): - """Test that map_handler uses default summary generator when config is None.""" + """Test that map_handler calls executor_context methods correctly with default config.""" items = ["item1"] def callable_func(ctx, item, idx, items): return f"result_{item}" - # Track the child_config passed to run_in_child_context - captured_child_configs = [] - - def mock_run_in_child_context(callable_func, name, child_config): - captured_child_configs.append(child_config) - return callable_func("mock-context") + executor_context = Mock() + executor_context._create_step_id_for_logical_step = Mock(return_value="1") # noqa SLF001 + executor_context.create_child_context = Mock(return_value=Mock()) # SLF001 class MockExecutionState: pass @@ -404,19 +393,13 @@ class MockExecutionState: execution_state = MockExecutionState() # Call map_handler with None config (should use default) - map_handler(items, callable_func, None, execution_state, mock_run_in_child_context) + map_handler(items, callable_func, None, execution_state, executor_context) - # Verify that a default summary_generator was provided - assert len(captured_child_configs) > 0 - child_config = captured_child_configs[0] - assert child_config.summary_generator is not None + # Verify that create_child_context was called once (N=1 item) + assert executor_context.create_child_context.call_count == 1 - # Test that the default summary generator works - test_result = child_config.summary_generator( - BatchResult([], CompletionReason.ALL_COMPLETED) - ) - assert isinstance(test_result, str) - assert len(test_result) > 0 + # Verify that _create_step_id_for_logical_step was called once + assert executor_context._create_step_id_for_logical_step.call_count == 1 # noqa SLF001 def test_map_executor_init_with_summary_generator(): @@ -445,12 +428,12 @@ def mock_summary_generator(result): def test_map_handler_with_explicit_none_summary_generator(): - """Test that map_handler respects explicit None summary_generator.""" + """Test that map_handler calls executor_context methods correctly with explicit None summary_generator.""" def func(ctx, item, index, array): return f"processed_{item}" - items = ["item1", "item2"] + items = ["item1", "item2", "item3"] # Explicitly set summary_generator to None config = MapConfig(summary_generator=None) @@ -459,35 +442,30 @@ class MockExecutionState: execution_state = MockExecutionState() - # Capture the child configs passed to run_in_child_context - captured_child_configs = [] - - def mock_run_in_child_context(func, name, child_config): - captured_child_configs.append(child_config) - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 + side_effect=["1", "2", "3"] + ) + executor_context.create_child_context = Mock(return_value=Mock()) - # Call map_handler with our mock run_in_child_context + # Call map_handler map_handler( items=items, func=func, config=config, execution_state=execution_state, - run_in_child_context=mock_run_in_child_context, + map_context=executor_context, ) - # Verify that the summary_generator was set to None (not default) - assert len(captured_child_configs) > 0 - child_config = captured_child_configs[0] - assert child_config.summary_generator is None + # Verify that create_child_context was called 3 times (N=3 items) + assert executor_context.create_child_context.call_count == 3 - # Test that when None, it should result in empty string behavior - # This matches child.py: config.summary_generator(raw_result) if config.summary_generator else "" - test_result = ( - child_config.summary_generator("test_data") - if child_config.summary_generator - else "" - ) - assert test_result == "" # noqa PLC1901 + # Verify that _create_step_id_for_logical_step was called 3 times with unique values + assert executor_context._create_step_id_for_logical_step.call_count == 3 # noqa SLF001 + calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 + # Verify all calls have unique values + call_values = [call[0][0] for call in calls] + assert len(set(call_values)) == 3 # All unique def test_map_config_with_explicit_none_summary_generator(): diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 8dd6c66..1dcb5d8 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -19,7 +19,6 @@ ParallelExecutor, parallel_handler, ) -from aws_durable_execution_sdk_python.serdes import serialize from tests.serdes_test import CustomStrSerDes @@ -199,8 +198,9 @@ def func1(ctx): config = ParallelConfig(max_concurrency=5) execution_state = Mock() - def mock_run_in_child_context(callable_func, name, child_config): - return callable_func("mock-context") + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() with patch.object(ParallelExecutor, "from_callables") as mock_from_callables: mock_executor = Mock() @@ -208,13 +208,11 @@ def mock_run_in_child_context(callable_func, name, child_config): mock_executor.execute.return_value = mock_batch_result mock_from_callables.return_value = mock_executor - result = parallel_handler( - callables, config, execution_state, mock_run_in_child_context - ) + result = parallel_handler(callables, config, execution_state, executor_context) mock_from_callables.assert_called_once_with(callables, config) mock_executor.execute.assert_called_once_with( - execution_state, mock_run_in_child_context + execution_state, executor_context=executor_context ) assert result == mock_batch_result @@ -228,8 +226,9 @@ def func1(ctx): callables = [func1] execution_state = Mock() - def mock_run_in_child_context(callable_func, name, child_config): - return callable_func("mock-context") + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() with patch.object(ParallelExecutor, "from_callables") as mock_from_callables: mock_executor = Mock() @@ -237,9 +236,7 @@ def mock_run_in_child_context(callable_func, name, child_config): mock_executor.execute.return_value = mock_batch_result mock_from_callables.return_value = mock_executor - result = parallel_handler( - callables, None, execution_state, mock_run_in_child_context - ) + result = parallel_handler(callables, None, execution_state, executor_context) assert result == mock_batch_result # Verify that a default ParallelConfig was created @@ -313,31 +310,27 @@ def test_parallel_handler_with_serdes(): """Test that parallel_handler with serdes""" def func1(ctx): - return "result1" + return "RESULT1" callables = [func1] execution_state = Mock() - def mock_run_in_child_context(callable_func, name, child_config): - return serialize( - serdes=child_config.serdes, - value=callable_func("mock-context"), - operation_id="op_id", - durable_execution_arn="exec_arn", - ) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = lambda *args: "1" # noqa SLF001 + executor_context.create_child_context = lambda *args: Mock() result = parallel_handler( callables, ParallelConfig(serdes=CustomStrSerDes()), execution_state, - mock_run_in_child_context, + executor_context, ) assert result.all[0].result == "RESULT1" def test_parallel_handler_with_summary_generator(): - """Test that parallel_handler passes summary_generator to child config.""" + """Test that parallel_handler calls executor_context methods correctly.""" def func1(ctx): return "large_result" * 1000 # Create a large result @@ -349,24 +342,18 @@ def mock_summary_generator(result): config = ParallelConfig(summary_generator=mock_summary_generator) execution_state = Mock() - # Track the child_config passed to run_in_child_context - captured_child_configs = [] - - def mock_run_in_child_context(callable_func, name, child_config): - captured_child_configs.append(child_config) - return callable_func("mock-context") + executor_context = Mock() + executor_context._create_step_id_for_logical_step = Mock(return_value="1") # noqa SLF001 + executor_context.create_child_context = Mock(return_value=Mock()) - # Call parallel_handler with our mock run_in_child_context - parallel_handler(callables, config, execution_state, mock_run_in_child_context) + # Call parallel_handler + parallel_handler(callables, config, execution_state, executor_context) - # Verify that the summary_generator was passed to the child config - assert len(captured_child_configs) > 0 - child_config = captured_child_configs[0] - assert child_config.summary_generator is mock_summary_generator + # Verify that create_child_context was called once (N=1 job) + assert executor_context.create_child_context.call_count == 1 - # Test that the summary generator works - test_result = child_config.summary_generator("test" * 100) - assert test_result == "Summary of 400 chars" + # Verify that _create_step_id_for_logical_step was called once with unique value + assert executor_context._create_step_id_for_logical_step.call_count == 1 # noqa SLF001 def test_parallel_executor_from_callables_with_summary_generator(): @@ -388,82 +375,75 @@ def mock_summary_generator(result): def test_parallel_handler_default_summary_generator(): - """Test that parallel_handler uses default summary generator when config is None.""" + """Test that parallel_handler calls executor_context methods correctly with default config.""" def func1(ctx): return "result1" - callables = [func1] - execution_state = Mock() + def func2(ctx): + return "result2" - # Track the child_config passed to run_in_child_context - captured_child_configs = [] + callables = [func1, func2] + execution_state = Mock() - def mock_run_in_child_context(callable_func, name, child_config): - captured_child_configs.append(child_config) - return callable_func("mock-context") + executor_context = Mock() + executor_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2"]) # noqa SLF001 + executor_context.create_child_context = Mock(return_value=Mock()) # Call parallel_handler with None config (should use default) - parallel_handler(callables, None, execution_state, mock_run_in_child_context) - - # Verify that a default summary_generator was provided - assert len(captured_child_configs) > 0 - child_config = captured_child_configs[0] - assert child_config.summary_generator is not None - - # Test that the default summary generator works - test_result = child_config.summary_generator( - BatchResult.from_dict( - { - "all": [{"index": 0, "status": "SUCCEEDED", "result": "test"}], - "completionReason": "ALL_COMPLETED", - } - ) - ) - assert isinstance(test_result, str) - assert len(test_result) > 0 + parallel_handler(callables, None, execution_state, executor_context) + + # Verify that create_child_context was called twice (N=2 jobs) + assert executor_context.create_child_context.call_count == 2 + + # Verify that _create_step_id_for_logical_step was called twice with unique values + assert executor_context._create_step_id_for_logical_step.call_count == 2 # noqa SLF001 + calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 + # Verify unique values were passed + assert calls[0] != calls[1] def test_parallel_handler_with_explicit_none_summary_generator(): - """Test that parallel_handler respects explicit None summary_generator.""" + """Test that parallel_handler calls executor_context methods correctly with explicit None summary_generator.""" def func1(ctx): return "result1" - callables = [func1] + def func2(ctx): + return "result2" + + def func3(ctx): + return "result3" + + callables = [func1, func2, func3] # Explicitly set summary_generator to None config = ParallelConfig(summary_generator=None) execution_state = Mock() - # Capture the child configs passed to run_in_child_context - captured_child_configs = [] - - def mock_run_in_child_context(func, name, child_config): - captured_child_configs.append(child_config) - return func(Mock()) + executor_context = Mock() + executor_context._create_step_id_for_logical_step = Mock( # noqa: SLF001 + side_effect=["1", "2", "3"] + ) + executor_context.create_child_context = Mock(return_value=Mock()) - # Call parallel_handler with our mock run_in_child_context + # Call parallel_handler parallel_handler( callables=callables, config=config, execution_state=execution_state, - run_in_child_context=mock_run_in_child_context, + parallel_context=executor_context, ) - # Verify that the summary_generator was set to None (not default) - assert len(captured_child_configs) > 0 - child_config = captured_child_configs[0] - assert child_config.summary_generator is None + # Verify that create_child_context was called 3 times (N=3 jobs) + assert executor_context.create_child_context.call_count == 3 - # Test that when None, it should result in empty string behavior - # This matches child.py: config.summary_generator(raw_result) if config.summary_generator else "" - test_result = ( - child_config.summary_generator("test_data") - if child_config.summary_generator - else "" - ) - assert test_result == "" # noqa PLC1901 + # Verify that _create_step_id_for_logical_step was called 3 times with unique values + assert executor_context._create_step_id_for_logical_step.call_count == 3 # noqa SLF001 + calls = executor_context._create_step_id_for_logical_step.call_args_list # noqa SLF001 + # Verify all calls have unique values + call_values = [call[0][0] for call in calls] + assert len(set(call_values)) == 3 # All unique def test_parallel_config_with_explicit_none_summary_generator():