Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/aws_durable_execution_sdk_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class ChildConfig:
# checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
serdes: SerDes | None = None
sub_type: OperationSubType | None = None
summary_generator: Callable[[T], str] | None = None


class ItemsPerBatchUnit(Enum):
Expand Down
7 changes: 6 additions & 1 deletion src/aws_durable_execution_sdk_python/lambda_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,11 @@ def create_context_start(

@classmethod
def create_context_succeed(
cls, identifier: OperationIdentifier, payload: str, sub_type: OperationSubType
cls,
identifier: OperationIdentifier,
payload: str,
sub_type: OperationSubType,
context_options: ContextOptions | None = None,
) -> OperationUpdate:
"""Create an instance of OperationUpdate for type: CONTEXT, action: SUCCEED."""
return cls(
Expand All @@ -355,6 +359,7 @@ def create_context_succeed(
action=OperationAction.SUCCEED,
name=identifier.name,
payload=payload,
context_options=context_options,
)

@classmethod
Expand Down
32 changes: 29 additions & 3 deletions src/aws_durable_execution_sdk_python/operation/child.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aws_durable_execution_sdk_python.config import ChildConfig
from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution
from aws_durable_execution_sdk_python.lambda_service import (
ContextOptions,
ErrorObject,
OperationSubType,
OperationUpdate,
Expand All @@ -24,6 +25,9 @@

T = TypeVar("T")

# Checkpoint size limit in bytes (256KB)
CHECKPOINT_SIZE_LIMIT = 256 * 1024


def child_handler(
func: Callable[[], T],
Expand All @@ -40,9 +44,11 @@ def child_handler(
if not config:
config = ChildConfig()

# TODO: ReplayChildren
checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id)
if checkpointed_result.is_succeeded():
if (
checkpointed_result.is_succeeded()
and not checkpointed_result.is_replay_children()
):
logger.debug(
"Child context already completed, skipping execution for id: %s, name: %s",
operation_identifier.operation_id,
Expand Down Expand Up @@ -71,17 +77,37 @@ def child_handler(

try:
raw_result: T = func()
if checkpointed_result.is_replay_children():
logger.debug(
"ReplayChildren mode: Re-executing child context due to large payload: id: %s, name: %s",
operation_identifier.operation_id,
operation_identifier.name,
)
return raw_result
serialized_result: str = serialize(
serdes=config.serdes,
value=raw_result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)
payload_to_checkpoint = serialized_result
replay_children = False
if len(serialized_result) > CHECKPOINT_SIZE_LIMIT:
logger.debug(
"Large payload detected, using ReplayChildren mode: id: %s, name: %s",
operation_identifier.operation_id,
operation_identifier.name,
)
replay_children = True
payload_to_checkpoint = (
config.summary_generator(raw_result) if config.summary_generator else ""
)

success_operation = OperationUpdate.create_context_succeed(
identifier=operation_identifier,
payload=serialized_result,
payload=payload_to_checkpoint,
sub_type=sub_type,
context_options=ContextOptions(replay_children=replay_children),
)
state.create_checkpoint(operation_update=success_operation)

Expand Down
9 changes: 9 additions & 0 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def is_timed_out(self) -> bool:
return False
return op.status is OperationStatus.TIMED_OUT

def is_replay_children(self) -> bool:
op = self.operation
if not op:
return False
context_details = op.context_details
if not context_details:
return False
return context_details.replay_children

def raise_callable_error(self) -> None:
if self.error is None:
msg: str = "Attempted to throw exception, but no ErrorObject exists on the Checkpoint Operation."
Expand Down
94 changes: 94 additions & 0 deletions tests/operation/child_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_child_handler_not_started(
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock(return_value="fresh_result")

Expand Down Expand Up @@ -80,6 +81,7 @@ def test_child_handler_already_succeeded():
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_replay_children.return_value = False
mock_result.result = json.dumps("cached_result")
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock()
Expand All @@ -99,6 +101,7 @@ def test_child_handler_already_succeeded_none_result():
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_replay_children.return_value = False
mock_result.result = None
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock()
Expand Down Expand Up @@ -155,6 +158,7 @@ def test_child_handler_already_started(
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = True
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock(return_value="started_result")

Expand Down Expand Up @@ -281,6 +285,7 @@ def test_child_handler_default_serialization():
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
mock_callable = Mock(return_value=complex_result)
Expand All @@ -306,6 +311,7 @@ def test_child_handler_custom_serdes_not_start():
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
mock_callable = Mock(return_value=complex_result)
Expand Down Expand Up @@ -334,6 +340,7 @@ def test_child_handler_custom_serdes_already_succeeded():
mock_result.is_succeeded.return_value = True
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.result = '{"key": "VALUE", "number": "84", "list": [1, 2, 3]}'
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock()
Expand All @@ -352,3 +359,90 @@ def test_child_handler_custom_serdes_already_succeeded():


# endregion child_handler


# large payload with summary generator
def test_child_handler_large_payload_with_summary_generator():
"""Test child_handler with large payload and summary generator."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
large_result = "large" * 256 * 1024
mock_callable = Mock(return_value=large_result)
child_config: ChildConfig = ChildConfig(summary_generator=lambda x: "summary")

actual_result = child_handler(
mock_callable,
mock_state,
OperationIdentifier("op9", None, "test_name"),
child_config,
)

assert large_result == actual_result
success_call = mock_state.create_checkpoint.call_args_list[1]
success_operation = success_call[1]["operation_update"]
assert success_operation.context_options.replay_children
expected_checkpoointed_result = "summary"
assert success_operation.payload == expected_checkpoointed_result


# large payload without summary generator
def test_child_handler_large_payload_without_summary_generator():
"""Test child_handler with large payload and no summary generator."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
large_result = "large" * 256 * 1024
mock_callable = Mock(return_value=large_result)
child_config: ChildConfig = ChildConfig()

actual_result = child_handler(
mock_callable,
mock_state,
OperationIdentifier("op9", None, "test_name"),
child_config,
)

assert large_result == actual_result
success_call = mock_state.create_checkpoint.call_args_list[1]
success_operation = success_call[1]["operation_update"]
assert success_operation.context_options.replay_children
expected_checkpoointed_result = ""
assert success_operation.payload == expected_checkpoointed_result


# mocked children replay mode execute the function again
def test_child_handler_replay_children_mode():
"""Test child_handler in ReplayChildren mode."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = True
mock_result.is_replay_children.return_value = True
mock_state.get_checkpoint_result.return_value = mock_result
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
mock_callable = Mock(return_value=complex_result)
child_config: ChildConfig = ChildConfig()

actual_result = child_handler(
mock_callable,
mock_state,
OperationIdentifier("op9", None, "test_name"),
child_config,
)

assert actual_result == complex_result

mock_state.create_checkpoint.assert_not_called()
Loading