11"""Serialization and deserialization"""
22
3+ from __future__ import annotations
4+
5+ import base64
36import json
47import logging
8+ import uuid
59from abc import ABC , abstractmethod
610from dataclasses import dataclass
7- from typing import Generic , TypeVar
11+ from datetime import date , datetime
12+ from decimal import Decimal
13+ from typing import TYPE_CHECKING , Any , Generic , Protocol , TypeVar
14+
15+ if TYPE_CHECKING :
16+ from collections .abc import Callable
817
918from aws_durable_execution_sdk_python .exceptions import FatalError
1019
1322T = TypeVar ("T" )
1423
1524
25+ class TypeEncoder (Protocol ):
26+ """Protocol for objects that can encode and decode types."""
27+
28+ def encode (self , obj : Any ) -> dict [str , Any ]:
29+ """Encode an object to a dictionary representation."""
30+ ...
31+
32+ def decode (self , tag : str , value : Any ) -> Any :
33+ """Decode a tagged value back to an object."""
34+ ...
35+
36+
1637@dataclass (frozen = True )
1738class SerDesContext :
1839 operation_id : str
@@ -37,38 +58,215 @@ def deserialize(self, data: str, _: SerDesContext) -> T: # noqa: PLR6301
3758 return json .loads (data )
3859
3960
61+ class TypeHandler (TypeEncoder , ABC ):
62+ def __init__ (self , next_handler : TypeEncoder ) -> None :
63+ self ._next : TypeEncoder = next_handler
64+
65+
66+ class UnsupportedHandler (TypeEncoder ):
67+ def encode (self , obj : Any ) -> dict [str , Any ]: # noqa: PLR6301
68+ msg = f"Unsupported type: { type (obj )!r} "
69+ raise TypeError (msg )
70+
71+ def decode (self , tag : str , value : Any ) -> Any : # noqa: PLR6301, ARG002
72+ msg = f"Unknown type tag: { tag !r} "
73+ raise ValueError (msg )
74+
75+
76+ class BytesHandler (TypeHandler ):
77+ def encode (self , obj : Any ) -> dict [str , Any ]:
78+ if isinstance (obj , bytes | bytearray | memoryview ):
79+ encoded : str = base64 .b64encode (bytes (obj )).decode ()
80+ return {"_" : {"t" : "bytes" , "v" : encoded }}
81+ return self ._next .encode (obj )
82+
83+ def decode (self , tag : str , value : Any ) -> Any :
84+ if tag == "bytes" :
85+ return base64 .b64decode (value )
86+ return self ._next .decode (tag , value )
87+
88+
89+ class UuidHandler (TypeHandler ):
90+ def encode (self , obj : Any ) -> dict [str , Any ]:
91+ if isinstance (obj , uuid .UUID ):
92+ return {"_" : {"t" : "uuid" , "v" : str (obj )}}
93+ return self ._next .encode (obj )
94+
95+ def decode (self , tag : str , value : Any ) -> Any :
96+ if tag == "uuid" :
97+ return uuid .UUID (value )
98+ return self ._next .decode (tag , value )
99+
100+
101+ class DecimalHandler (TypeHandler ):
102+ def encode (self , obj : Any ) -> dict [str , Any ]:
103+ if isinstance (obj , Decimal ):
104+ return {"_" : {"t" : "decimal" , "v" : str (obj )}}
105+ return self ._next .encode (obj )
106+
107+ def decode (self , tag : str , value : Any ) -> Any :
108+ if tag == "decimal" :
109+ return Decimal (value )
110+ return self ._next .decode (tag , value )
111+
112+
113+ class DateTimeHandler (TypeHandler ):
114+ def encode (self , obj : Any ) -> dict [str , Any ]:
115+ if isinstance (obj , datetime ):
116+ return {"_" : {"t" : "datetime" , "v" : obj .isoformat ()}}
117+ if isinstance (obj , date ):
118+ return {"_" : {"t" : "date" , "v" : obj .isoformat ()}}
119+ return self ._next .encode (obj )
120+
121+ def decode (self , tag : str , value : Any ) -> Any :
122+ if tag == "datetime" :
123+ return datetime .fromisoformat (value )
124+ if tag == "date" :
125+ return date .fromisoformat (value )
126+ return self ._next .decode (tag , value )
127+
128+
129+ class ContainerHandler (TypeHandler ):
130+ def __init__ (self , next_handler : TypeEncoder ) -> None :
131+ super ().__init__ (next_handler )
132+ self ._dispatch_encode : Callable [[Any ], dict [str , Any ]] | None = None
133+ self ._dispatch_decode : Callable [[str , Any ], Any ] | None = None
134+
135+ def _enc (self , obj : Any ) -> dict [str , Any ]:
136+ if self ._dispatch_encode is None :
137+ msg = "ContainerHandler not initialized with encode dispatcher."
138+ raise RuntimeError (msg )
139+ return self ._dispatch_encode (obj )
140+
141+ def _dec (self , tag : str , value : Any ) -> Any :
142+ if self ._dispatch_decode is None :
143+ msg = "ContainerHandler not initialized with decode dispatcher."
144+ raise RuntimeError (msg )
145+ return self ._dispatch_decode (tag , value )
146+
147+ def encode (self , obj : Any ) -> dict [str , Any ]:
148+ if isinstance (obj , tuple ):
149+ items : list [dict [str , Any ]] = [self ._enc (x ) for x in obj ]
150+ return {"_" : {"t" : "tuple" , "v" : items }}
151+ if isinstance (obj , list ):
152+ items_list : list [dict [str , Any ]] = [self ._enc (x ) for x in obj ]
153+ return {"_" : {"t" : "list" , "v" : items_list }}
154+ if isinstance (obj , dict ):
155+ self ._validate_dict_keys (obj )
156+ wrapped : dict [str , dict [str , Any ]] = {
157+ k : self ._enc (v ) for k , v in obj .items ()
158+ }
159+ return {"_" : {"t" : "dict" , "v" : wrapped }}
160+ return self ._next .encode (obj )
161+
162+ def decode (self , tag : str , value : Any ) -> Any :
163+ if tag == "tuple" :
164+ if not isinstance (value , list ):
165+ msg = 'Malformed envelope: "tuple" expects array value.'
166+ raise TypeError (msg )
167+ return tuple (self ._dec (v ["_" ]["t" ], v ["_" ]["v" ]) for v in value )
168+ if tag == "list" :
169+ if not isinstance (value , list ):
170+ msg = 'Malformed envelope: "list" expects array value.'
171+ raise TypeError (msg )
172+ return [self ._dec (v ["_" ]["t" ], v ["_" ]["v" ]) for v in value ]
173+ if tag == "dict" :
174+ if not isinstance (value , dict ):
175+ msg = 'Malformed envelope: "dict" expects object value.'
176+ raise TypeError (msg )
177+ return {k : self ._dec (v ["_" ]["t" ], v ["_" ]["v" ]) for k , v in value .items ()}
178+ return self ._next .decode (tag , value )
179+
180+ @staticmethod
181+ def _validate_dict_keys (mapping : dict [Any , Any ]) -> None :
182+ bad : list [Any ] = [k for k in mapping if not isinstance (k , str )]
183+ if bad :
184+ ex : Any = bad [0 ]
185+ msg = f"Unsupported mapping key type: { type (ex )!r} . JSON object keys must be strings."
186+ raise TypeError (msg )
187+
188+
189+ class PrimitiveHandler (TypeHandler ):
190+ def encode (self , obj : Any ) -> dict [str , Any ]:
191+ if obj is None or isinstance (obj , str | int | float | bool ):
192+ tag : str = type (obj ).__name__
193+ return {"_" : {"t" : tag , "v" : obj }}
194+ return self ._next .encode (obj )
195+
196+ def decode (self , tag : str , value : Any ) -> Any :
197+ if tag == "NoneType" :
198+ return None
199+ if tag in {"str" , "int" , "float" , "bool" }:
200+ return value
201+ return self ._next .decode (tag , value )
202+
203+
204+ @dataclass (frozen = True )
205+ class HandlerChain :
206+ root : TypeHandler
207+ container : ContainerHandler
208+
209+ @classmethod
210+ def create (cls ) -> HandlerChain :
211+ unsupported : UnsupportedHandler = UnsupportedHandler ()
212+ bytes_h : BytesHandler = BytesHandler (unsupported )
213+ uuid_h : UuidHandler = UuidHandler (bytes_h )
214+ decimal_h : DecimalHandler = DecimalHandler (uuid_h )
215+ dt_h : DateTimeHandler = DateTimeHandler (decimal_h )
216+ container_h : ContainerHandler = ContainerHandler (dt_h )
217+ primitive_h : PrimitiveHandler = PrimitiveHandler (container_h )
218+
219+ # Wire dispatchers to always go through the root
220+ container_h ._dispatch_encode = primitive_h .encode # noqa: SLF001
221+ container_h ._dispatch_decode = primitive_h .decode # noqa: SLF001
222+
223+ return cls (root = primitive_h , container = container_h )
224+
225+
226+ class EnvelopeSerDes (SerDes [T ]):
227+ def __init__ (self ) -> None :
228+ self ._chain : HandlerChain = HandlerChain .create ()
229+
230+ def serialize (self , value : T , _ : SerDesContext ) -> str :
231+ wrapped : dict [str , Any ] = self ._chain .root .encode (value )
232+ return json .dumps (wrapped , separators = ("," , ":" ))
233+
234+ def deserialize (self , data : str , _ : SerDesContext ) -> T :
235+ obj : Any = json .loads (data )
236+ if not (isinstance (obj , dict ) and "_" in obj and isinstance (obj ["_" ], dict )):
237+ msg = 'Malformed envelope: root must be {"_": {"t": ..., "v": ...}}.'
238+ raise TypeError (msg )
239+ inner : dict [str , Any ] = obj ["_" ]
240+ if not (isinstance (inner , dict ) and "t" in inner and "v" in inner ):
241+ msg = 'Malformed envelope: missing "t" or "v" at root.'
242+ raise TypeError (msg )
243+ return self ._chain .root .decode (inner ["t" ], inner ["v" ])
244+
245+
40246_DEFAULT_JSON_SERDES : SerDes = JsonSerDes ()
41247
42248
43249def serialize (
44250 serdes : SerDes [T ] | None , value : T , operation_id : str , durable_execution_arn : str
45251) -> str :
46252 serdes_context : SerDesContext = SerDesContext (operation_id , durable_execution_arn )
47- if serdes is None :
48- serdes = _DEFAULT_JSON_SERDES
253+ active_serdes : SerDes [T ] = serdes or _DEFAULT_JSON_SERDES
49254 try :
50- return serdes .serialize (value , serdes_context )
255+ return active_serdes .serialize (value , serdes_context )
51256 except Exception as e :
52- logger .exception (
53- "⚠️ Serialization failed for id: %s" ,
54- operation_id ,
55- )
56- msg = f"Serialization failed for id: { operation_id } , error: { e } ."
257+ logger .exception ("⚠️ Serialization failed for id: %s" , operation_id )
258+ msg : str = f"Serialization failed for id: { operation_id } , error: { e } ."
57259 raise FatalError (msg ) from e
58260
59261
60262def deserialize (
61263 serdes : SerDes [T ] | None , data : str , operation_id : str , durable_execution_arn : str
62264) -> T :
63265 serdes_context : SerDesContext = SerDesContext (operation_id , durable_execution_arn )
64- if serdes is None :
65- serdes = _DEFAULT_JSON_SERDES
266+ active_serdes : SerDes [T ] = serdes or _DEFAULT_JSON_SERDES
66267 try :
67- return serdes .deserialize (data , serdes_context )
268+ return active_serdes .deserialize (data , serdes_context )
68269 except Exception as e :
69- logger .exception (
70- "⚠️ Deserialization failed for id: %s" ,
71- operation_id ,
72- )
73- msg = f"Deserialization failed for id: { operation_id } "
270+ logger .exception ("⚠️ Deserialization failed for id: %s" , operation_id )
271+ msg : str = f"Deserialization failed for id: { operation_id } "
74272 raise FatalError (msg ) from e
0 commit comments