Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/aws_durable_execution_sdk_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ class CheckpointMode(Enum):


@dataclass(frozen=True)
class ChildConfig:
class ChildConfig(Generic[T]):
"""Options when running inside a child context."""

# checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
serdes: SerDes | None = None
sub_type: OperationSubType | None = None
summary_generator: Callable[[T], str] | None = None


class ItemsPerBatchUnit(Enum):
Expand Down
7 changes: 6 additions & 1 deletion src/aws_durable_execution_sdk_python/lambda_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,11 @@ def create_context_start(

@classmethod
def create_context_succeed(
cls, identifier: OperationIdentifier, payload: str, sub_type: OperationSubType
cls,
identifier: OperationIdentifier,
payload: str,
sub_type: OperationSubType,
context_options: ContextOptions | None = None,
) -> OperationUpdate:
"""Create an instance of OperationUpdate for type: CONTEXT, action: SUCCEED."""
return cls(
Expand All @@ -355,6 +359,7 @@ def create_context_succeed(
action=OperationAction.SUCCEED,
name=identifier.name,
payload=payload,
context_options=context_options,
)

@classmethod
Expand Down
29 changes: 27 additions & 2 deletions src/aws_durable_execution_sdk_python/operation/child.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aws_durable_execution_sdk_python.config import ChildConfig
from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution
from aws_durable_execution_sdk_python.lambda_service import (
ContextOptions,
ErrorObject,
OperationSubType,
OperationUpdate,
Expand All @@ -24,6 +25,9 @@

T = TypeVar("T")

# Checkpoint size limit in bytes (256KB)
CHECKPOINT_SIZE_LIMIT = 256 * 1024


def child_handler(
func: Callable[[], T],
Expand All @@ -40,9 +44,11 @@ def child_handler(
if not config:
config = ChildConfig()

# TODO: ReplayChildren
checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id)
if checkpointed_result.is_succeeded():
if (
checkpointed_result.is_succeeded()
and not checkpointed_result.is_replay_children()
):
logger.debug(
"Child context already completed, skipping execution for id: %s, name: %s",
operation_identifier.operation_id,
Expand Down Expand Up @@ -71,17 +77,36 @@ def child_handler(

try:
raw_result: T = func()
if checkpointed_result.is_replay_children():
logger.debug(
"ReplayChildren mode: Executed child context again on replay due to large payload. Exiting child context without creating another checkpoint. id: %s, name: %s",
operation_identifier.operation_id,
operation_identifier.name,
)
return raw_result
serialized_result: str = serialize(
serdes=config.serdes,
value=raw_result,
operation_id=operation_identifier.operation_id,
durable_execution_arn=state.durable_execution_arn,
)
replay_children: bool = False
if len(serialized_result) > CHECKPOINT_SIZE_LIMIT:
logger.debug(
"Large payload detected, using ReplayChildren mode: id: %s, name: %s",
operation_identifier.operation_id,
operation_identifier.name,
)
replay_children = True
serialized_result = (
config.summary_generator(raw_result) if config.summary_generator else ""
)

success_operation = OperationUpdate.create_context_succeed(
identifier=operation_identifier,
payload=serialized_result,
sub_type=sub_type,
context_options=ContextOptions(replay_children=replay_children),
)
state.create_checkpoint(operation_update=success_operation)

Expand Down
6 changes: 6 additions & 0 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def is_timed_out(self) -> bool:
return False
return op.status is OperationStatus.TIMED_OUT

def is_replay_children(self) -> bool:
op = self.operation
if not op:
return False
return op.context_details.replay_children if op.context_details else False

def raise_callable_error(self) -> None:
if self.error is None:
msg: str = "Attempted to throw exception, but no ErrorObject exists on the Checkpoint Operation."
Expand Down
99 changes: 99 additions & 0 deletions tests/operation/child_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def test_child_handler_not_started(
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock(return_value="fresh_result")

Expand Down Expand Up @@ -80,6 +82,7 @@ def test_child_handler_already_succeeded():
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_replay_children.return_value = False
mock_result.result = json.dumps("cached_result")
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock()
Expand All @@ -99,6 +102,7 @@ def test_child_handler_already_succeeded_none_result():
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_replay_children.return_value = False
mock_result.result = None
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock()
Expand Down Expand Up @@ -155,6 +159,7 @@ def test_child_handler_already_started(
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = True
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock(return_value="started_result")

Expand Down Expand Up @@ -281,6 +286,7 @@ def test_child_handler_default_serialization():
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
mock_callable = Mock(return_value=complex_result)
Expand All @@ -306,6 +312,7 @@ def test_child_handler_custom_serdes_not_start():
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
mock_callable = Mock(return_value=complex_result)
Expand Down Expand Up @@ -334,6 +341,7 @@ def test_child_handler_custom_serdes_already_succeeded():
mock_result.is_succeeded.return_value = True
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_result.result = '{"key": "VALUE", "number": "84", "list": [1, 2, 3]}'
mock_state.get_checkpoint_result.return_value = mock_result
mock_callable = Mock()
Expand All @@ -352,3 +360,94 @@ def test_child_handler_custom_serdes_already_succeeded():


# endregion child_handler


# large payload with summary generator
def test_child_handler_large_payload_with_summary_generator():
"""Test child_handler with large payload and summary generator."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
large_result = "large" * 256 * 1024
mock_callable = Mock(return_value=large_result)

def my_summary(result: str) -> str:
return "summary"

child_config: ChildConfig = ChildConfig[str](summary_generator=my_summary)

actual_result = child_handler(
mock_callable,
mock_state,
OperationIdentifier("op9", None, "test_name"),
child_config,
)

assert large_result == actual_result
success_call = mock_state.create_checkpoint.call_args_list[1]
success_operation = success_call[1]["operation_update"]
assert success_operation.context_options.replay_children
expected_checkpoointed_result = "summary"
assert success_operation.payload == expected_checkpoointed_result


# large payload without summary generator
def test_child_handler_large_payload_without_summary_generator():
"""Test child_handler with large payload and no summary generator."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = False
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = False
mock_result.is_replay_children.return_value = False
mock_state.get_checkpoint_result.return_value = mock_result
large_result = "large" * 256 * 1024
mock_callable = Mock(return_value=large_result)
child_config: ChildConfig = ChildConfig()

actual_result = child_handler(
mock_callable,
mock_state,
OperationIdentifier("op9", None, "test_name"),
child_config,
)

assert large_result == actual_result
success_call = mock_state.create_checkpoint.call_args_list[1]
success_operation = success_call[1]["operation_update"]
assert success_operation.context_options.replay_children
expected_checkpoointed_result = ""
assert success_operation.payload == expected_checkpoointed_result


# mocked children replay mode execute the function again
def test_child_handler_replay_children_mode():
"""Test child_handler in ReplayChildren mode."""
mock_state = Mock(spec=ExecutionState)
mock_state.durable_execution_arn = "test_arn"
mock_result = Mock()
mock_result.is_succeeded.return_value = True
mock_result.is_failed.return_value = False
mock_result.is_started.return_value = True
mock_result.is_replay_children.return_value = True
mock_state.get_checkpoint_result.return_value = mock_result
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
mock_callable = Mock(return_value=complex_result)
child_config: ChildConfig = ChildConfig()

actual_result = child_handler(
mock_callable,
mock_state,
OperationIdentifier("op9", None, "test_name"),
child_config,
)

assert actual_result == complex_result

mock_state.create_checkpoint.assert_not_called()
Loading