Conversation
6a71ce0 to
24488ab
Compare
|
Note: JSON payload size could be optimized with type token compression (e.g., "datetime" → "d") but kept readable for maintainability. |
ghost
left a comment
There was a problem hiding this comment.
Overall looks good to me, but dict should allow primitives in keys. As long as we can encode them and back, it should be ok.
There was a problem hiding this comment.
I actually really liked the ideas from your 1st iteration - with the individual Codecs being responsible for serialising each type. this new change set does a lot more, but it's a lot more complicated 😅
How about just tweaking your elegant 1st design a bit so it handles nested types? something like this (I wrote this in a hurry, so it needs some cleaning up on type signatures and correct DurableExecution exceptions rather than just throwing built-ins):
from __future__ import annotations
import json
import uuid
import base64
from decimal import Decimal
from datetime import datetime, date
from typing import Any, Protocol
from enum import StrEnum
from dataclasses import dataclass
#region TypeTag
class TypeTag(StrEnum):
NONE = "n"
STR = "s"
INT = "i"
FLOAT = "f"
BOOL = "b"
BYTES = "B"
UUID = "u"
DECIMAL = "d"
DATETIME = "dt"
DATE = "D"
TUPLE = "t"
LIST = "l"
DICT = "m"
#endregion
#region EncodedValue
@dataclass(frozen=True)
class EncodedValue:
tag: TypeTag
value: Any
#endregion
#region Codec Protocol
class Codec(Protocol):
def encode(self, obj: Any) -> EncodedValue: ...
def decode(self, tag: TypeTag, value: Any) -> Any: ...
#endregion
#region PrimitiveCodec
class PrimitiveCodec(Codec):
def encode(self, obj: Any) -> EncodedValue:
match obj:
case None:
return EncodedValue(TypeTag.NONE, None)
case str():
return EncodedValue(TypeTag.STR, obj)
case bool(): # must come before int
return EncodedValue(TypeTag.BOOL, obj)
case int():
return EncodedValue(TypeTag.INT, obj)
case float():
return EncodedValue(TypeTag.FLOAT, obj)
case _:
raise TypeError(f"Unsupported primitive type: {type(obj)}")
def decode(self, tag: TypeTag, value: Any) -> Any:
match tag:
case TypeTag.NONE:
return None
case TypeTag.STR:
return str(value)
case TypeTag.INT:
return int(value)
case TypeTag.FLOAT:
return float(value)
case TypeTag.BOOL:
return bool(value)
case _:
raise ValueError(f"Unknown primitive tag: {tag}")
#endregion
#region BytesCodec
class BytesCodec(Codec):
def encode(self, obj: bytes) -> EncodedValue:
encoded = base64.b64encode(obj).decode("utf-8")
return EncodedValue(TypeTag.BYTES, encoded)
def decode(self, tag: TypeTag, value: Any) -> bytes:
if tag != TypeTag.BYTES:
raise ValueError(f"Invalid bytes tag: {tag}")
return base64.b64decode(value.encode("utf-8"))
#endregion
#region UUIDCodec
class UuidCodec(Codec):
def encode(self, obj: uuid.UUID) -> EncodedValue:
return EncodedValue(TypeTag.UUID, str(obj))
def decode(self, tag: TypeTag, value: Any) -> uuid.UUID:
if tag != TypeTag.UUID:
raise ValueError(f"Invalid UUID tag: {tag}")
return uuid.UUID(value)
#endregion
#region DecimalCodec
class DecimalCodec(Codec):
def encode(self, obj: Decimal) -> EncodedValue:
return EncodedValue(TypeTag.DECIMAL, str(obj))
def decode(self, tag: TypeTag, value: Any) -> Decimal:
if tag != TypeTag.DECIMAL:
raise ValueError(f"Invalid decimal tag: {tag}")
return Decimal(value)
#endregion
#region DateTimeCodec
class DateTimeCodec(Codec):
def encode(self, obj: datetime | date) -> EncodedValue:
match obj:
case datetime():
return EncodedValue(TypeTag.DATETIME, obj.isoformat())
case date():
return EncodedValue(TypeTag.DATE, obj.isoformat())
case _:
raise TypeError(f"Unsupported date/time type: {type(obj)}")
def decode(self, tag: TypeTag, value: Any) -> datetime | date:
match tag:
case TypeTag.DATETIME:
return datetime.fromisoformat(value)
case TypeTag.DATE:
return date.fromisoformat(value)
case _:
raise ValueError(f"Invalid date/time tag: {tag}")
#endregion
#region ContainerCodec
class ContainerCodec(Codec):
def __init__(self):
self._dispatcher: TypeCodec | None = None
def set_dispatcher(self, dispatcher: TypeCodec) -> None:
self._dispatcher = dispatcher
@property
def dispatcher(self) -> TypeCodec:
if self._dispatcher is None:
raise RuntimeError("ContainerCodec not linked to a TypeCodec dispatcher.")
return self._dispatcher
def encode(self, obj: list | tuple | dict) -> EncodedValue:
match obj:
case list():
return EncodedValue(TypeTag.LIST, [self._wrap(v) for v in obj])
case tuple():
return EncodedValue(TypeTag.TUPLE, [self._wrap(v) for v in obj])
case dict():
return EncodedValue(TypeTag.DICT, {self._wrap(k).value: self._wrap(v) for k, v in obj.items()})
case _:
raise TypeError(f"Unsupported container type: {type(obj)}")
def decode(self, tag: TypeTag, value: Any) -> Any:
match tag:
case TypeTag.LIST:
return [self._unwrap(v) for v in value]
case TypeTag.TUPLE:
return tuple(self._unwrap(v) for v in value)
case TypeTag.DICT:
return {k: self._unwrap(v) for k, v in value.items()}
case _:
raise ValueError(f"Unknown container tag: {tag}")
def _wrap(self, obj: Any) -> EncodedValue:
return self.dispatcher.encode(obj)
def _unwrap(self, obj: Any) -> Any:
match obj:
case EncodedValue():
return self.dispatcher.decode(obj.tag, obj.value)
case dict() if "t" in obj and "v" in obj:
tag = TypeTag(obj["t"])
return self.dispatcher.decode(tag, obj["v"])
case _:
return obj
#endregion
#region TypeCodec
class TypeCodec:
def __init__(self):
self.primitive_codec = PrimitiveCodec()
self.bytes_codec = BytesCodec()
self.uuid_codec = UuidCodec()
self.decimal_codec = DecimalCodec()
self.datetime_codec = DateTimeCodec()
self.container_codec = ContainerCodec()
self.container_codec.set_dispatcher(self)
def encode(self, obj: Any) -> EncodedValue:
match obj:
case None | str() | bool() | int() | float():
return self.primitive_codec.encode(obj)
case bytes() | bytearray() | memoryview():
return self.bytes_codec.encode(bytes(obj))
case uuid.UUID():
return self.uuid_codec.encode(obj)
case Decimal():
return self.decimal_codec.encode(obj)
case datetime() | date():
return self.datetime_codec.encode(obj)
case list() | tuple() | dict():
return self.container_codec.encode(obj)
case _:
raise TypeError(f"Unsupported type: {type(obj)}")
def decode(self, tag: TypeTag, value: Any) -> Any:
match tag:
case TypeTag.NONE | TypeTag.STR | TypeTag.BOOL | TypeTag.INT | TypeTag.FLOAT:
return self.primitive_codec.decode(tag, value)
case TypeTag.BYTES:
return self.bytes_codec.decode(tag, value)
case TypeTag.UUID:
return self.uuid_codec.decode(tag, value)
case TypeTag.DECIMAL:
return self.decimal_codec.decode(tag, value)
case TypeTag.DATETIME | TypeTag.DATE:
return self.datetime_codec.decode(tag, value)
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
return self.container_codec.decode(tag, value)
case _:
raise ValueError(f"Unknown type tag: {tag}")
#endregion
TYPE_CODEC = TypeCodec()
#region ExtendedTypeSerDes
class SerDesContext:
pass
class ExtendedTypeSerDes:
def __init__(self):
self._codec = TYPE_CODEC
def serialize(self, value: Any, _: SerDesContext = None) -> str:
encoded = self._codec.encode(value)
wrapped = self._to_json_serializable(encoded)
return json.dumps(wrapped, separators=(",", ":"))
def _to_json_serializable(self, obj: Any) -> Any:
match obj:
case EncodedValue():
return {"t": obj.tag, "v": self._to_json_serializable(obj.value)}
case list():
return [self._to_json_serializable(x) for x in obj]
case dict():
return {k: self._to_json_serializable(v) for k, v in obj.items()}
case _:
return obj
def deserialize(self, data: str, _: SerDesContext = None) -> Any:
obj = json.loads(data)
if not (isinstance(obj, dict) and "t" in obj and "v" in obj):
raise TypeError('Malformed envelope: missing "t" or "v" at root.')
tag = TypeTag(obj["t"])
return self._codec.decode(tag, obj["v"])
#endregion
from decimal import Decimal
from datetime import datetime, date
from decimal import Decimal
from datetime import datetime, date
import uuid
# create a single serializer instance
serdes = ExtendedTypeSerDes()
ctx = SerDesContext()
# --- 1) Primitive types at root ---
def test_primitives():
primitives = [
123,
'hello',
True,
False,
None,
3.14,
Decimal('10.5'),
uuid.UUID("12345678-1234-5678-1234-567812345678"),
b"bytes here",
date(2025, 10, 22),
datetime(2025, 10, 22, 15, 30, 0),
]
for val in primitives:
serialized = serdes.serialize(val, ctx)
deserialized = serdes.deserialize(serialized, ctx)
assert deserialized == val
# --- 2) Nested arrays/lists ---
def test_nested_arrays():
val = [1, "two", [3, {"four": 4}], True, b"hi"]
serialized = serdes.serialize(val, ctx)
deserialized = serdes.deserialize(serialized, ctx)
assert deserialized == val
# --- 3) Nested dicts ---
def test_nested_dicts():
val = {
"a": 1,
"b": [2, 3, {"c": 4}],
"d": {"e": "f", "g": [5, 6]},
"h": b"bytes in dict",
"i": uuid.UUID("12345678-1234-5678-1234-567812345678")
}
serialized = serdes.serialize(val, ctx)
deserialized = serdes.deserialize(serialized, ctx)
assert deserialized == val
# --- 4) Dict with t/v keys ---
def test_user_dict_with_t_v_keys():
val = {"t": "user t value", "v": "user v value"}
serialized = serdes.serialize(val, ctx)
deserialized = serdes.deserialize(serialized, ctx)
assert deserialized == val
# --- 5) Deep nested structures ---
def test_complex_nested_structure():
val = {
"list": [1, 2, [3, 4], {"nested_bytes": b"abc"}],
"tuple": (Decimal("3.14"), True),
"dict": {
"uuid": uuid.UUID("12345678-1234-5678-1234-567812345678"),
"date": date(2025, 10, 22),
"datetime": datetime(2025, 10, 22, 15, 30)
},
"t_v": {"t": "inner t", "v": [1, b"bytes"]}
}
serialized = serdes.serialize(val, ctx)
deserialized = serdes.deserialize(serialized, ctx)
assert deserialized == val
# --- 6) All nested t/v dicts ---
def test_all_t_v_nested_dicts():
val = {
"t": {"t": "s", "v": "outer t"},
"v": {
"t": {"t": "s", "v": "inner t"},
"v": {"t": {"t": "s", "v": "deep t"}, "v": "deep v"}
}
}
serialized = serdes.serialize(val, ctx)
deserialized = serdes.deserialize(serialized, ctx)
assert deserialized == val
# --- Run all tests ---
if __name__ == "__main__":
test_primitives()
test_nested_arrays()
test_nested_dicts()
test_user_dict_with_t_v_keys()
test_complex_nested_structure()
test_all_t_v_nested_dicts()
print("All tests passed!")
this uses recursion - which in python caps out at 1000. for the purposes of this SerDes that's probably fine. We could recode to traverse with a stack, but that's more complicated and something we could keep for another day if even necessary.
Implement ExtendedTypesSerDes - Plain JSON for primitives (str, int, float, bool, None) - Plain JSON for simple lists containing only primitives - Extended format for complex types (datetime, Decimal, UUID, etc.) - Support datetime/date with ISO format encoding - Handle Decimal with string representation
Implement EnvelopeSerDes
Issue #, if available:
closes #46
Description of changes:
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.