diff --git a/src/aws_durable_execution_sdk_python/config.py b/src/aws_durable_execution_sdk_python/config.py index 08747f2..348e480 100644 --- a/src/aws_durable_execution_sdk_python/config.py +++ b/src/aws_durable_execution_sdk_python/config.py @@ -104,12 +104,13 @@ class CheckpointMode(Enum): @dataclass(frozen=True) -class ChildConfig: +class ChildConfig(Generic[T]): """Options when running inside a child context.""" # checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH serdes: SerDes | None = None sub_type: OperationSubType | None = None + summary_generator: Callable[[T], str] | None = None class ItemsPerBatchUnit(Enum): diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index 8f99f5d..cb103e3 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -344,7 +344,11 @@ def create_context_start( @classmethod def create_context_succeed( - cls, identifier: OperationIdentifier, payload: str, sub_type: OperationSubType + cls, + identifier: OperationIdentifier, + payload: str, + sub_type: OperationSubType, + context_options: ContextOptions | None = None, ) -> OperationUpdate: """Create an instance of OperationUpdate for type: CONTEXT, action: SUCCEED.""" return cls( @@ -355,6 +359,7 @@ def create_context_succeed( action=OperationAction.SUCCEED, name=identifier.name, payload=payload, + context_options=context_options, ) @classmethod diff --git a/src/aws_durable_execution_sdk_python/operation/child.py b/src/aws_durable_execution_sdk_python/operation/child.py index d5ef8f7..1af4b0b 100644 --- a/src/aws_durable_execution_sdk_python/operation/child.py +++ b/src/aws_durable_execution_sdk_python/operation/child.py @@ -8,6 +8,7 @@ from aws_durable_execution_sdk_python.config import ChildConfig from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution from aws_durable_execution_sdk_python.lambda_service import ( + ContextOptions, ErrorObject, OperationSubType, OperationUpdate, @@ -24,6 +25,9 @@ T = TypeVar("T") +# Checkpoint size limit in bytes (256KB) +CHECKPOINT_SIZE_LIMIT = 256 * 1024 + def child_handler( func: Callable[[], T], @@ -40,9 +44,11 @@ def child_handler( if not config: config = ChildConfig() - # TODO: ReplayChildren checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id) - if checkpointed_result.is_succeeded(): + if ( + checkpointed_result.is_succeeded() + and not checkpointed_result.is_replay_children() + ): logger.debug( "Child context already completed, skipping execution for id: %s, name: %s", operation_identifier.operation_id, @@ -71,17 +77,36 @@ def child_handler( try: raw_result: T = func() + if checkpointed_result.is_replay_children(): + logger.debug( + "ReplayChildren mode: Executed child context again on replay due to large payload. Exiting child context without creating another checkpoint. id: %s, name: %s", + operation_identifier.operation_id, + operation_identifier.name, + ) + return raw_result serialized_result: str = serialize( serdes=config.serdes, value=raw_result, operation_id=operation_identifier.operation_id, durable_execution_arn=state.durable_execution_arn, ) + replay_children: bool = False + if len(serialized_result) > CHECKPOINT_SIZE_LIMIT: + logger.debug( + "Large payload detected, using ReplayChildren mode: id: %s, name: %s", + operation_identifier.operation_id, + operation_identifier.name, + ) + replay_children = True + serialized_result = ( + config.summary_generator(raw_result) if config.summary_generator else "" + ) success_operation = OperationUpdate.create_context_succeed( identifier=operation_identifier, payload=serialized_result, sub_type=sub_type, + context_options=ContextOptions(replay_children=replay_children), ) state.create_checkpoint(operation_update=success_operation) diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index ff43e87..3691ece 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -113,6 +113,12 @@ def is_timed_out(self) -> bool: return False return op.status is OperationStatus.TIMED_OUT + def is_replay_children(self) -> bool: + op = self.operation + if not op: + return False + return op.context_details.replay_children if op.context_details else False + def raise_callable_error(self) -> None: if self.error is None: msg: str = "Attempted to throw exception, but no ErrorObject exists on the Checkpoint Operation." diff --git a/tests/operation/child_test.py b/tests/operation/child_test.py index e284518..3ae9e7d 100644 --- a/tests/operation/child_test.py +++ b/tests/operation/child_test.py @@ -41,6 +41,8 @@ def test_child_handler_not_started( mock_result.is_succeeded.return_value = False mock_result.is_failed.return_value = False mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False + mock_result.is_replay_children.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="fresh_result") @@ -80,6 +82,7 @@ def test_child_handler_already_succeeded(): mock_state.durable_execution_arn = "test_arn" mock_result = Mock() mock_result.is_succeeded.return_value = True + mock_result.is_replay_children.return_value = False mock_result.result = json.dumps("cached_result") mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock() @@ -99,6 +102,7 @@ def test_child_handler_already_succeeded_none_result(): mock_state.durable_execution_arn = "test_arn" mock_result = Mock() mock_result.is_succeeded.return_value = True + mock_result.is_replay_children.return_value = False mock_result.result = None mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock() @@ -155,6 +159,7 @@ def test_child_handler_already_started( mock_result.is_succeeded.return_value = False mock_result.is_failed.return_value = False mock_result.is_started.return_value = True + mock_result.is_replay_children.return_value = False mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock(return_value="started_result") @@ -281,6 +286,7 @@ def test_child_handler_default_serialization(): mock_result.is_succeeded.return_value = False mock_result.is_failed.return_value = False mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False mock_state.get_checkpoint_result.return_value = mock_result complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} mock_callable = Mock(return_value=complex_result) @@ -306,6 +312,7 @@ def test_child_handler_custom_serdes_not_start(): mock_result.is_succeeded.return_value = False mock_result.is_failed.return_value = False mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False mock_state.get_checkpoint_result.return_value = mock_result complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} mock_callable = Mock(return_value=complex_result) @@ -334,6 +341,7 @@ def test_child_handler_custom_serdes_already_succeeded(): mock_result.is_succeeded.return_value = True mock_result.is_failed.return_value = False mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False mock_result.result = '{"key": "VALUE", "number": "84", "list": [1, 2, 3]}' mock_state.get_checkpoint_result.return_value = mock_result mock_callable = Mock() @@ -352,3 +360,94 @@ def test_child_handler_custom_serdes_already_succeeded(): # endregion child_handler + + +# large payload with summary generator +def test_child_handler_large_payload_with_summary_generator(): + """Test child_handler with large payload and summary generator.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + mock_result = Mock() + mock_result.is_succeeded.return_value = False + mock_result.is_failed.return_value = False + mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False + mock_state.get_checkpoint_result.return_value = mock_result + large_result = "large" * 256 * 1024 + mock_callable = Mock(return_value=large_result) + + def my_summary(result: str) -> str: + return "summary" + + child_config: ChildConfig = ChildConfig[str](summary_generator=my_summary) + + actual_result = child_handler( + mock_callable, + mock_state, + OperationIdentifier("op9", None, "test_name"), + child_config, + ) + + assert large_result == actual_result + success_call = mock_state.create_checkpoint.call_args_list[1] + success_operation = success_call[1]["operation_update"] + assert success_operation.context_options.replay_children + expected_checkpoointed_result = "summary" + assert success_operation.payload == expected_checkpoointed_result + + +# large payload without summary generator +def test_child_handler_large_payload_without_summary_generator(): + """Test child_handler with large payload and no summary generator.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + mock_result = Mock() + mock_result.is_succeeded.return_value = False + mock_result.is_failed.return_value = False + mock_result.is_started.return_value = False + mock_result.is_replay_children.return_value = False + mock_state.get_checkpoint_result.return_value = mock_result + large_result = "large" * 256 * 1024 + mock_callable = Mock(return_value=large_result) + child_config: ChildConfig = ChildConfig() + + actual_result = child_handler( + mock_callable, + mock_state, + OperationIdentifier("op9", None, "test_name"), + child_config, + ) + + assert large_result == actual_result + success_call = mock_state.create_checkpoint.call_args_list[1] + success_operation = success_call[1]["operation_update"] + assert success_operation.context_options.replay_children + expected_checkpoointed_result = "" + assert success_operation.payload == expected_checkpoointed_result + + +# mocked children replay mode execute the function again +def test_child_handler_replay_children_mode(): + """Test child_handler in ReplayChildren mode.""" + mock_state = Mock(spec=ExecutionState) + mock_state.durable_execution_arn = "test_arn" + mock_result = Mock() + mock_result.is_succeeded.return_value = True + mock_result.is_failed.return_value = False + mock_result.is_started.return_value = True + mock_result.is_replay_children.return_value = True + mock_state.get_checkpoint_result.return_value = mock_result + complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]} + mock_callable = Mock(return_value=complex_result) + child_config: ChildConfig = ChildConfig() + + actual_result = child_handler( + mock_callable, + mock_state, + OperationIdentifier("op9", None, "test_name"), + child_config, + ) + + assert actual_result == complex_result + + mock_state.create_checkpoint.assert_not_called()