Skip to content

Commit d30b11b

Browse files
committed
feat: Batch Result serialization
- Adds serialization for batch result in the serdes module. Unfortunately we need to do an adhoc import as we are dealing with cyclical dependencies.
1 parent a950699 commit d30b11b

File tree

5 files changed

+288
-24
lines changed

5 files changed

+288
-24
lines changed

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class TypeTag(StrEnum):
6262
TUPLE = "t"
6363
LIST = "l"
6464
DICT = "m"
65+
BATCH_RESULT = "br"
6566

6667

6768
@dataclass(frozen=True)
@@ -206,7 +207,17 @@ def dispatcher(self):
206207

207208
def encode(self, obj: Any) -> EncodedValue:
208209
"""Encode container using dispatcher for recursive elements."""
210+
# Import here to avoid circular dependency
211+
# concurrency -> child_handler -> serdes -> concurrency
212+
from aws_durable_execution_sdk_python.concurrency import BatchResult
213+
209214
match obj:
215+
case BatchResult():
216+
# Encode BatchResult as dict with special tag
217+
return EncodedValue(
218+
TypeTag.BATCH_RESULT,
219+
self._wrap(obj.to_dict(), self.dispatcher).value,
220+
)
210221
case list():
211222
return EncodedValue(
212223
TypeTag.LIST, [self._wrap(v, self.dispatcher) for v in obj]
@@ -230,7 +241,15 @@ def encode(self, obj: Any) -> EncodedValue:
230241

231242
def decode(self, tag: TypeTag, value: Any) -> Any:
232243
"""Decode container using dispatcher for recursive elements."""
244+
# Import here to avoid circular dependency
245+
from aws_durable_execution_sdk_python.concurrency import BatchResult
246+
233247
match tag:
248+
case TypeTag.BATCH_RESULT:
249+
# Decode BatchResult from dict - value is already the dict structure
250+
# First decode it as a dict to unwrap all nested EncodedValues
251+
decoded_dict = self.decode(TypeTag.DICT, value)
252+
return BatchResult.from_dict(decoded_dict)
234253
case TypeTag.LIST:
235254
if not isinstance(value, list):
236255
msg = f"Expected list, got {type(value)}"
@@ -295,6 +314,11 @@ def encode(self, obj: Any) -> EncodedValue:
295314
case list() | tuple() | dict():
296315
return self.container_codec.encode(obj)
297316
case _:
317+
# Check if it's a BatchResult (handled by container_codec)
318+
from aws_durable_execution_sdk_python.concurrency import BatchResult
319+
320+
if isinstance(obj, BatchResult):
321+
return self.container_codec.encode(obj)
298322
msg = f"Unsupported type: {type(obj)}"
299323
raise SerDesError(msg)
300324

@@ -316,7 +340,7 @@ def decode(self, tag: TypeTag, value: Any) -> Any:
316340
return self.decimal_codec.decode(tag, value)
317341
case TypeTag.DATETIME | TypeTag.DATE:
318342
return self.datetime_codec.decode(tag, value)
319-
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
343+
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT | TypeTag.BATCH_RESULT:
320344
return self.container_codec.decode(tag, value)
321345
case _:
322346
msg = f"Unknown type tag: {tag}"

tests/concurrency_test.py

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the concurrency module."""
22

3+
import json
34
import random
45
import threading
56
import time
@@ -101,29 +102,6 @@ def test_batch_item_from_dict():
101102
assert item.result == "success_result"
102103
assert item.error is None
103104

104-
105-
def test_batch_item_from_dict_with_error():
106-
"""Test BatchItem from_dict with error object."""
107-
error_data = {
108-
"message": "Test error",
109-
"type": "TestError",
110-
"data": None,
111-
"stackTrace": None,
112-
}
113-
data = {
114-
"index": 1,
115-
"status": "FAILED",
116-
"result": None,
117-
"error": error_data,
118-
}
119-
120-
item = BatchItem.from_dict(data)
121-
assert item.index == 1
122-
assert item.status == BatchItemStatus.FAILED
123-
assert item.result is None
124-
assert item.error is not None
125-
126-
127105
def test_batch_result_creation():
128106
"""Test BatchResult creation."""
129107
items = [
@@ -2676,3 +2654,79 @@ def mock_get_checkpoint_result(operation_id):
26762654
assert len(result.all) == 1
26772655
assert result.all[0].status == BatchItemStatus.SUCCEEDED
26782656
assert result.all[0].result == "re_executed_result"
2657+
2658+
2659+
def test_batch_item_from_dict_with_error():
2660+
"""Test BatchItem.from_dict() with error."""
2661+
data = {
2662+
"index": 3,
2663+
"status": "FAILED",
2664+
"result": None,
2665+
"error": {
2666+
"ErrorType": "ValueError",
2667+
"ErrorMessage": "bad value",
2668+
"StackTrace": [],
2669+
},
2670+
}
2671+
2672+
item = BatchItem.from_dict(data)
2673+
2674+
assert item.index == 3
2675+
assert item.status == BatchItemStatus.FAILED
2676+
assert item.error.type == "ValueError"
2677+
assert item.error.message == "bad value"
2678+
2679+
2680+
def test_batch_result_with_mixed_statuses():
2681+
"""Test BatchResult serialization with mixed item statuses."""
2682+
result = BatchResult(
2683+
all=[
2684+
BatchItem(0, BatchItemStatus.SUCCEEDED, result="success"),
2685+
BatchItem(
2686+
1,
2687+
BatchItemStatus.FAILED,
2688+
error=ErrorObject(message="msg", type="E", data=None, stack_trace=[]),
2689+
),
2690+
BatchItem(2, BatchItemStatus.STARTED),
2691+
],
2692+
completion_reason=CompletionReason.FAILURE_TOLERANCE_EXCEEDED,
2693+
)
2694+
2695+
serialized = json.dumps(result.to_dict())
2696+
deserialized = BatchResult.from_dict(json.loads(serialized))
2697+
2698+
assert len(deserialized.all) == 3
2699+
assert deserialized.all[0].status == BatchItemStatus.SUCCEEDED
2700+
assert deserialized.all[1].status == BatchItemStatus.FAILED
2701+
assert deserialized.all[2].status == BatchItemStatus.STARTED
2702+
assert deserialized.completion_reason == CompletionReason.FAILURE_TOLERANCE_EXCEEDED
2703+
2704+
2705+
def test_batch_result_empty_list():
2706+
"""Test BatchResult serialization with empty items list."""
2707+
result = BatchResult(all=[], completion_reason=CompletionReason.ALL_COMPLETED)
2708+
2709+
serialized = json.dumps(result.to_dict())
2710+
deserialized = BatchResult.from_dict(json.loads(serialized))
2711+
2712+
assert len(deserialized.all) == 0
2713+
assert deserialized.completion_reason == CompletionReason.ALL_COMPLETED
2714+
2715+
2716+
def test_batch_result_complex_nested_data():
2717+
"""Test BatchResult with complex nested data structures."""
2718+
complex_result = {
2719+
"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}],
2720+
"metadata": {"count": 2, "timestamp": "2025-10-31"},
2721+
}
2722+
2723+
result = BatchResult(
2724+
all=[BatchItem(0, BatchItemStatus.SUCCEEDED, result=complex_result)],
2725+
completion_reason=CompletionReason.ALL_COMPLETED,
2726+
)
2727+
2728+
serialized = json.dumps(result.to_dict())
2729+
deserialized = BatchResult.from_dict(json.loads(serialized))
2730+
2731+
assert deserialized.all[0].result == complex_result
2732+
assert deserialized.all[0].result["users"][0]["name"] == "Alice"

tests/operation/map_test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for map operation."""
22

3+
import json
34
from unittest.mock import Mock, patch
45

56
# Mock the executor.execute method
@@ -750,3 +751,46 @@ def get_checkpoint_result(self, operation_id):
750751
# Verify replay was called, execute was not
751752
mock_replay.assert_called_once()
752753
mock_execute.assert_not_called()
754+
755+
756+
757+
def test_map_result_serialization_roundtrip():
758+
"""Test that map operation BatchResult can be serialized and deserialized."""
759+
760+
items = ["a", "b", "c"]
761+
762+
def func(ctx, item, idx, items):
763+
return {"item": item.upper(), "index": idx}
764+
765+
class MockExecutionState:
766+
durable_execution_arn = "arn:test"
767+
768+
def get_checkpoint_result(self, operation_id):
769+
mock_result = Mock()
770+
mock_result.is_succeeded.return_value = False
771+
return mock_result
772+
773+
execution_state = MockExecutionState()
774+
map_context = Mock()
775+
map_context._create_step_id_for_logical_step = Mock(side_effect=["1", "2", "3"]) # noqa SLF001
776+
map_context.create_child_context = Mock(return_value=Mock())
777+
operation_identifier = OperationIdentifier("test_op", "parent", "test_map")
778+
779+
# Execute map
780+
result = map_handler(
781+
items, func, MapConfig(), execution_state, map_context, operation_identifier
782+
)
783+
784+
# Serialize the BatchResult
785+
serialized = json.dumps(result.to_dict())
786+
787+
# Deserialize
788+
deserialized = BatchResult.from_dict(json.loads(serialized))
789+
790+
# Verify all data preserved
791+
assert len(deserialized.all) == 3
792+
assert deserialized.all[0].result == {"item": "A", "index": 0}
793+
assert deserialized.all[1].result == {"item": "B", "index": 1}
794+
assert deserialized.all[2].result == {"item": "C", "index": 2}
795+
assert deserialized.completion_reason == result.completion_reason
796+
assert all(item.status == BatchItemStatus.SUCCEEDED for item in deserialized.all)

tests/operation/parallel_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the parallel operation module."""
22

3+
import json
34
from unittest.mock import Mock, patch
45

56
import pytest
@@ -734,3 +735,58 @@ def get_checkpoint_result(self, operation_id):
734735
# Verify replay was called, execute was not
735736
mock_replay.assert_called_once()
736737
mock_execute.assert_not_called()
738+
739+
740+
741+
def test_parallel_result_serialization_roundtrip():
742+
"""Test that parallel operation BatchResult can be serialized and deserialized."""
743+
744+
def func1(ctx):
745+
return [1, 2, 3]
746+
747+
def func2(ctx):
748+
return {"status": "complete", "count": 42}
749+
750+
def func3(ctx):
751+
return "simple string"
752+
753+
callables = [func1, func2, func3]
754+
755+
class MockExecutionState:
756+
durable_execution_arn = "arn:test"
757+
758+
def get_checkpoint_result(self, operation_id):
759+
mock_result = Mock()
760+
mock_result.is_succeeded.return_value = False
761+
return mock_result
762+
763+
execution_state = MockExecutionState()
764+
parallel_context = Mock()
765+
parallel_context._create_step_id_for_logical_step = Mock( # noqa SLF001
766+
side_effect=["1", "2", "3"]
767+
)
768+
parallel_context.create_child_context = Mock(return_value=Mock())
769+
operation_identifier = OperationIdentifier("test_op", "parent", "test_parallel")
770+
771+
# Execute parallel
772+
result = parallel_handler(
773+
callables,
774+
ParallelConfig(),
775+
execution_state,
776+
parallel_context,
777+
operation_identifier,
778+
)
779+
780+
# Serialize the BatchResult
781+
serialized = json.dumps(result.to_dict())
782+
783+
# Deserialize
784+
deserialized = BatchResult.from_dict(json.loads(serialized))
785+
786+
# Verify all data preserved
787+
assert len(deserialized.all) == 3
788+
assert deserialized.all[0].result == [1, 2, 3]
789+
assert deserialized.all[1].result == {"status": "complete", "count": 42}
790+
assert deserialized.all[2].result == "simple string"
791+
assert deserialized.completion_reason == result.completion_reason
792+
assert all(item.status == BatchItemStatus.SUCCEEDED for item in deserialized.all)

0 commit comments

Comments
 (0)