Skip to content

Commit e2565f2

Browse files
committed
fix(sdk): pass item_serdes to executor
1. Pass item_serdes to executor factory methods. 2. Add tests to verify fallback and default behaviour.
1 parent a950699 commit e2565f2

4 files changed

Lines changed: 253 additions & 0 deletions

File tree

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def from_items(
8282
name_prefix="map-item-",
8383
serdes=config.serdes,
8484
summary_generator=config.summary_generator,
85+
item_serdes=config.item_serdes,
8586
)
8687

8788
def execute_item(self, child_context, executable: Executable[Callable]) -> R:

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def from_callables(
6969
name_prefix="parallel-branch-",
7070
serdes=config.serdes,
7171
summary_generator=config.summary_generator,
72+
item_serdes=config.item_serdes,
7273
)
7374

7475
def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301

tests/operation/map_test.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from unittest.mock import Mock, patch
44

5+
import pytest
6+
57
# Mock the executor.execute method
68
from aws_durable_execution_sdk_python.concurrency import (
79
BatchItem,
@@ -15,6 +17,7 @@
1517
ItemBatcher,
1618
MapConfig,
1719
)
20+
from aws_durable_execution_sdk_python.context import DurableContext
1821
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1922
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
2023
from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler
@@ -750,3 +753,128 @@ def get_checkpoint_result(self, operation_id):
750753
# Verify replay was called, execute was not
751754
mock_replay.assert_called_once()
752755
mock_execute.assert_not_called()
756+
757+
758+
@pytest.mark.parametrize(
759+
("item_serdes", "batch_serdes"),
760+
[
761+
(Mock(), Mock()),
762+
(None, Mock()),
763+
(Mock(), None),
764+
],
765+
)
766+
@patch("aws_durable_execution_sdk_python.operation.child.serialize")
767+
def test_map_item_serialize(mock_serialize, item_serdes, batch_serdes):
768+
"""Test map serializes items with item_serdes or fallback."""
769+
mock_serialize.return_value = '"serialized"'
770+
771+
parent_checkpoint = Mock()
772+
parent_checkpoint.is_succeeded.return_value = False
773+
parent_checkpoint.is_failed.return_value = False
774+
parent_checkpoint.is_started.return_value = False
775+
parent_checkpoint.is_existent.return_value = True
776+
parent_checkpoint.is_replay_children.return_value = False
777+
778+
child_checkpoint = Mock()
779+
child_checkpoint.is_succeeded.return_value = False
780+
child_checkpoint.is_failed.return_value = False
781+
child_checkpoint.is_started.return_value = False
782+
child_checkpoint.is_existent.return_value = True
783+
child_checkpoint.is_replay_children.return_value = False
784+
785+
def get_checkpoint(op_id):
786+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
787+
788+
mock_state = Mock()
789+
mock_state.durable_execution_arn = "arn:test"
790+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
791+
mock_state.create_checkpoint = Mock()
792+
793+
context_map = {}
794+
795+
def create_id(self, i):
796+
ctx_id = id(self)
797+
if ctx_id not in context_map:
798+
context_map[ctx_id] = []
799+
context_map[ctx_id].append(i)
800+
return (
801+
"parent"
802+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
803+
else f"child-{i}"
804+
)
805+
806+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
807+
context = DurableContext(state=mock_state)
808+
context.map(
809+
["a", "b"],
810+
lambda ctx, item, idx, items: item,
811+
config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes),
812+
)
813+
814+
expected = item_serdes or batch_serdes
815+
assert mock_serialize.call_args_list[0][1]["serdes"] is expected
816+
assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0"
817+
assert mock_serialize.call_args_list[1][1]["serdes"] is expected
818+
assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1"
819+
assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes
820+
assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent"
821+
822+
823+
@pytest.mark.parametrize(
824+
("item_serdes", "batch_serdes"),
825+
[
826+
(Mock(), Mock()),
827+
(None, Mock()),
828+
(Mock(), None),
829+
],
830+
)
831+
@patch("aws_durable_execution_sdk_python.operation.child.deserialize")
832+
def test_map_item_deserialize(mock_deserialize, item_serdes, batch_serdes):
833+
"""Test map deserializes items with item_serdes or fallback."""
834+
mock_deserialize.return_value = "deserialized"
835+
836+
parent_checkpoint = Mock()
837+
parent_checkpoint.is_succeeded.return_value = False
838+
parent_checkpoint.is_failed.return_value = False
839+
parent_checkpoint.is_existent.return_value = False
840+
841+
child_checkpoint = Mock()
842+
child_checkpoint.is_succeeded.return_value = True
843+
child_checkpoint.is_failed.return_value = False
844+
child_checkpoint.is_replay_children.return_value = False
845+
child_checkpoint.result = '"cached"'
846+
847+
def get_checkpoint(op_id):
848+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
849+
850+
mock_state = Mock()
851+
mock_state.durable_execution_arn = "arn:test"
852+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
853+
mock_state.create_checkpoint = Mock()
854+
855+
context_map = {}
856+
857+
def create_id(self, i):
858+
ctx_id = id(self)
859+
if ctx_id not in context_map:
860+
context_map[ctx_id] = []
861+
context_map[ctx_id].append(i)
862+
return (
863+
"parent"
864+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
865+
else f"child-{i}"
866+
)
867+
868+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
869+
context = DurableContext(state=mock_state)
870+
context.map(
871+
["a", "b"],
872+
lambda ctx, item, idx, items: item,
873+
config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes),
874+
)
875+
876+
expected = item_serdes or batch_serdes
877+
assert mock_deserialize.call_args_list[0][1]["serdes"] is expected
878+
assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0"
879+
assert mock_deserialize.call_args_list[1][1]["serdes"] is expected
880+
assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"

