Skip to content

Commit 24488ab

Browse files
author
Rares Polenciuc
committed
feat: add envelope serializer for extended python types
Implement EnvelopeSerDes to handle datetime, Decimal, bytes, UUID, tuple types using wrapper envelope format. Maintains backward compatibility with existing JSON serializer while providing comprehensive type support. - Add TypeHandler chain architecture for extensible serialization - Support datetime/date with ISO format encoding - Handle Decimal with string representation - Support bytes/bytearray/memoryview with base64 encoding - Add UUID serialization with string format - Implement tuple/list/dict container handling - Provide clear error messages for unsupported types - Add comprehensive test coverage for all supported types
1 parent 6beb550 commit 24488ab

2 files changed

Lines changed: 727 additions & 17 deletions

File tree

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 215 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
"""Serialization and deserialization"""
22

3+
from __future__ import annotations
4+
5+
import base64
36
import json
47
import logging
8+
import uuid
59
from abc import ABC, abstractmethod
610
from dataclasses import dataclass
7-
from typing import Generic, TypeVar
11+
from datetime import date, datetime
12+
from decimal import Decimal
13+
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Callable
817

918
from aws_durable_execution_sdk_python.exceptions import FatalError
1019

@@ -13,6 +22,18 @@
1322
T = TypeVar("T")
1423

1524

25+
class TypeEncoder(Protocol):
26+
"""Protocol for objects that can encode and decode types."""
27+
28+
def encode(self, obj: Any) -> dict[str, Any]:
29+
"""Encode an object to a dictionary representation."""
30+
...
31+
32+
def decode(self, tag: str, value: Any) -> Any:
33+
"""Decode a tagged value back to an object."""
34+
...
35+
36+
1637
@dataclass(frozen=True)
1738
class SerDesContext:
1839
operation_id: str
@@ -37,38 +58,215 @@ def deserialize(self, data: str, _: SerDesContext) -> T: # noqa: PLR6301
3758
return json.loads(data)
3859

3960

