diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index 2938ca9..f64892e 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -343,7 +343,11 @@ def map_in_child_context() -> BatchResult[R]: operation_identifier=operation_identifier, config=ChildConfig( sub_type=OperationSubType.MAP, - serdes=config.serdes if config is not None else None, + serdes=getattr(config, "serdes", None), + # child_handler should only know the serdes of the parent serdes, + # the item serdes will be passed when we are actually executing + # the branch within its own child_handler. + item_serdes=None, ), ) @@ -380,7 +384,11 @@ def parallel_in_child_context() -> BatchResult[T]: operation_identifier=operation_identifier, config=ChildConfig( sub_type=OperationSubType.PARALLEL, - serdes=config.serdes if config is not None else None, + serdes=getattr(config, "serdes", None), + # child_handler should only know the serdes of the parent serdes, + # the item serdes will be passed when we are actually executing + # the branch within its own child_handler. + item_serdes=None, ), ) diff --git a/src/aws_durable_execution_sdk_python/operation/map.py b/src/aws_durable_execution_sdk_python/operation/map.py index ed76bb4..4d0c2e5 100644 --- a/src/aws_durable_execution_sdk_python/operation/map.py +++ b/src/aws_durable_execution_sdk_python/operation/map.py @@ -82,6 +82,7 @@ def from_items( name_prefix="map-item-", serdes=config.serdes, summary_generator=config.summary_generator, + item_serdes=config.item_serdes, ) def execute_item(self, child_context, executable: Executable[Callable]) -> R: diff --git a/src/aws_durable_execution_sdk_python/operation/parallel.py b/src/aws_durable_execution_sdk_python/operation/parallel.py index e81499f..39bebe0 100644 --- a/src/aws_durable_execution_sdk_python/operation/parallel.py +++ b/src/aws_durable_execution_sdk_python/operation/parallel.py @@ -69,6 +69,7 @@ def from_callables( name_prefix="parallel-branch-", serdes=config.serdes, summary_generator=config.summary_generator, + item_serdes=config.item_serdes, ) def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301 diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index eb099d1..edfceec 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -2,6 +2,8 @@ from unittest.mock import Mock, patch +import pytest + # Mock the executor.execute method from aws_durable_execution_sdk_python.concurrency import ( BatchItem, @@ -15,6 +17,7 @@ ItemBatcher, MapConfig, ) +from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler @@ -750,3 +753,128 @@ def get_checkpoint_result(self, operation_id): # Verify replay was called, execute was not mock_replay.assert_called_once() mock_execute.assert_not_called() + + +@pytest.mark.parametrize( + ("item_serdes", "batch_serdes"), + [ + (Mock(), Mock()), + (None, Mock()), + (Mock(), None), + ], +) +@patch("aws_durable_execution_sdk_python.operation.child.serialize") +def test_map_item_serialize(mock_serialize, item_serdes, batch_serdes): + """Test map serializes items with item_serdes or fallback.""" + mock_serialize.return_value = '"serialized"' + + parent_checkpoint = Mock() + parent_checkpoint.is_succeeded.return_value = False + parent_checkpoint.is_failed.return_value = False + parent_checkpoint.is_started.return_value = False + parent_checkpoint.is_existent.return_value = True + parent_checkpoint.is_replay_children.return_value = False + + child_checkpoint = Mock() + child_checkpoint.is_succeeded.return_value = False + child_checkpoint.is_failed.return_value = False + child_checkpoint.is_started.return_value = False + child_checkpoint.is_existent.return_value = True + child_checkpoint.is_replay_children.return_value = False + + def get_checkpoint(op_id): + return child_checkpoint if op_id.startswith("child-") else parent_checkpoint + + mock_state = Mock() + mock_state.durable_execution_arn = "arn:test" + mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) + mock_state.create_checkpoint = Mock() + + context_map = {} + + def create_id(self, i): + ctx_id = id(self) + if ctx_id not in context_map: + context_map[ctx_id] = [] + context_map[ctx_id].append(i) + return ( + "parent" + if len(context_map) == 1 and len(context_map[ctx_id]) == 1 + else f"child-{i}" + ) + + with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + context = DurableContext(state=mock_state) + context.map( + ["a", "b"], + lambda ctx, item, idx, items: item, + config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes), + ) + + expected = item_serdes or batch_serdes + assert mock_serialize.call_args_list[0][1]["serdes"] is expected + assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0" + assert mock_serialize.call_args_list[1][1]["serdes"] is expected + assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1" + assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes + assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent" + + +@pytest.mark.parametrize( + ("item_serdes", "batch_serdes"), + [ + (Mock(), Mock()), + (None, Mock()), + (Mock(), None), + ], +) +@patch("aws_durable_execution_sdk_python.operation.child.deserialize") +def test_map_item_deserialize(mock_deserialize, item_serdes, batch_serdes): + """Test map deserializes items with item_serdes or fallback.""" + mock_deserialize.return_value = "deserialized" + + parent_checkpoint = Mock() + parent_checkpoint.is_succeeded.return_value = False + parent_checkpoint.is_failed.return_value = False + parent_checkpoint.is_existent.return_value = False + + child_checkpoint = Mock() + child_checkpoint.is_succeeded.return_value = True + child_checkpoint.is_failed.return_value = False + child_checkpoint.is_replay_children.return_value = False + child_checkpoint.result = '"cached"' + + def get_checkpoint(op_id): + return child_checkpoint if op_id.startswith("child-") else parent_checkpoint + + mock_state = Mock() + mock_state.durable_execution_arn = "arn:test" + mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) + mock_state.create_checkpoint = Mock() + + context_map = {} + + def create_id(self, i): + ctx_id = id(self) + if ctx_id not in context_map: + context_map[ctx_id] = [] + context_map[ctx_id].append(i) + return ( + "parent" + if len(context_map) == 1 and len(context_map[ctx_id]) == 1 + else f"child-{i}" + ) + + with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + context = DurableContext(state=mock_state) + context.map( + ["a", "b"], + lambda ctx, item, idx, items: item, + config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes), + ) + + expected = item_serdes or batch_serdes + assert mock_deserialize.call_args_list[0][1]["serdes"] is expected + assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0" + assert mock_deserialize.call_args_list[1][1]["serdes"] is expected + assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1" diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 54f2229..b2f3cf4 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -14,6 +14,7 @@ Executable, ) from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig +from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import OperationSubType from aws_durable_execution_sdk_python.operation.parallel import ( @@ -734,3 +735,126 @@ def get_checkpoint_result(self, operation_id): # Verify replay was called, execute was not mock_replay.assert_called_once() mock_execute.assert_not_called() + + +@pytest.mark.parametrize( + ("item_serdes", "batch_serdes"), + [ + (Mock(), Mock()), + (None, Mock()), + (Mock(), None), + ], +) +@patch("aws_durable_execution_sdk_python.operation.child.serialize") +def test_parallel_item_serialize(mock_serialize, item_serdes, batch_serdes): + """Test parallel serializes branches with item_serdes or fallback.""" + mock_serialize.return_value = '"serialized"' + + parent_checkpoint = Mock() + parent_checkpoint.is_succeeded.return_value = False + parent_checkpoint.is_failed.return_value = False + parent_checkpoint.is_started.return_value = False + parent_checkpoint.is_existent.return_value = True + parent_checkpoint.is_replay_children.return_value = False + + child_checkpoint = Mock() + child_checkpoint.is_succeeded.return_value = False + child_checkpoint.is_failed.return_value = False + child_checkpoint.is_started.return_value = False + child_checkpoint.is_existent.return_value = True + child_checkpoint.is_replay_children.return_value = False + + def get_checkpoint(op_id): + return child_checkpoint if op_id.startswith("child-") else parent_checkpoint + + mock_state = Mock() + mock_state.durable_execution_arn = "arn:test" + mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) + mock_state.create_checkpoint = Mock() + + context_map = {} + + def create_id(self, i): + ctx_id = id(self) + if ctx_id not in context_map: + context_map[ctx_id] = [] + context_map[ctx_id].append(i) + return ( + "parent" + if len(context_map) == 1 and len(context_map[ctx_id]) == 1 + else f"child-{i}" + ) + + with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + context = DurableContext(state=mock_state) + context.parallel( + [lambda ctx: "a", lambda ctx: "b"], + config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), + ) + + expected = item_serdes or batch_serdes + assert mock_serialize.call_args_list[0][1]["serdes"] is expected + assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0" + assert mock_serialize.call_args_list[1][1]["serdes"] is expected + assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1" + assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes + assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent" + + +@pytest.mark.parametrize( + ("item_serdes", "batch_serdes"), + [ + (Mock(), Mock()), + (None, Mock()), + (Mock(), None), + ], +) +@patch("aws_durable_execution_sdk_python.operation.child.deserialize") +def test_parallel_item_deserialize(mock_deserialize, item_serdes, batch_serdes): + """Test parallel deserializes branches with item_serdes or fallback.""" + mock_deserialize.return_value = "deserialized" + + parent_checkpoint = Mock() + parent_checkpoint.is_succeeded.return_value = False + parent_checkpoint.is_failed.return_value = False + parent_checkpoint.is_existent.return_value = False + + child_checkpoint = Mock() + child_checkpoint.is_succeeded.return_value = True + child_checkpoint.is_failed.return_value = False + child_checkpoint.is_replay_children.return_value = False + child_checkpoint.result = '"cached"' + + def get_checkpoint(op_id): + return child_checkpoint if op_id.startswith("child-") else parent_checkpoint + + mock_state = Mock() + mock_state.durable_execution_arn = "arn:test" + mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) + mock_state.create_checkpoint = Mock() + + context_map = {} + + def create_id(self, i): + ctx_id = id(self) + if ctx_id not in context_map: + context_map[ctx_id] = [] + context_map[ctx_id].append(i) + return ( + "parent" + if len(context_map) == 1 and len(context_map[ctx_id]) == 1 + else f"child-{i}" + ) + + with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): + context = DurableContext(state=mock_state) + context.parallel( + [lambda ctx: "a", lambda ctx: "b"], + config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), + ) + + expected = item_serdes or batch_serdes + assert mock_deserialize.call_args_list[0][1]["serdes"] is expected + assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0" + assert mock_deserialize.call_args_list[1][1]["serdes"] is expected + assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1"