Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,11 @@ def map_in_child_context() -> BatchResult[R]:
operation_identifier=operation_identifier,
config=ChildConfig(
sub_type=OperationSubType.MAP,
serdes=config.serdes if config is not None else None,
serdes=getattr(config, "serdes", None),
# child_handler should only know the serdes of the parent serdes,
# the item serdes will be passed when we are actually executing
# the branch within its own child_handler.
item_serdes=None,
),
)

Expand Down Expand Up @@ -380,7 +384,11 @@ def parallel_in_child_context() -> BatchResult[T]:
operation_identifier=operation_identifier,
config=ChildConfig(
sub_type=OperationSubType.PARALLEL,
serdes=config.serdes if config is not None else None,
serdes=getattr(config, "serdes", None),
# child_handler should only know the serdes of the parent serdes,
# the item serdes will be passed when we are actually executing
# the branch within its own child_handler.
item_serdes=None,
),
)

Expand Down
1 change: 1 addition & 0 deletions src/aws_durable_execution_sdk_python/operation/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def from_items(
name_prefix="map-item-",
serdes=config.serdes,
summary_generator=config.summary_generator,
item_serdes=config.item_serdes,
)

def execute_item(self, child_context, executable: Executable[Callable]) -> R:
Expand Down
1 change: 1 addition & 0 deletions src/aws_durable_execution_sdk_python/operation/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def from_callables(
name_prefix="parallel-branch-",
serdes=config.serdes,
summary_generator=config.summary_generator,
item_serdes=config.item_serdes,
)

def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301
Expand Down
128 changes: 128 additions & 0 deletions tests/operation/map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from unittest.mock import Mock, patch

import pytest

# Mock the executor.execute method
from aws_durable_execution_sdk_python.concurrency import (
BatchItem,
Expand All @@ -15,6 +17,7 @@
ItemBatcher,
MapConfig,
)
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler
Expand Down Expand Up @@ -750,3 +753,128 @@ def get_checkpoint_result(self, operation_id):
# Verify replay was called, execute was not
mock_replay.assert_called_once()
mock_execute.assert_not_called()


@pytest.mark.parametrize(
("item_serdes", "batch_serdes"),
[
(Mock(), Mock()),
(None, Mock()),
(Mock(), None),
],
)
@patch("aws_durable_execution_sdk_python.operation.child.serialize")
def test_map_item_serialize(mock_serialize, item_serdes, batch_serdes):
"""Test map serializes items with item_serdes or fallback."""
mock_serialize.return_value = '"serialized"'

parent_checkpoint = Mock()
parent_checkpoint.is_succeeded.return_value = False
parent_checkpoint.is_failed.return_value = False
parent_checkpoint.is_started.return_value = False
parent_checkpoint.is_existent.return_value = True
parent_checkpoint.is_replay_children.return_value = False

child_checkpoint = Mock()
child_checkpoint.is_succeeded.return_value = False
child_checkpoint.is_failed.return_value = False
child_checkpoint.is_started.return_value = False
child_checkpoint.is_existent.return_value = True
child_checkpoint.is_replay_children.return_value = False

def get_checkpoint(op_id):
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint

mock_state = Mock()
mock_state.durable_execution_arn = "arn:test"
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
mock_state.create_checkpoint = Mock()

context_map = {}

def create_id(self, i):
ctx_id = id(self)
if ctx_id not in context_map:
context_map[ctx_id] = []
context_map[ctx_id].append(i)
return (
"parent"
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
else f"child-{i}"
)
Comment on lines +768 to +804
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting up the state to allow us to properly reach children and execute them.


with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
context = DurableContext(state=mock_state)
context.map(
["a", "b"],
lambda ctx, item, idx, items: item,
config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes),
)

expected = item_serdes or batch_serdes
assert mock_serialize.call_args_list[0][1]["serdes"] is expected
assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0"
assert mock_serialize.call_args_list[1][1]["serdes"] is expected
assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1"
assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes
assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent"
Comment on lines +815 to +820
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test that we are actually calling serialize.

Map and Parallel produce trees of operations. We expect to see N children first, and then 1 operation for the parent.



@pytest.mark.parametrize(
("item_serdes", "batch_serdes"),
[
(Mock(), Mock()),
(None, Mock()),
(Mock(), None),
],
)
@patch("aws_durable_execution_sdk_python.operation.child.deserialize")
def test_map_item_deserialize(mock_deserialize, item_serdes, batch_serdes):
"""Test map deserializes items with item_serdes or fallback."""
mock_deserialize.return_value = "deserialized"

parent_checkpoint = Mock()
parent_checkpoint.is_succeeded.return_value = False
parent_checkpoint.is_failed.return_value = False
parent_checkpoint.is_existent.return_value = False

child_checkpoint = Mock()
child_checkpoint.is_succeeded.return_value = True
child_checkpoint.is_failed.return_value = False
child_checkpoint.is_replay_children.return_value = False
child_checkpoint.result = '"cached"'

