diff --git a/src/aws_durable_execution_sdk_python/concurrency.py b/src/aws_durable_execution_sdk_python/concurrency.py index fb43c03..67c93be 100644 --- a/src/aws_durable_execution_sdk_python/concurrency.py +++ b/src/aws_durable_execution_sdk_python/concurrency.py @@ -29,7 +29,7 @@ from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import ExecutionState - from aws_durable_execution_sdk_python.types import DurableContext + from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator logger = logging.getLogger(__name__) @@ -566,13 +566,24 @@ def __init__( sub_type_iteration: OperationSubType, name_prefix: str, serdes: SerDes | None, + summary_generator: SummaryGenerator | None = None, ): + """Initialize ConcurrentExecutor. + + Args: + summary_generator: Optional function to generate compact summaries for large results. + When the serialized result exceeds 256KB, this generator creates a JSON summary + instead of checkpointing the full result. Used by map/parallel operations to + handle large BatchResult payloads efficiently. Matches TypeScript behavior in + run-in-child-context-handler.ts. + """ self.executables = executables self.max_concurrency = max_concurrency self.completion_config = completion_config self.sub_type_top = sub_type_top self.sub_type_iteration = sub_type_iteration self.name_prefix = name_prefix + self.summary_generator = summary_generator # Event-driven state tracking for when the executor is done self._completion_event = threading.Event() @@ -785,7 +796,11 @@ def execute_in_child_context(child_context: DurableContext) -> ResultType: return run_in_child_context( execute_in_child_context, f"{self.name_prefix}{executable.index}", - ChildConfig(serdes=self.serdes, sub_type=self.sub_type_iteration), + ChildConfig( + serdes=self.serdes, + sub_type=self.sub_type_iteration, + summary_generator=self.summary_generator, + ), ) diff --git a/src/aws_durable_execution_sdk_python/config.py b/src/aws_durable_execution_sdk_python/config.py index 0ac2473..0f35579 100644 --- a/src/aws_durable_execution_sdk_python/config.py +++ b/src/aws_durable_execution_sdk_python/config.py @@ -19,6 +19,7 @@ from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.serdes import SerDes + from aws_durable_execution_sdk_python.types import SummaryGenerator Numeric = int | float # deliberately leaving off complex @@ -39,6 +40,38 @@ class TerminationMode(Enum): @dataclass(frozen=True) class CompletionConfig: + """Configuration for determining when parallel/map operations complete. + + This class defines the success/failure criteria for operations that process + multiple items or branches concurrently. + + Args: + min_successful: Minimum number of successful completions required. + If None, no minimum is enforced. Use this to implement "at least N + must succeed" semantics. + + tolerated_failure_count: Maximum number of failures allowed before + the operation is considered failed. If None, no limit on failure count. + Use this to implement "fail fast after N failures" semantics. + + tolerated_failure_percentage: Maximum percentage of failures allowed + (0.0 to 100.0). If None, no percentage limit is enforced. + Use this to implement "fail if more than X% fail" semantics. + + Note: + The operation completes when any of the completion criteria are met: + - Enough successes (min_successful reached) + - Too many failures (tolerated limits exceeded) + - All items/branches completed + + Example: + # Succeed if at least 3 succeed, fail if more than 2 fail + config = CompletionConfig( + min_successful=3, + tolerated_failure_count=2 + ) + """ + min_successful: int | None = None tolerated_failure_count: int | None = None tolerated_failure_percentage: int | float | None = None @@ -77,11 +110,47 @@ def all_successful(): @dataclass(frozen=True) class ParallelConfig: + """Configuration options for parallel execution operations. + + This class configures how parallel operations are executed, including + concurrency limits, completion criteria, and serialization behavior. + + Args: + max_concurrency: Maximum number of parallel branches to execute concurrently. + If None, no limit is imposed and all branches run concurrently. + Use this to control resource usage and prevent overwhelming the system. + + completion_config: Defines when the parallel operation should complete. + Controls success/failure criteria for the overall parallel operation. + Default is CompletionConfig.all_successful() which requires all branches + to succeed. Other options include first_successful() and all_completed(). + + serdes: Custom serialization/deserialization configuration for parallel results. + If None, uses the default serializer. This allows custom handling of + complex result types or optimization for large result sets. + + summary_generator: Function to generate compact summaries for large results (>256KB). + When the serialized result exceeds CHECKPOINT_SIZE_LIMIT, this generator + creates a JSON summary instead of checkpointing the full result. The operation + is marked with ReplayChildren=true to reconstruct the full result during replay. + + Used internally by map/parallel operations to handle large BatchResult payloads. + Signature: (result: T) -> str + + Example: + # Run at most 3 branches concurrently, succeed if any one succeeds + config = ParallelConfig( + max_concurrency=3, + completion_config=CompletionConfig.first_successful() + ) + """ + max_concurrency: int | None = None completion_config: CompletionConfig = field( default_factory=CompletionConfig.all_successful ) serdes: SerDes | None = None + summary_generator: SummaryGenerator | None = None class StepSemantics(Enum): @@ -106,12 +175,41 @@ class CheckpointMode(Enum): @dataclass(frozen=True) class ChildConfig(Generic[T]): - """Options when running inside a child context.""" + """Configuration options for child context operations. + + This class configures how child contexts are executed and checkpointed, + matching the TypeScript ChildConfig interface behavior. + + Args: + serdes: Custom serialization/deserialization configuration for the child context data. + If None, uses the default serializer. This allows different serialization + strategies for child operations vs parent operations. + + sub_type: Operation subtype identifier used for tracking and debugging. + Examples: OperationSubType.MAP_ITERATION, OperationSubType.PARALLEL_BRANCH. + Used internally by the execution engine for operation classification. + + summary_generator: Function to generate compact summaries for large results (>256KB). + When the serialized result exceeds CHECKPOINT_SIZE_LIMIT, this generator + creates a JSON summary instead of checkpointing the full result. The operation + is marked with ReplayChildren=true to reconstruct the full result during replay. + + Used internally by map/parallel operations to handle large BatchResult payloads. + Signature: (result: T) -> str + Note: + checkpoint_mode field is commented out as it's not currently implemented. + When implemented, it will control when checkpoints are created: + - CHECKPOINT_AT_START_AND_FINISH: Checkpoint at both start and completion (default) + - CHECKPOINT_AT_FINISH: Only checkpoint when operation completes + - NO_CHECKPOINT: No automatic checkpointing + + See TypeScript reference: aws-durable-execution-sdk-js/src/types/index.ts + """ # checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH serdes: SerDes | None = None sub_type: OperationSubType | None = None - summary_generator: Callable[[T], str] | None = None + summary_generator: SummaryGenerator | None = None class ItemsPerBatchUnit(Enum): @@ -121,6 +219,34 @@ class ItemsPerBatchUnit(Enum): @dataclass(frozen=True) class ItemBatcher(Generic[T]): + """Configuration for batching items in map operations. + + This class defines how individual items should be grouped together into batches + for more efficient processing in map operations. + + Args: + max_items_per_batch: Maximum number of items to include in a single batch. + If 0 (default), no item count limit is applied. Use this to control + batch size when processing many small items. + + max_item_bytes_per_batch: Maximum total size in bytes for items in a batch. + If 0 (default), no size limit is applied. Use this to control memory + usage when processing large items or when items vary significantly in size. + + batch_input: Additional data to include with each batch. + This data is passed to the processing function along with the batched items. + Useful for providing context or configuration that applies to all items + in the batch. + + Example: + # Batch up to 100 items or 1MB, whichever comes first + batcher = ItemBatcher( + max_items_per_batch=100, + max_item_bytes_per_batch=1024*1024, + batch_input={"processing_mode": "fast"} + ) + """ + max_items_per_batch: int = 0 max_item_bytes_per_batch: int | float = 0 batch_input: T | None = None @@ -128,10 +254,51 @@ class ItemBatcher(Generic[T]): @dataclass(frozen=True) class MapConfig: + """Configuration options for map operations over collections. + + This class configures how map operations process collections of items, + including concurrency, batching, completion criteria, and serialization. + + Args: + max_concurrency: Maximum number of items to process concurrently. + If None, no limit is imposed and all items are processed concurrently. + Use this to control resource usage when processing large collections. + + item_batcher: Configuration for batching multiple items together for processing. + Allows grouping items by count or size to optimize processing efficiency. + Default is no batching (each item processed individually). + + completion_config: Defines when the map operation should complete. + Controls success/failure criteria for the overall map operation. + Default allows any number of failures. Use CompletionConfig.all_successful() + to require all items to succeed. + + serdes: Custom serialization/deserialization configuration for map results. + If None, uses the default serializer. This allows custom handling of + complex item types or optimization for large result collections. + + summary_generator: Function to generate compact summaries for large results (>256KB). + When the serialized result exceeds CHECKPOINT_SIZE_LIMIT, this generator + creates a JSON summary instead of checkpointing the full result. The operation + is marked with ReplayChildren=true to reconstruct the full result during replay. + + Used internally by map/parallel operations to handle large BatchResult payloads. + Signature: (result: T) -> str + + Example: + # Process 5 items at a time, batch by count, require all to succeed + config = MapConfig( + max_concurrency=5, + item_batcher=ItemBatcher(max_items_per_batch=10), + completion_config=CompletionConfig.all_successful() + ) + """ + max_concurrency: int | None = None item_batcher: ItemBatcher = field(default_factory=ItemBatcher) completion_config: CompletionConfig = field(default_factory=CompletionConfig) serdes: SerDes | None = None + summary_generator: SummaryGenerator | None = None @dataclass diff --git a/src/aws_durable_execution_sdk_python/operation/child.py b/src/aws_durable_execution_sdk_python/operation/child.py index 9ccc3e9..e2e8c72 100644 --- a/src/aws_durable_execution_sdk_python/operation/child.py +++ b/src/aws_durable_execution_sdk_python/operation/child.py @@ -88,14 +88,28 @@ def child_handler( operation_id=operation_identifier.operation_id, durable_execution_arn=state.durable_execution_arn, ) + # Summary Generator Logic: + # When the serialized result exceeds 256KB, we use ReplayChildren mode to avoid + # checkpointing large payloads. Instead, we checkpoint a compact summary and mark + # the operation for replay. This matches the TypeScript implementation behavior. + # + # See TypeScript reference: + # - aws-durable-execution-sdk-js/src/handlers/run-in-child-context-handler/run-in-child-context-handler.ts (lines ~200-220) + # + # The summary generator creates a JSON summary with metadata (type, counts, status) + # instead of the full BatchResult. During replay, the child context is re-executed + # to reconstruct the full result rather than deserializing from the checkpoint. replay_children: bool = False if len(serialized_result) > CHECKPOINT_SIZE_LIMIT: logger.debug( - "Large payload detected, using ReplayChildren mode: id: %s, name: %s", + "Large payload detected, using ReplayChildren mode: id: %s, name: %s, payload_size: %d, limit: %d", operation_identifier.operation_id, operation_identifier.name, + len(serialized_result), + CHECKPOINT_SIZE_LIMIT, ) replay_children = True + # Use summary generator if provided, otherwise use empty string (matches TypeScript) serialized_result = ( config.summary_generator(raw_result) if config.summary_generator else "" ) diff --git a/src/aws_durable_execution_sdk_python/operation/map.py b/src/aws_durable_execution_sdk_python/operation/map.py index f910413..aae1e35 100644 --- a/src/aws_durable_execution_sdk_python/operation/map.py +++ b/src/aws_durable_execution_sdk_python/operation/map.py @@ -2,8 +2,10 @@ from __future__ import annotations +import json import logging from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, Generic, TypeVar from aws_durable_execution_sdk_python.concurrency import ( @@ -18,7 +20,7 @@ from aws_durable_execution_sdk_python.config import ChildConfig from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import ExecutionState - from aws_durable_execution_sdk_python.types import DurableContext + from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator logger = logging.getLogger(__name__) @@ -40,6 +42,7 @@ def __init__( iteration_sub_type: OperationSubType, name_prefix: str, serdes: SerDes | None, + summary_generator: SummaryGenerator = None, ): super().__init__( executables=executables, @@ -49,6 +52,7 @@ def __init__( sub_type_iteration=iteration_sub_type, name_prefix=name_prefix, serdes=serdes, + summary_generator=summary_generator, ) self.items = items @@ -73,6 +77,7 @@ def from_items( iteration_sub_type=OperationSubType.MAP_ITERATION, name_prefix="map-item-", serdes=config.serdes, + summary_generator=config.summary_generator, ) def execute_item(self, child_context, executable: Executable[Callable]) -> R: @@ -93,7 +98,29 @@ def map_handler( ], ) -> BatchResult[R]: """Execute a callable for each item in parallel.""" + # Summary Generator Construction (matches TypeScript implementation): + # Construct the summary generator at the handler level, just like TypeScript does in map-handler.ts. + # This matches the pattern where handlers are responsible for configuring operation-specific behavior. + # + # See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/map-handler/map-handler.ts (~line 79) + executor: MapExecutor[T, R] = MapExecutor.from_items( - items=items, func=func, config=config or MapConfig() + items=items, + func=func, + config=config or MapConfig(summary_generator=MapSummaryGenerator()), ) return executor.execute(execution_state, run_in_child_context) + + +@dataclass(frozen=True) +class MapSummaryGenerator: + def __call__(self, result: BatchResult) -> str: + fields = { + "totalCount": result.total_count, + "successCount": result.success_count, + "failureCount": result.failure_count, + "completionReason": result.completion_reason.value, + "status": result.status.value, + "type": "MapResult", + } + return json.dumps(fields) diff --git a/src/aws_durable_execution_sdk_python/operation/parallel.py b/src/aws_durable_execution_sdk_python/operation/parallel.py index aa17186..cbf9205 100644 --- a/src/aws_durable_execution_sdk_python/operation/parallel.py +++ b/src/aws_durable_execution_sdk_python/operation/parallel.py @@ -2,8 +2,10 @@ from __future__ import annotations +import json import logging from collections.abc import Callable, Sequence +from dataclasses import dataclass from typing import TYPE_CHECKING, TypeVar from aws_durable_execution_sdk_python.concurrency import ConcurrentExecutor, Executable @@ -15,7 +17,7 @@ from aws_durable_execution_sdk_python.config import ChildConfig from aws_durable_execution_sdk_python.serdes import SerDes from aws_durable_execution_sdk_python.state import ExecutionState - from aws_durable_execution_sdk_python.types import DurableContext + from aws_durable_execution_sdk_python.types import DurableContext, SummaryGenerator logger = logging.getLogger(__name__) @@ -33,6 +35,7 @@ def __init__( iteration_sub_type: OperationSubType, name_prefix: str, serdes: SerDes | None, + summary_generator: SummaryGenerator = None, ): super().__init__( executables=executables, @@ -42,6 +45,7 @@ def __init__( sub_type_iteration=iteration_sub_type, name_prefix=name_prefix, serdes=serdes, + summary_generator=summary_generator, ) @classmethod @@ -62,6 +66,7 @@ def from_callables( iteration_sub_type=OperationSubType.PARALLEL_BRANCH, name_prefix="parallel-branch-", serdes=config.serdes, + summary_generator=config.summary_generator, ) def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301 @@ -80,5 +85,30 @@ def parallel_handler( ], ) -> BatchResult[R]: """Execute multiple operations in parallel.""" - executor = ParallelExecutor.from_callables(callables, config or ParallelConfig()) + # Summary Generator Construction (matches TypeScript implementation): + # Construct the summary generator at the handler level, just like TypeScript does in parallel-handler.ts. + # This matches the pattern where handlers are responsible for configuring operation-specific behavior. + # + # See TypeScript reference: aws-durable-execution-sdk-js/src/handlers/parallel-handler/parallel-handler.ts (~line 112) + + executor = ParallelExecutor.from_callables( + callables, + config or ParallelConfig(summary_generator=ParallelSummaryGenerator()), + ) return executor.execute(execution_state, run_in_child_context) + + +@dataclass(frozen=True) +class ParallelSummaryGenerator: + def __call__(self, result: BatchResult) -> str: + fields = { + "totalCount": result.total_count, + "successCount": result.success_count, + "failureCount": result.failure_count, + "completionReason": result.completion_reason.value, + "status": result.status.value, + "startedCount": result.started_count, + "type": "ParallelResult", + } + + return json.dumps(fields) diff --git a/src/aws_durable_execution_sdk_python/types.py b/src/aws_durable_execution_sdk_python/types.py index acc5525..01f61ad 100644 --- a/src/aws_durable_execution_sdk_python/types.py +++ b/src/aws_durable_execution_sdk_python/types.py @@ -151,3 +151,23 @@ class LambdaContext(Protocol): # pragma: no cover def get_remaining_time_in_millis(self) -> int: ... def log(self, msg) -> None: ... + + +# region Summary + +"""Summary generators for concurrent operations. + +Summary generators create compact JSON representations of large BatchResult objects +when the serialized result exceeds the 256KB checkpoint size limit. This prevents +large payloads from being stored in checkpoints while maintaining operation metadata. + +When a summary is used, the operation is marked with ReplayChildren=true, causing +the child context to be re-executed during replay to reconstruct the full result. +""" + + +class SummaryGenerator(Protocol): + def __call__(self, result: BatchResult) -> str: ... + + +# endregion Summary diff --git a/tests/config_test.py b/tests/config_test.py index a78d842..04c392f 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -174,6 +174,22 @@ def test_child_config_with_sub_type(): assert config.sub_type is sub_type +def test_child_config_with_summary_generator(): + """Test ChildConfig with summary_generator.""" + + def mock_summary_generator(result): + return f"Summary of {result}" + + config = ChildConfig(summary_generator=mock_summary_generator) + assert config.serdes is None + assert config.sub_type is None + assert config.summary_generator is mock_summary_generator + + # Test that the summary generator works + result = config.summary_generator("test_data") + assert result == "Summary of test_data" + + def test_items_per_batch_unit_enum(): """Test ItemsPerBatchUnit enum.""" assert ItemsPerBatchUnit.COUNT.value == ("COUNT",) diff --git a/tests/operation/child_test.py b/tests/operation/child_test.py index 04cbf31..803ae51 100644 --- a/tests/operation/child_test.py +++ b/tests/operation/child_test.py @@ -456,3 +456,76 @@ def test_child_handler_replay_children_mode() -> None: assert actual_result == complex_result mock_state.create_checkpoint.assert_not_called() + + +def test_small_payload_with_summary_generator(): + """Test: Small payload with summary_generator -> replay_children = False""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + mock_result = Mock() + mock_result.is_succeeded.return_value = False + mock_result.is_failed.return_value = False + mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False + mock_result.is_existent.return_value = False + mock_state.get_checkpoint_result.return_value = mock_result + + # Small payload (< 256KB) + small_result = "small_payload" + mock_callable = Mock(return_value=small_result) + + def my_summary(result: str) -> str: + return "summary_of_small_payload" + + child_config = ChildConfig[str](summary_generator=my_summary) + + actual_result = child_handler( + mock_callable, + mock_state, + OperationIdentifier("op1", None, "test_name"), + child_config, + ) + + assert actual_result == small_result + success_call = mock_state.create_checkpoint.call_args_list[1] + success_operation = success_call[1]["operation_update"] + + # Small payload should NOT trigger replay_children, even with summary_generator + assert not success_operation.context_options.replay_children + # Should checkpoint the actual result, not the summary + assert success_operation.payload == '"small_payload"' # JSON serialized + + +def test_small_payload_without_summary_generator(): + """Test: Small payload without summary_generator -> replay_children = False""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + mock_result = Mock() + mock_result.is_succeeded.return_value = False + mock_result.is_failed.return_value = False + mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False + mock_result.is_existent.return_value = False + mock_state.get_checkpoint_result.return_value = mock_result + + # Small payload (< 256KB) + small_result = "small_payload" + mock_callable = Mock(return_value=small_result) + + child_config = ChildConfig[str]() # No summary_generator + + actual_result = child_handler( + mock_callable, + mock_state, + OperationIdentifier("op2", None, "test_name"), + child_config, + ) + + assert actual_result == small_result + success_call = mock_state.create_checkpoint.call_args_list[1] + success_operation = success_call[1]["operation_update"] + + # Small payload should NOT trigger replay_children + assert not success_operation.context_options.replay_children + # Should checkpoint the actual result + assert success_operation.payload == '"small_payload"' # JSON serialized diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index 0a7bf46..814c4ca 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -2,8 +2,18 @@ from unittest.mock import Mock, patch -from aws_durable_execution_sdk_python.concurrency import BatchResult, Executable -from aws_durable_execution_sdk_python.config import CompletionConfig, MapConfig +from aws_durable_execution_sdk_python.concurrency import ( + BatchItem, + BatchItemStatus, + BatchResult, + CompletionReason, + Executable, +) +from aws_durable_execution_sdk_python.config import ( + CompletionConfig, + ItemBatcher, + MapConfig, +) from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler from aws_durable_execution_sdk_python.serdes import serialize @@ -213,7 +223,17 @@ def callable_func(ctx, item, idx, items): return f"result_{item}" # Mock the executor.execute method - mock_batch_result = Mock(spec=BatchResult) + from aws_durable_execution_sdk_python.concurrency import ( + BatchItem, + BatchItemStatus, + BatchResult, + CompletionReason, + ) + + mock_batch_result = BatchResult( + all=[BatchItem(index=0, status=BatchItemStatus.SUCCEEDED, result="test")], + completion_reason=CompletionReason.ALL_COMPLETED, + ) with patch.object( MapExecutor, "execute", return_value=mock_batch_result @@ -247,7 +267,10 @@ def callable_func(ctx, item, idx, items): # Mock MapExecutor.from_items to verify it's called with default config with patch.object(MapExecutor, "from_items") as mock_from_items: mock_executor = Mock() - mock_batch_result = Mock(spec=BatchResult) + mock_batch_result = BatchResult( + all=[BatchItem(index=0, status=BatchItemStatus.SUCCEEDED, result="test")], + completion_reason=CompletionReason.ALL_COMPLETED, + ) mock_executor.execute.return_value = mock_batch_result mock_from_items.return_value = mock_executor @@ -309,3 +332,195 @@ class MockExecutionState: # Verify execute was called assert result.all[0].result == "RESULT_TEST_ITEM" + + +def test_map_handler_with_summary_generator(): + """Test that map_handler passes summary_generator to child config.""" + items = ["item1", "item2"] + + def callable_func(ctx, item, idx, items): + return f"large_result_{item}" * 1000 # Create a large result + + def mock_summary_generator(result): + return f"Summary of {len(result)} chars for map item" + + config = MapConfig(summary_generator=mock_summary_generator) + + # Track the child_config passed to run_in_child_context + captured_child_configs = [] + + def mock_run_in_child_context(callable_func, name, child_config): + captured_child_configs.append(child_config) + return callable_func("mock-context") + + class MockExecutionState: + pass + + execution_state = MockExecutionState() + + # Call map_handler with our mock run_in_child_context + result = map_handler( + items, callable_func, config, execution_state, mock_run_in_child_context + ) + + # Verify that the summary_generator was passed to the child config + assert len(captured_child_configs) > 0 + child_config = captured_child_configs[0] + assert child_config.summary_generator is mock_summary_generator + + # Test that the summary generator works + test_result = child_config.summary_generator("test" * 100) + assert test_result == "Summary of 400 chars for map item" + + +def test_map_executor_from_items_with_summary_generator(): + """Test MapExecutor.from_items preserves summary_generator.""" + items = ["item1"] + + def callable_func(ctx, item, idx, items): + return f"result_{item}" + + def mock_summary_generator(result): + return f"Map summary: {result}" + + config = MapConfig(summary_generator=mock_summary_generator) + + executor = MapExecutor.from_items(items, callable_func, config) + + # Verify that the summary_generator is preserved in the executor + assert executor.summary_generator is mock_summary_generator + + +def test_map_handler_default_summary_generator(): + """Test that map_handler uses default summary generator when config is None.""" + items = ["item1"] + + def callable_func(ctx, item, idx, items): + return f"result_{item}" + + # Track the child_config passed to run_in_child_context + captured_child_configs = [] + + def mock_run_in_child_context(callable_func, name, child_config): + captured_child_configs.append(child_config) + return callable_func("mock-context") + + class MockExecutionState: + pass + + execution_state = MockExecutionState() + + # Call map_handler with None config (should use default) + result = map_handler( + items, callable_func, None, execution_state, mock_run_in_child_context + ) + + # Verify that a default summary_generator was provided + assert len(captured_child_configs) > 0 + child_config = captured_child_configs[0] + assert child_config.summary_generator is not None + + # Test that the default summary generator works + test_result = child_config.summary_generator( + BatchResult([], CompletionReason.ALL_COMPLETED) + ) + assert isinstance(test_result, str) + assert len(test_result) > 0 + + +def test_map_executor_init_with_summary_generator(): + """Test MapExecutor initialization with summary_generator.""" + items = ["item1"] + executables = [Executable(index=0, func=lambda: None)] + + def mock_summary_generator(result): + return f"Summary: {result}" + + executor = MapExecutor( + executables=executables, + items=items, + max_concurrency=2, + completion_config=CompletionConfig(), + top_level_sub_type=OperationSubType.MAP, + iteration_sub_type=OperationSubType.MAP_ITERATION, + name_prefix="test-", + serdes=None, + summary_generator=mock_summary_generator, + ) + + assert executor.summary_generator is mock_summary_generator + assert executor.items == items + assert executor.executables == executables + + +def test_map_handler_with_explicit_none_summary_generator(): + """Test that map_handler respects explicit None summary_generator.""" + + def func(ctx, item, index, array): + return f"processed_{item}" + + items = ["item1", "item2"] + # Explicitly set summary_generator to None + config = MapConfig(summary_generator=None) + + class MockExecutionState: + pass + + execution_state = MockExecutionState() + + # Capture the child configs passed to run_in_child_context + captured_child_configs = [] + + def mock_run_in_child_context(func, name, child_config): + captured_child_configs.append(child_config) + return func(Mock()) + + # Call map_handler with our mock run_in_child_context + result = map_handler( + items=items, + func=func, + config=config, + execution_state=execution_state, + run_in_child_context=mock_run_in_child_context, + ) + + # Verify that the summary_generator was set to None (not default) + assert len(captured_child_configs) > 0 + child_config = captured_child_configs[0] + assert child_config.summary_generator is None + + # Test that when None, it should result in empty string behavior + # This matches child.py: config.summary_generator(raw_result) if config.summary_generator else "" + test_result = ( + child_config.summary_generator("test_data") + if child_config.summary_generator + else "" + ) + assert test_result == "" + + +def test_map_config_with_explicit_none_summary_generator(): + """Test MapConfig with explicitly set None summary_generator.""" + config = MapConfig(summary_generator=None) + + assert config.summary_generator is None + assert config.max_concurrency is None + assert isinstance(config.item_batcher, ItemBatcher) + assert isinstance(config.completion_config, CompletionConfig) + assert config.serdes is None + + +def test_map_config_default_summary_generator_behavior(): + """Test MapConfig() with no summary_generator should result in empty string behavior.""" + # When creating MapConfig() with no summary_generator specified + config = MapConfig() + + # The summary_generator should be None by default + assert config.summary_generator is None + + # But when used in the actual child.py logic, it should result in empty string + # This matches child.py: config.summary_generator(raw_result) if config.summary_generator else "" + test_result = ( + config.summary_generator("test_data") if config.summary_generator else "" + ) + assert test_result == "" diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 8c4dc61..1f54bad 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -5,7 +5,10 @@ import pytest from aws_durable_execution_sdk_python.concurrency import ( + BatchItem, + BatchItemStatus, BatchResult, + CompletionReason, ConcurrentExecutor, Executable, ) @@ -16,7 +19,6 @@ parallel_handler, ) from aws_durable_execution_sdk_python.serdes import serialize -from aws_durable_execution_sdk_python.state import ExecutionState from tests.serdes_test import CustomStrSerDes @@ -142,14 +144,24 @@ def func2(ctx): callables = [func1, func2] config = ParallelConfig(max_concurrency=2) - execution_state = Mock(spec=ExecutionState) + execution_state = Mock() # Mock the run_in_child_context function def mock_run_in_child_context(callable_func, name, child_config): return callable_func("mock-context") # Mock the executor.execute method to return a BatchResult - mock_batch_result = Mock(spec=BatchResult) + from aws_durable_execution_sdk_python.concurrency import ( + BatchItem, + BatchItemStatus, + BatchResult, + CompletionReason, + ) + + mock_batch_result = BatchResult( + all=[BatchItem(index=0, status=BatchItemStatus.SUCCEEDED, result="test")], + completion_reason=CompletionReason.ALL_COMPLETED, + ) with patch.object(ParallelExecutor, "execute", return_value=mock_batch_result): result = parallel_handler( @@ -166,12 +178,15 @@ def func1(ctx): return "result1" callables = [func1] - execution_state = Mock(spec=ExecutionState) + execution_state = Mock() def mock_run_in_child_context(callable_func, name, child_config): return callable_func("mock-context") - mock_batch_result = Mock(spec=BatchResult) + mock_batch_result = BatchResult( + all=[BatchItem(index=0, status=BatchItemStatus.SUCCEEDED, result="test")], + completion_reason=CompletionReason.ALL_COMPLETED, + ) with patch.object(ParallelExecutor, "execute", return_value=mock_batch_result): result = parallel_handler( @@ -189,7 +204,7 @@ def func1(ctx): callables = [func1] config = ParallelConfig(max_concurrency=5) - execution_state = Mock(spec=ExecutionState) + execution_state = Mock() def mock_run_in_child_context(callable_func, name, child_config): return callable_func("mock-context") @@ -218,7 +233,7 @@ def func1(ctx): return "result1" callables = [func1] - execution_state = Mock(spec=ExecutionState) + execution_state = Mock() def mock_run_in_child_context(callable_func, name, child_config): return callable_func("mock-context") @@ -308,7 +323,7 @@ def func1(ctx): return "result1" callables = [func1] - execution_state = Mock(spec=ExecutionState) + execution_state = Mock() def mock_run_in_child_context(callable_func, name, child_config): return serialize( @@ -326,3 +341,163 @@ def mock_run_in_child_context(callable_func, name, child_config): ) assert result.all[0].result == "RESULT1" + + +def test_parallel_handler_with_summary_generator(): + """Test that parallel_handler passes summary_generator to child config.""" + + def func1(ctx): + return "large_result" * 1000 # Create a large result + + def mock_summary_generator(result): + return f"Summary of {len(result)} chars" + + callables = [func1] + config = ParallelConfig(summary_generator=mock_summary_generator) + execution_state = Mock() + + # Track the child_config passed to run_in_child_context + captured_child_configs = [] + + def mock_run_in_child_context(callable_func, name, child_config): + captured_child_configs.append(child_config) + return callable_func("mock-context") + + # Call parallel_handler with our mock run_in_child_context + result = parallel_handler( + callables, config, execution_state, mock_run_in_child_context + ) + + # Verify that the summary_generator was passed to the child config + assert len(captured_child_configs) > 0 + child_config = captured_child_configs[0] + assert child_config.summary_generator is mock_summary_generator + + # Test that the summary generator works + test_result = child_config.summary_generator("test" * 100) + assert test_result == "Summary of 400 chars" + + +def test_parallel_executor_from_callables_with_summary_generator(): + """Test ParallelExecutor.from_callables preserves summary_generator.""" + + def func1(ctx): + return "result1" + + def mock_summary_generator(result): + return f"Summary: {result}" + + callables = [func1] + config = ParallelConfig(summary_generator=mock_summary_generator) + + executor = ParallelExecutor.from_callables(callables, config) + + # Verify that the summary_generator is preserved in the executor + assert executor.summary_generator is mock_summary_generator + + +def test_parallel_handler_default_summary_generator(): + """Test that parallel_handler uses default summary generator when config is None.""" + + def func1(ctx): + return "result1" + + callables = [func1] + execution_state = Mock() + + # Track the child_config passed to run_in_child_context + captured_child_configs = [] + + def mock_run_in_child_context(callable_func, name, child_config): + captured_child_configs.append(child_config) + return callable_func("mock-context") + + # Call parallel_handler with None config (should use default) + result = parallel_handler( + callables, None, execution_state, mock_run_in_child_context + ) + + # Verify that a default summary_generator was provided + assert len(captured_child_configs) > 0 + child_config = captured_child_configs[0] + assert child_config.summary_generator is not None + + # Test that the default summary generator works + test_result = child_config.summary_generator( + BatchResult.from_dict( + { + "all": [{"index": 0, "status": "SUCCEEDED", "result": "test"}], + "completionReason": "ALL_COMPLETED", + } + ) + ) + assert isinstance(test_result, str) + assert len(test_result) > 0 + + +def test_parallel_handler_with_explicit_none_summary_generator(): + """Test that parallel_handler respects explicit None summary_generator.""" + + def func1(ctx): + return "result1" + + callables = [func1] + # Explicitly set summary_generator to None + config = ParallelConfig(summary_generator=None) + + execution_state = Mock() + + # Capture the child configs passed to run_in_child_context + captured_child_configs = [] + + def mock_run_in_child_context(func, name, child_config): + captured_child_configs.append(child_config) + return func(Mock()) + + # Call parallel_handler with our mock run_in_child_context + result = parallel_handler( + callables=callables, + config=config, + execution_state=execution_state, + run_in_child_context=mock_run_in_child_context, + ) + + # Verify that the summary_generator was set to None (not default) + assert len(captured_child_configs) > 0 + child_config = captured_child_configs[0] + assert child_config.summary_generator is None + + # Test that when None, it should result in empty string behavior + # This matches child.py: config.summary_generator(raw_result) if config.summary_generator else "" + test_result = ( + child_config.summary_generator("test_data") + if child_config.summary_generator + else "" + ) + assert test_result == "" + + +def test_parallel_config_with_explicit_none_summary_generator(): + """Test ParallelConfig with explicitly set None summary_generator.""" + config = ParallelConfig(summary_generator=None) + + assert config.summary_generator is None + assert config.max_concurrency is None + assert isinstance(config.completion_config, CompletionConfig) + + +def test_parallel_config_default_summary_generator_behavior(): + """Test ParallelConfig() with no summary_generator should result in empty string behavior.""" + # When creating ParallelConfig() with no summary_generator specified + config = ParallelConfig() + + # The summary_generator should be None by default + assert config.summary_generator is None + + # But when used in the actual child.py logic, it should result in empty string + # This matches child.py: config.summary_generator(raw_result) if config.summary_generator else "" + test_result = ( + config.summary_generator("test_data") if config.summary_generator else "" + ) + assert test_result == "" + assert config.serdes is None