61+
class TypeHandler(TypeEncoder, ABC):
62+
def __init__(self, next_handler: TypeEncoder) -> None:
63+
self._next: TypeEncoder = next_handler
64+
65+
66+
class UnsupportedHandler(TypeEncoder):
67+
def encode(self, obj: Any) -> dict[str, Any]: # noqa: PLR6301
68+
msg = f"Unsupported type: {type(obj)!r}"
69+
raise TypeError(msg)
70+
71+
def decode(self, tag: str, value: Any) -> Any: # noqa: PLR6301, ARG002
72+
msg = f"Unknown type tag: {tag!r}"
73+
raise ValueError(msg)
74+
75+
76+
class BytesHandler(TypeHandler):
77+
def encode(self, obj: Any) -> dict[str, Any]:
78+
if isinstance(obj, bytes | bytearray | memoryview):
79+
encoded: str = base64.b64encode(bytes(obj)).decode()
80+
return {"_": {"t": "bytes", "v": encoded}}
81+
return self._next.encode(obj)
82+
83+
def decode(self, tag: str, value: Any) -> Any:
84+
if tag == "bytes":
85+
return base64.b64decode(value)
86+
return self._next.decode(tag, value)
87+
88+
89+
class UuidHandler(TypeHandler):
90+
def encode(self, obj: Any) -> dict[str, Any]:
91+
if isinstance(obj, uuid.UUID):
92+
return {"_": {"t": "uuid", "v": str(obj)}}
93+
return self._next.encode(obj)
94+
95+
def decode(self, tag: str, value: Any) -> Any:
96+
if tag == "uuid":
97+
return uuid.UUID(value)
98+
return self._next.decode(tag, value)
99+
100+
101+
class DecimalHandler(TypeHandler):
102+
def encode(self, obj: Any) -> dict[str, Any]:
103+
if isinstance(obj, Decimal):
104+
return {"_": {"t": "decimal", "v": str(obj)}}
105+
return self._next.encode(obj)
106+
107+
def decode(self, tag: str, value: Any) -> Any:
108+
if tag == "decimal":
109+
return Decimal(value)
110+
return self._next.decode(tag, value)
111+
112+
113+
class DateTimeHandler(TypeHandler):
114+
def encode(self, obj: Any) -> dict[str, Any]:
115+
if isinstance(obj, datetime):
116+
return {"_": {"t": "datetime", "v": obj.isoformat()}}
117+
if isinstance(obj, date):
118+
return {"_": {"t": "date", "v": obj.isoformat()}}
119+
return self._next.encode(obj)
120+
121+
def decode(self, tag: str, value: Any) -> Any:
122+
if tag == "datetime":
123+
return datetime.fromisoformat(value)
124+
if tag == "date":
125+
return date.fromisoformat(value)
126+
return self._next.decode(tag, value)
127+
128+
129+
class ContainerHandler(TypeHandler):
130+
def __init__(self, next_handler: TypeEncoder) -> None:
131+
super().__init__(next_handler)
132+
self._dispatch_encode: Callable[[Any], dict[str, Any]] | None = None
133+
self._dispatch_decode: Callable[[str, Any], Any] | None = None
134+
135+
def _enc(self, obj: Any) -> dict[str, Any]:
136+
if self._dispatch_encode is None:
137+
msg = "ContainerHandler not initialized with encode dispatcher."
138+
raise RuntimeError(msg)
139+
return self._dispatch_encode(obj)
140+
141+
def _dec(self, tag: str, value: Any) -> Any:
142+
if self._dispatch_decode is None:
143+
msg = "ContainerHandler not initialized with decode dispatcher."
144+
raise RuntimeError(msg)
145+
return self._dispatch_decode(tag, value)
146+
147+
def encode(self, obj: Any) -> dict[str, Any]:
148+
if isinstance(obj, tuple):
149+
items: list[dict[str, Any]] = [self._enc(x) for x in obj]
150+
return {"_": {"t": "tuple", "v": items}}
151+
if isinstance(obj, list):
152+
items_list: list[dict[str, Any]] = [self._enc(x) for x in obj]
153+
return {"_": {"t": "list", "v": items_list}}
154+
if isinstance(obj, dict):
155+
self._validate_dict_keys(obj)
156+
wrapped: dict[str, dict[str, Any]] = {
157+
k: self._enc(v) for k, v in obj.items()
158+
}
159+
return {"_": {"t": "dict", "v": wrapped}}
160+
return self._next.encode(obj)
161+
162+
def decode(self, tag: str, value: Any) -> Any:
163+
if tag == "tuple":
164+
if not isinstance(value, list):
165+
msg = 'Malformed envelope: "tuple" expects array value.'
166+
raise TypeError(msg)
167+
return tuple(self._dec(v["_"]["t"], v["_"]["v"]) for v in value)
168+
if tag == "list":
169+
if not isinstance(value, list):
170+
msg = 'Malformed envelope: "list" expects array value.'
171+
raise TypeError(msg)
172+
return [self._dec(v["_"]["t"], v["_"]["v"]) for v in value]
173+
if tag == "dict":
174+
if not isinstance(value, dict):
175+
msg = 'Malformed envelope: "dict" expects object value.'
176+
raise TypeError(msg)
177+
return {k: self._dec(v["_"]["t"], v["_"]["v"]) for k, v in value.items()}
178+
return self._next.decode(tag, value)
179+
180+
@staticmethod
181+
def _validate_dict_keys(mapping: dict[Any, Any]) -> None:
182+
bad: list[Any] = [k for k in mapping if not isinstance(k, str)]
183+
if bad:
184+
ex: Any = bad[0]
185+
msg = f"Unsupported mapping key type: {type(ex)!r}. JSON object keys must be strings."
186+
raise TypeError(msg)
187+
188+
189+
class PrimitiveHandler(TypeHandler):
190+
def encode(self, obj: Any) -> dict[str, Any]:
191+
if obj is None or isinstance(obj, str | int | float | bool):
192+
tag: str = type(obj).__name__
193+
return {"_": {"t": tag, "v": obj}}
194+
return self._next.encode(obj)
195+
196+
def decode(self, tag: str, value: Any) -> Any:
197+
if tag == "NoneType":
198+
return None
199+
if tag in {"str", "int", "float", "bool"}:
200+
return value
201+
return self._next.decode(tag, value)
202+
203+
204+
@dataclass(frozen=True)
205+
class HandlerChain:
206+
root: TypeHandler
207+
container: ContainerHandler
208+
209+
@classmethod
210+
def create(cls) -> HandlerChain:
211+
unsupported: UnsupportedHandler = UnsupportedHandler()
212+
bytes_h: BytesHandler = BytesHandler(unsupported)
213+
uuid_h: UuidHandler = UuidHandler(bytes_h)
214+
decimal_h: DecimalHandler = DecimalHandler(uuid_h)
215+
dt_h: DateTimeHandler = DateTimeHandler(decimal_h)
216+
container_h: ContainerHandler = ContainerHandler(dt_h)
217+
primitive_h: PrimitiveHandler = PrimitiveHandler(container_h)
218+
219+
# Wire dispatchers to always go through the root
220+
container_h._dispatch_encode = primitive_h.encode # noqa: SLF001
221+
container_h._dispatch_decode = primitive_h.decode # noqa: SLF001
222+
223+
return cls(root=primitive_h, container=container_h)
224+
225+
226+
class EnvelopeSerDes(SerDes[T]):
227+
def __init__(self) -> None:
228+
self._chain: HandlerChain = HandlerChain.create()
229+
230+
def serialize(self, value: T, _: SerDesContext) -> str:
231+
wrapped: dict[str, Any] = self._chain.root.encode(value)
232+
return json.dumps(wrapped, separators=(",", ":"))
233+
234+
def deserialize(self, data: str, _: SerDesContext) -> T:
235+
obj: Any = json.loads(data)
236+
if not (isinstance(obj, dict) and "_" in obj and isinstance(obj["_"], dict)):
237+
msg = 'Malformed envelope: root must be {"_": {"t": ..., "v": ...}}.'
238+
raise TypeError(msg)
239+
inner: dict[str, Any] = obj["_"]
240+
if not (isinstance(inner, dict) and "t" in inner and "v" in inner):
241+
msg = 'Malformed envelope: missing "t" or "v" at root.'
242+
raise TypeError(msg)
243+
return self._chain.root.decode(inner["t"], inner["v"])
244+
245+
40246
_DEFAULT_JSON_SERDES: SerDes = JsonSerDes()
41247