def get_checkpoint(op_id):
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint

mock_state = Mock()
mock_state.durable_execution_arn = "arn:test"
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
mock_state.create_checkpoint = Mock()

context_map = {}

def create_id(self, i):
ctx_id = id(self)
if ctx_id not in context_map:
context_map[ctx_id] = []
context_map[ctx_id].append(i)
return (
"parent"
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
else f"child-{i}"
)

with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
context = DurableContext(state=mock_state)
context.map(
["a", "b"],
lambda ctx, item, idx, items: item,
config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes),
)
Comment on lines +833 to +874
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context and state setup.


expected = item_serdes or batch_serdes
assert mock_deserialize.call_args_list[0][1]["serdes"] is expected
assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0"
assert mock_deserialize.call_args_list[1][1]["serdes"] is expected
assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"
Comment on lines +875 to +880
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that we are calling deserialize on children.

124 changes: 124 additions & 0 deletions tests/operation/parallel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Executable,
)
from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig
from aws_durable_execution_sdk_python.context import DurableContext
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
from aws_durable_execution_sdk_python.operation.parallel import (
Expand Down Expand Up @@ -734,3 +735,126 @@ def get_checkpoint_result(self, operation_id):
# Verify replay was called, execute was not
mock_replay.assert_called_once()
mock_execute.assert_not_called()


@pytest.mark.parametrize(
("item_serdes", "batch_serdes"),
[
(Mock(), Mock()),
(None, Mock()),
(Mock(), None),
],
)
@patch("aws_durable_execution_sdk_python.operation.child.serialize")
def test_parallel_item_serialize(mock_serialize, item_serdes, batch_serdes):
"""Test parallel serializes branches with item_serdes or fallback."""
mock_serialize.return_value = '"serialized"'

parent_checkpoint = Mock()
parent_checkpoint.is_succeeded.return_value = False
parent_checkpoint.is_failed.return_value = False
parent_checkpoint.is_started.return_value = False
parent_checkpoint.is_existent.return_value = True
parent_checkpoint.is_replay_children.return_value = False

child_checkpoint = Mock()
child_checkpoint.is_succeeded.return_value = False
child_checkpoint.is_failed.return_value = False
child_checkpoint.is_started.return_value = False
child_checkpoint.is_existent.return_value = True
child_checkpoint.is_replay_children.return_value = False

def get_checkpoint(op_id):
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint

mock_state = Mock()
mock_state.durable_execution_arn = "arn:test"
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
mock_state.create_checkpoint = Mock()

context_map = {}

def create_id(self, i):
ctx_id = id(self)
if ctx_id not in context_map:
context_map[ctx_id] = []
context_map[ctx_id].append(i)
return (
"parent"
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
else f"child-{i}"
)

with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
context = DurableContext(state=mock_state)
context.parallel(
[lambda ctx: "a", lambda ctx: "b"],
config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes),
)
Comment on lines +748 to +793
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setup the state.


expected = item_serdes or batch_serdes
assert mock_serialize.call_args_list[0][1]["serdes"] is expected
assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0"
assert mock_serialize.call_args_list[1][1]["serdes"] is expected
assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1"
assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes
assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent"
Comment on lines +796 to +801
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify that we are seeing children and parent in correct ordering here as we are executing recursively.



@pytest.mark.parametrize(
("item_serdes", "batch_serdes"),
[
(Mock(), Mock()),
(None, Mock()),
(Mock(), None),
],
)
@patch("aws_durable_execution_sdk_python.operation.child.deserialize")
def test_parallel_item_deserialize(mock_deserialize, item_serdes, batch_serdes):
"""Test parallel deserializes branches with item_serdes or fallback."""
mock_deserialize.return_value = "deserialized"

parent_checkpoint = Mock()
parent_checkpoint.is_succeeded.return_value = False
parent_checkpoint.is_failed.return_value = False
parent_checkpoint.is_existent.return_value = False

child_checkpoint = Mock()
child_checkpoint.is_succeeded.return_value = True
child_checkpoint.is_failed.return_value = False
child_checkpoint.is_replay_children.return_value = False
child_checkpoint.result = '"cached"'

def get_checkpoint(op_id):
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint

mock_state = Mock()
mock_state.durable_execution_arn = "arn:test"
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
mock_state.create_checkpoint = Mock()

context_map = {}

def create_id(self, i):
ctx_id = id(self)
if ctx_id not in context_map:
context_map[ctx_id] = []
context_map[ctx_id].append(i)
return (
"parent"
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
else f"child-{i}"
)

with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
context = DurableContext(state=mock_state)
context.parallel(
[lambda ctx: "a", lambda ctx: "b"],
config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes),
)

expected = item_serdes or batch_serdes
assert mock_deserialize.call_args_list[0][1]["serdes"] is expected
assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0"
assert mock_deserialize.call_args_list[1][1]["serdes"] is expected
assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"
Loading