-
Notifications
You must be signed in to change notification settings - Fork 15
fix(sdk): pass item_serdes in factory_method #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,8 @@ | |
|
|
||
| from unittest.mock import Mock, patch | ||
|
|
||
| import pytest | ||
|
|
||
| # Mock the executor.execute method | ||
| from aws_durable_execution_sdk_python.concurrency import ( | ||
| BatchItem, | ||
|
|
@@ -15,6 +17,7 @@ | |
| ItemBatcher, | ||
| MapConfig, | ||
| ) | ||
| from aws_durable_execution_sdk_python.context import DurableContext | ||
| from aws_durable_execution_sdk_python.identifier import OperationIdentifier | ||
| from aws_durable_execution_sdk_python.lambda_service import OperationSubType | ||
| from aws_durable_execution_sdk_python.operation.map import MapExecutor, map_handler | ||
|
|
@@ -750,3 +753,128 @@ def get_checkpoint_result(self, operation_id): | |
| # Verify replay was called, execute was not | ||
| mock_replay.assert_called_once() | ||
| mock_execute.assert_not_called() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("item_serdes", "batch_serdes"), | ||
| [ | ||
| (Mock(), Mock()), | ||
| (None, Mock()), | ||
| (Mock(), None), | ||
| ], | ||
| ) | ||
| @patch("aws_durable_execution_sdk_python.operation.child.serialize") | ||
| def test_map_item_serialize(mock_serialize, item_serdes, batch_serdes): | ||
| """Test map serializes items with item_serdes or fallback.""" | ||
| mock_serialize.return_value = '"serialized"' | ||
|
|
||
| parent_checkpoint = Mock() | ||
| parent_checkpoint.is_succeeded.return_value = False | ||
| parent_checkpoint.is_failed.return_value = False | ||
| parent_checkpoint.is_started.return_value = False | ||
| parent_checkpoint.is_existent.return_value = True | ||
| parent_checkpoint.is_replay_children.return_value = False | ||
|
|
||
| child_checkpoint = Mock() | ||
| child_checkpoint.is_succeeded.return_value = False | ||
| child_checkpoint.is_failed.return_value = False | ||
| child_checkpoint.is_started.return_value = False | ||
| child_checkpoint.is_existent.return_value = True | ||
| child_checkpoint.is_replay_children.return_value = False | ||
|
|
||
| def get_checkpoint(op_id): | ||
| return child_checkpoint if op_id.startswith("child-") else parent_checkpoint | ||
|
|
||
| mock_state = Mock() | ||
| mock_state.durable_execution_arn = "arn:test" | ||
| mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) | ||
| mock_state.create_checkpoint = Mock() | ||
|
|
||
| context_map = {} | ||
|
|
||
| def create_id(self, i): | ||
| ctx_id = id(self) | ||
| if ctx_id not in context_map: | ||
| context_map[ctx_id] = [] | ||
| context_map[ctx_id].append(i) | ||
| return ( | ||
| "parent" | ||
| if len(context_map) == 1 and len(context_map[ctx_id]) == 1 | ||
| else f"child-{i}" | ||
| ) | ||
|
|
||
| with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): | ||
| context = DurableContext(state=mock_state) | ||
| context.map( | ||
| ["a", "b"], | ||
| lambda ctx, item, idx, items: item, | ||
| config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes), | ||
| ) | ||
|
|
||
| expected = item_serdes or batch_serdes | ||
| assert mock_serialize.call_args_list[0][1]["serdes"] is expected | ||
| assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0" | ||
| assert mock_serialize.call_args_list[1][1]["serdes"] is expected | ||
| assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1" | ||
| assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes | ||
| assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent" | ||
|
Comment on lines
+815
to
+820
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test that we are actually calling serialize. Map and Parallel produce trees of operations. We expect to see N children first, and then 1 operation for the parent. |
||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("item_serdes", "batch_serdes"), | ||
| [ | ||
| (Mock(), Mock()), | ||
| (None, Mock()), | ||
| (Mock(), None), | ||
| ], | ||
| ) | ||
| @patch("aws_durable_execution_sdk_python.operation.child.deserialize") | ||
| def test_map_item_deserialize(mock_deserialize, item_serdes, batch_serdes): | ||
| """Test map deserializes items with item_serdes or fallback.""" | ||
| mock_deserialize.return_value = "deserialized" | ||
|
|
||
| parent_checkpoint = Mock() | ||
| parent_checkpoint.is_succeeded.return_value = False | ||
| parent_checkpoint.is_failed.return_value = False | ||
| parent_checkpoint.is_existent.return_value = False | ||
|
|
||
| child_checkpoint = Mock() | ||
| child_checkpoint.is_succeeded.return_value = True | ||
| child_checkpoint.is_failed.return_value = False | ||
| child_checkpoint.is_replay_children.return_value = False | ||
| child_checkpoint.result = '"cached"' | ||
|
|
||
| def get_checkpoint(op_id): | ||
| return child_checkpoint if op_id.startswith("child-") else parent_checkpoint | ||
|
|
||
| mock_state = Mock() | ||
| mock_state.durable_execution_arn = "arn:test" | ||
| mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) | ||
| mock_state.create_checkpoint = Mock() | ||
|
|
||
| context_map = {} | ||
|
|
||
| def create_id(self, i): | ||
| ctx_id = id(self) | ||
| if ctx_id not in context_map: | ||
| context_map[ctx_id] = [] | ||
| context_map[ctx_id].append(i) | ||
| return ( | ||
| "parent" | ||
| if len(context_map) == 1 and len(context_map[ctx_id]) == 1 | ||
| else f"child-{i}" | ||
| ) | ||
|
|
||
| with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): | ||
| context = DurableContext(state=mock_state) | ||
| context.map( | ||
| ["a", "b"], | ||
| lambda ctx, item, idx, items: item, | ||
| config=MapConfig(serdes=batch_serdes, item_serdes=item_serdes), | ||
| ) | ||
|
Comment on lines
+833
to
+874
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Context and state setup. |
||
|
|
||
| expected = item_serdes or batch_serdes | ||
| assert mock_deserialize.call_args_list[0][1]["serdes"] is expected | ||
| assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0" | ||
| assert mock_deserialize.call_args_list[1][1]["serdes"] is expected | ||
| assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1" | ||
|
Comment on lines
+875
to
+880
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verify that we are calling deserialize on children. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| Executable, | ||
| ) | ||
| from aws_durable_execution_sdk_python.config import CompletionConfig, ParallelConfig | ||
| from aws_durable_execution_sdk_python.context import DurableContext | ||
| from aws_durable_execution_sdk_python.identifier import OperationIdentifier | ||
| from aws_durable_execution_sdk_python.lambda_service import OperationSubType | ||
| from aws_durable_execution_sdk_python.operation.parallel import ( | ||
|
|
@@ -734,3 +735,126 @@ def get_checkpoint_result(self, operation_id): | |
| # Verify replay was called, execute was not | ||
| mock_replay.assert_called_once() | ||
| mock_execute.assert_not_called() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("item_serdes", "batch_serdes"), | ||
| [ | ||
| (Mock(), Mock()), | ||
| (None, Mock()), | ||
| (Mock(), None), | ||
| ], | ||
| ) | ||
| @patch("aws_durable_execution_sdk_python.operation.child.serialize") | ||
| def test_parallel_item_serialize(mock_serialize, item_serdes, batch_serdes): | ||
| """Test parallel serializes branches with item_serdes or fallback.""" | ||
| mock_serialize.return_value = '"serialized"' | ||
|
|
||
| parent_checkpoint = Mock() | ||
| parent_checkpoint.is_succeeded.return_value = False | ||
| parent_checkpoint.is_failed.return_value = False | ||
| parent_checkpoint.is_started.return_value = False | ||
| parent_checkpoint.is_existent.return_value = True | ||
| parent_checkpoint.is_replay_children.return_value = False | ||
|
|
||
| child_checkpoint = Mock() | ||
| child_checkpoint.is_succeeded.return_value = False | ||
| child_checkpoint.is_failed.return_value = False | ||
| child_checkpoint.is_started.return_value = False | ||
| child_checkpoint.is_existent.return_value = True | ||
| child_checkpoint.is_replay_children.return_value = False | ||
|
|
||
| def get_checkpoint(op_id): | ||
| return child_checkpoint if op_id.startswith("child-") else parent_checkpoint | ||
|
|
||
| mock_state = Mock() | ||
| mock_state.durable_execution_arn = "arn:test" | ||
| mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) | ||
| mock_state.create_checkpoint = Mock() | ||
|
|
||
| context_map = {} | ||
|
|
||
| def create_id(self, i): | ||
| ctx_id = id(self) | ||
| if ctx_id not in context_map: | ||
| context_map[ctx_id] = [] | ||
| context_map[ctx_id].append(i) | ||
| return ( | ||
| "parent" | ||
| if len(context_map) == 1 and len(context_map[ctx_id]) == 1 | ||
| else f"child-{i}" | ||
| ) | ||
|
|
||
| with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): | ||
| context = DurableContext(state=mock_state) | ||
| context.parallel( | ||
| [lambda ctx: "a", lambda ctx: "b"], | ||
| config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), | ||
| ) | ||
|
Comment on lines
+748
to
+793
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setup the state. |
||
|
|
||
| expected = item_serdes or batch_serdes | ||
| assert mock_serialize.call_args_list[0][1]["serdes"] is expected | ||
| assert mock_serialize.call_args_list[0][1]["operation_id"] == "child-0" | ||
| assert mock_serialize.call_args_list[1][1]["serdes"] is expected | ||
| assert mock_serialize.call_args_list[1][1]["operation_id"] == "child-1" | ||
| assert mock_serialize.call_args_list[2][1]["serdes"] is batch_serdes | ||
| assert mock_serialize.call_args_list[2][1]["operation_id"] == "parent" | ||
|
Comment on lines
+796
to
+801
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verify that we are seeing children and parent in correct ordering here as we are executing recursively. |
||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("item_serdes", "batch_serdes"), | ||
| [ | ||
| (Mock(), Mock()), | ||
| (None, Mock()), | ||
| (Mock(), None), | ||
| ], | ||
| ) | ||
| @patch("aws_durable_execution_sdk_python.operation.child.deserialize") | ||
| def test_parallel_item_deserialize(mock_deserialize, item_serdes, batch_serdes): | ||
| """Test parallel deserializes branches with item_serdes or fallback.""" | ||
| mock_deserialize.return_value = "deserialized" | ||
|
|
||
| parent_checkpoint = Mock() | ||
| parent_checkpoint.is_succeeded.return_value = False | ||
| parent_checkpoint.is_failed.return_value = False | ||
| parent_checkpoint.is_existent.return_value = False | ||
|
|
||
| child_checkpoint = Mock() | ||
| child_checkpoint.is_succeeded.return_value = True | ||
| child_checkpoint.is_failed.return_value = False | ||
| child_checkpoint.is_replay_children.return_value = False | ||
| child_checkpoint.result = '"cached"' | ||
|
|
||
| def get_checkpoint(op_id): | ||
| return child_checkpoint if op_id.startswith("child-") else parent_checkpoint | ||
|
|
||
| mock_state = Mock() | ||
| mock_state.durable_execution_arn = "arn:test" | ||
| mock_state.get_checkpoint_result = Mock(side_effect=get_checkpoint) | ||
| mock_state.create_checkpoint = Mock() | ||
|
|
||
| context_map = {} | ||
|
|
||
| def create_id(self, i): | ||
| ctx_id = id(self) | ||
| if ctx_id not in context_map: | ||
| context_map[ctx_id] = [] | ||
| context_map[ctx_id].append(i) | ||
| return ( | ||
| "parent" | ||
| if len(context_map) == 1 and len(context_map[ctx_id]) == 1 | ||
| else f"child-{i}" | ||
| ) | ||
|
|
||
| with patch.object(DurableContext, "_create_step_id_for_logical_step", create_id): | ||
| context = DurableContext(state=mock_state) | ||
| context.parallel( | ||
| [lambda ctx: "a", lambda ctx: "b"], | ||
| config=ParallelConfig(serdes=batch_serdes, item_serdes=item_serdes), | ||
| ) | ||
|
|
||
| expected = item_serdes or batch_serdes | ||
| assert mock_deserialize.call_args_list[0][1]["serdes"] is expected | ||
| assert mock_deserialize.call_args_list[0][1]["operation_id"] == "child-0" | ||
| assert mock_deserialize.call_args_list[1][1]["serdes"] is expected | ||
| assert mock_deserialize.call_args_list[1][1]["operation_id"] == "child-1" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting up the state to allow us to properly reach children and execute them.