42248

43249
def serialize(
44250
serdes: SerDes[T] | None, value: T, operation_id: str, durable_execution_arn: str
45251
) -> str:
46252
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
47-
if serdes is None:
48-
serdes = _DEFAULT_JSON_SERDES
253+
active_serdes: SerDes[T] = serdes or _DEFAULT_JSON_SERDES
49254
try:
50-
return serdes.serialize(value, serdes_context)
255+
return active_serdes.serialize(value, serdes_context)
51256
except Exception as e:
52-
logger.exception(
53-
"⚠️ Serialization failed for id: %s",
54-
operation_id,
55-
)
56-
msg = f"Serialization failed for id: {operation_id}, error: {e}."
257+
logger.exception("⚠️ Serialization failed for id: %s", operation_id)
258+
msg: str = f"Serialization failed for id: {operation_id}, error: {e}."
57259
raise FatalError(msg) from e
58260

59261

60262
def deserialize(
61263
serdes: SerDes[T] | None, data: str, operation_id: str, durable_execution_arn: str
62264
) -> T:
63265
serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn)
64-
if serdes is None:
65-
serdes = _DEFAULT_JSON_SERDES
266+
active_serdes: SerDes[T] = serdes or _DEFAULT_JSON_SERDES
66267
try:
67-
return serdes.deserialize(data, serdes_context)
268+
return active_serdes.deserialize(data, serdes_context)
68269
except Exception as e:
69-
logger.exception(
70-
"⚠️ Deserialization failed for id: %s",
71-
operation_id,
72-
)
73-
msg = f"Deserialization failed for id: {operation_id}"
270+
logger.exception("⚠️ Deserialization failed for id: %s", operation_id)
271+
msg: str = f"Deserialization failed for id: {operation_id}"
74272
raise FatalError(msg) from e

0 commit comments

Comments
 (0)