tests/operation/parallel_test.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Executable,
1515
)
1616
from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig
17+
from aws_durable_execution_sdk_python.context import DurableContext
1718
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1819
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
1920
from aws_durable_execution_sdk_python.operation.parallel import (
@@ -734,3 +735,125 @@ def get_checkpoint_result(self, operation_id):
734735
# Verify replay was called, execute was not
735736
mock_replay.assert_called_once()
736737
mock_execute.assert_not_called()
738+
739+
@pytest.mark.parametrize(
740+
("item_serdes", "batch_serdes"),
741+
[
742+
(Mock(), Mock()),
743+
(None, Mock()),
744+
(Mock(), None),
745+
],
746+
)
747+
@patch("aws_durable_execution_sdk_python.operation.child.serialize")
748+
def test_parallel_item_serialize(mock_serialize, item_serdes, batch_serdes):
749+
"""Test parallel serializes branches with item_serdes or fallback."""
750+
mock_serialize.return_value = '"serialized"'
751+
752+
parent_checkpoint = Mock()
753+
parent_checkpoint.is_succeeded.return_value = False
754+
parent_checkpoint.is_failed.return_value = False
755+
parent_checkpoint.is_started.return_value = False
756+
parent_checkpoint.is_existent.return_value = True
757+
parent_checkpoint.is_replay_children.return_value = False
758+
759+
child_checkpoint = Mock()
760+
child_checkpoint.is_succeeded.return_value = False
761+
child_checkpoint.is_failed.return_value = False
762+
child_checkpoint.is_started.return_value = False
763+
child_checkpoint.is_existent.return_value = True
764+
child_checkpoint.is_replay_children.return_value = False
765+
766+
def get_checkpoint(op_id):
767+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
768+
769+
mock_state = Mock()
770+
mock_state.durable_execution_arn = "arn:test"
771+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
772+
mock_state.create_checkpoint = Mock()
773+
774+
context_map = {}
775+
776+
def create_id(self, i):
777+
ctx_id = id(self)
778+
if ctx_id not in context_map:
779+
context_map[ctx_id] = []
780+
context_map[ctx_id].append(i)
781+
return (
782+
"parent"
783+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
784+
else f"child-{i}"
785+
)
786+
787+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
788+
context = DurableContext(state=mock_state)
789+
context.parallel(
790+
[lambda ctx: "a", lambda ctx: "b"],
791+
config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes),
792+
)
793+
794+
expected = item_serdes or batch_serdes
795+
assert mock_serialize.call_args_list[0][1]["serdes"] is expected
796+
assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0"
797+
assert mock_serialize.call_args_list[1][1]["serdes"] is expected
798+
assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1"
799+
assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes
800+
assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent"
801+
802+
803+
@pytest.mark.parametrize(
804+
("item_serdes", "batch_serdes"),
805+
[
806+
(Mock(), Mock()),
807+
(None, Mock()),
808+
(Mock(), None),
809+
],
810+
)
811+
@patch("aws_durable_execution_sdk_python.operation.child.deserialize")
812+
def test_parallel_item_deserialize(mock_deserialize, item_serdes, batch_serdes):
813+
"""Test parallel deserializes branches with item_serdes or fallback."""
814+
mock_deserialize.return_value = "deserialized"
815+
816+
parent_checkpoint = Mock()
817+
parent_checkpoint.is_succeeded.return_value = False
818+
parent_checkpoint.is_failed.return_value = False
819+
parent_checkpoint.is_existent.return_value = False
820+
821+
child_checkpoint = Mock()
822+
child_checkpoint.is_succeeded.return_value = True
823+
child_checkpoint.is_failed.return_value = False
824+
child_checkpoint.is_replay_children.return_value = False
825+
child_checkpoint.result = '"cached"'
826+
827+
def get_checkpoint(op_id):
828+
return child_checkpoint if op_id.startswith("child-") else parent_checkpoint
829+
830+
mock_state = Mock()
831+
mock_state.durable_execution_arn = "arn:test"
832+
mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint)
833+
mock_state.create_checkpoint = Mock()
834+
835+
context_map = {}
836+
837+
def create_id(self, i):
838+
ctx_id = id(self)
839+
if ctx_id not in context_map:
840+
context_map[ctx_id] = []
841+
context_map[ctx_id].append(i)
842+
return (
843+
"parent"
844+
if len(context_map) == 1 and len(context_map[ctx_id]) == 1
845+
else f"child-{i}"
846+
)
847+
848+
with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id):
849+
context = DurableContext(state=mock_state)
850+
context.parallel(
851+
[lambda ctx: "a", lambda ctx: "b"],
852+
config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes),
853+
)
854+
855+
expected = item_serdes or batch_serdes
856+
assert mock_deserialize.call_args_list[0][1]["serdes"] is expected
857+
assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0"
858+
assert mock_deserialize.call_args_list[1][1]["serdes"] is expected
859+
assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"

0 commit comments

Comments
 (0)