diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index 46f04e9..f961dba 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -168,3 +168,7 @@ def __str__(self) -> str: A string in the format "type: message" """ return f"{self.type}: {self.message}" + + +class SerDesError(DurableExecutionsError): + """Raised when serialization fails.""" diff --git a/src/aws_durable_execution_sdk_python/serdes.py b/src/aws_durable_execution_sdk_python/serdes.py index 8972800..e17cd56 100644 --- a/src/aws_durable_execution_sdk_python/serdes.py +++ b/src/aws_durable_execution_sdk_python/serdes.py @@ -1,22 +1,341 @@ -"""Serialization and deserialization""" +"""Codec-based serialization and deserialization for Python types. +This module provides comprehensive serialization support using a codec-based +architecture with recursive encoding/decoding for nested structures. + +Key Features: +- Plain JSON for primitives and simple lists (performance optimization) +- Envelope format with type tags for complex types +- Modular codec architecture +- Recursive handling of nested structures + +Serialization Strategy: +- Primitives (None, str, int, float, bool): Plain JSON +- Simple lists containing only primitives: Plain JSON +- Everything else: Envelope format with type tags + +Wire Formats: + Plain JSON: 42, "hello", [1, 2, 3] + Envelope: {"t": "", "v": } +""" + +from __future__ import annotations + +import base64 import json import logging +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, TypeVar +from datetime import date, datetime +from decimal import Decimal +from enum import StrEnum +from typing import Any, Generic, Protocol, TypeVar -from aws_durable_execution_sdk_python.exceptions import FatalError +from aws_durable_execution_sdk_python.exceptions import ( + DurableExecutionsError, + FatalError, + SerDesError, +) logger = logging.getLogger(__name__) T = TypeVar("T") +TYPE_TOKEN: str = "t" +VALUE_TOKEN: str = "v" + + +class TypeTag(StrEnum): + """Type tags for envelope format.""" + + NONE = "n" + STR = "s" + INT = "i" + FLOAT = "f" + BOOL = "b" + BYTES = "B" + UUID = "u" + DECIMAL = "d" + DATETIME = "dt" + DATE = "D" + TUPLE = "t" + LIST = "l" + DICT = "m" + + +@dataclass(frozen=True) +class EncodedValue: + """Encoded value with type tag.""" + + tag: TypeTag + + value: Any + + +# region codecs +class Codec(Protocol): + """Protocol for type-specific codecs.""" + + def encode(self, obj: Any) -> EncodedValue: ... + + def decode(self, tag: TypeTag, value: Any) -> Any: ... + + +class PrimitiveCodec: + """Codec for primitive types.""" + + def encode(self, obj: Any) -> EncodedValue: # noqa: PLR6301 + match obj: + case None: + return EncodedValue(TypeTag.NONE, None) + case str(): + return EncodedValue(TypeTag.STR, obj) + case bool(): # Must come before int + return EncodedValue(TypeTag.BOOL, obj) + case int(): + return EncodedValue(TypeTag.INT, obj) + case float(): + return EncodedValue(TypeTag.FLOAT, obj) + case _: + msg = f"Unsupported primitive type: {type(obj)!r}" + raise SerDesError(msg) + + def decode(self, tag: TypeTag, value: Any) -> Any: # noqa: PLR6301 + match tag: + case TypeTag.NONE: + return None + case TypeTag.STR: + return str(value) + case TypeTag.BOOL: + return bool(value) + case TypeTag.INT: + return int(value) + case TypeTag.FLOAT: + return float(value) + case _: + msg = f"Unknown primitive tag: {tag}" + raise SerDesError(msg) + + +class BytesCodec: + """Codec for bytes, bytearray, and memoryview.""" + + def encode(self, obj: Any) -> EncodedValue: # noqa: PLR6301 + encoded = base64.b64encode(bytes(obj)).decode("utf-8") + return EncodedValue(TypeTag.BYTES, encoded) + + def decode(self, tag: TypeTag, value: Any) -> Any: # noqa: PLR6301 + if tag != TypeTag.BYTES: + msg = f"Expected BYTES tag, got {tag}" + raise SerDesError(msg) + return base64.b64decode(value.encode("utf-8")) + + +class UuidCodec: + """Codec for UUID objects.""" + + def encode(self, obj: Any) -> EncodedValue: # noqa: PLR6301 + return EncodedValue(TypeTag.UUID, str(obj)) + + def decode(self, tag: TypeTag, value: Any) -> Any: # noqa: PLR6301 + if tag != TypeTag.UUID: + msg = f"Expected UUID tag, got {tag}" + raise SerDesError(msg) + return uuid.UUID(value) + + +class DecimalCodec: + """Codec for Decimal objects.""" + + def encode(self, obj: Any) -> EncodedValue: # noqa: PLR6301 + return EncodedValue(TypeTag.DECIMAL, str(obj)) + + def decode(self, tag: TypeTag, value: Any) -> Any: # noqa: PLR6301 + if tag != TypeTag.DECIMAL: + msg = f"Expected DECIMAL tag, got {tag}" + raise SerDesError(msg) + return Decimal(value) + + +class DateTimeCodec: + """Codec for datetime and date objects.""" + + def encode(self, obj: Any) -> EncodedValue: # noqa: PLR6301 + match obj: + case datetime(): + return EncodedValue(TypeTag.DATETIME, obj.isoformat()) + case date(): + return EncodedValue(TypeTag.DATE, obj.isoformat()) + case _: + msg = f"Unsupported datetime type: {type(obj)!r}" + raise SerDesError(msg) + + def decode(self, tag: TypeTag, value: Any) -> Any: # noqa: PLR6301 + match tag: + case TypeTag.DATETIME: + # Handle Z suffix for UTC + s = value + if isinstance(s, str) and s.endswith("Z"): + s = s[:-1] + "+00:00" + return datetime.fromisoformat(s) + case TypeTag.DATE: + return date.fromisoformat(value) + case _: + msg = f"Unknown datetime tag: {tag}" + raise SerDesError(msg) + + +class ContainerCodec(Codec): + """Codec for container types with recursive encoding/decoding.""" + + def __init__(self) -> None: + self._dispatcher: TypeCodec | None = None + + def set_dispatcher(self, dispatcher) -> None: + """Set the main codec dispatcher for recursive encoding.""" + self._dispatcher = dispatcher + + @property + def dispatcher(self): + """Get the dispatcher, raising error if not set.""" + if self._dispatcher is None: + msg = "ContainerCodec not linked to a TypeCodec dispatcher." + raise DurableExecutionsError(msg) + return self._dispatcher + + def encode(self, obj: Any) -> EncodedValue: + """Encode container using dispatcher for recursive elements.""" + match obj: + case list(): + return EncodedValue( + TypeTag.LIST, [self._wrap(v, self.dispatcher) for v in obj] + ) + case tuple(): + return EncodedValue( + TypeTag.TUPLE, [self._wrap(v, self.dispatcher) for v in obj] + ) + case dict(): + for k in obj: + if isinstance(k, tuple): + msg = "Tuple keys not supported" + raise SerDesError(msg) + return EncodedValue( + TypeTag.DICT, + {k: self._wrap(v, self.dispatcher) for k, v in obj.items()}, + ) + case _: + msg = f"Unsupported container type: {type(obj)!r}" + raise SerDesError(msg) + + def decode(self, tag: TypeTag, value: Any) -> Any: + """Decode container using dispatcher for recursive elements.""" + match tag: + case TypeTag.LIST: + if not isinstance(value, list): + msg = f"Expected list, got {type(value)}" + raise SerDesError(msg) + return [self._unwrap(v, self.dispatcher) for v in value] + case TypeTag.TUPLE: + if not isinstance(value, list): + msg = f"Expected list, got {type(value)}" + raise SerDesError(msg) + return tuple(self._unwrap(v, self.dispatcher) for v in value) + case TypeTag.DICT: + if not isinstance(value, dict): + msg = f"Expected dict, got {type(value)}" + raise SerDesError(msg) + return {k: self._unwrap(v, self.dispatcher) for k, v in value.items()} + case _: + msg = f"Unknown container tag: {tag}" + raise SerDesError(msg) + + @staticmethod + def _wrap(obj: Any, dispatcher) -> EncodedValue: + """Wrap object using dispatcher.""" + return dispatcher.encode(obj) + + @staticmethod + def _unwrap(obj: Any, dispatcher) -> Any: + """Unwrap object using dispatcher.""" + match obj: + case EncodedValue(): + return dispatcher.decode(obj.tag, obj.value) + case dict() if TYPE_TOKEN in obj and VALUE_TOKEN in obj: + tag = TypeTag(obj[TYPE_TOKEN]) + return dispatcher.decode(tag, obj[VALUE_TOKEN]) + case _: + return obj + + +class TypeCodec(Codec): + """Main codec dispatcher.""" + + def __init__(self): + self.primitive_codec = PrimitiveCodec() + self.bytes_codec = BytesCodec() + self.uuid_codec = UuidCodec() + self.decimal_codec = DecimalCodec() + self.datetime_codec = DateTimeCodec() + self.container_codec = ContainerCodec() + self.container_codec.set_dispatcher(self) + + def encode(self, obj: Any) -> EncodedValue: + match obj: + case None | str() | bool() | int() | float(): + return self.primitive_codec.encode(obj) + case bytes() | bytearray() | memoryview(): + return self.bytes_codec.encode(bytes(obj)) + case uuid.UUID(): + return self.uuid_codec.encode(obj) + case Decimal(): + return self.decimal_codec.encode(obj) + case datetime() | date(): + return self.datetime_codec.encode(obj) + case list() | tuple() | dict(): + return self.container_codec.encode(obj) + case _: + msg = f"Unsupported type: {type(obj)}" + raise SerDesError(msg) + + def decode(self, tag: TypeTag, value: Any) -> Any: + match tag: + case ( + TypeTag.NONE + | TypeTag.STR + | TypeTag.BOOL + | TypeTag.INT + | TypeTag.FLOAT + ): + return self.primitive_codec.decode(tag, value) + case TypeTag.BYTES: + return self.bytes_codec.decode(tag, value) + case TypeTag.UUID: + return self.uuid_codec.decode(tag, value) + case TypeTag.DECIMAL: + return self.decimal_codec.decode(tag, value) + case TypeTag.DATETIME | TypeTag.DATE: + return self.datetime_codec.decode(tag, value) + case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT: + return self.container_codec.decode(tag, value) + case _: + msg = f"Unknown type tag: {tag}" + raise SerDesError(msg) + + +TYPE_CODEC = TypeCodec() + + +# endregion + @dataclass(frozen=True) class SerDesContext: - operation_id: str - durable_execution_arn: str + """Context for serialization operations.""" + + operation_id: str = "" + + durable_execution_arn: str = "" class SerDes(ABC, Generic[T]): @@ -28,6 +347,15 @@ def serialize(self, value: T, serdes_context: SerDesContext) -> str: def deserialize(self, data: str, serdes_context: SerDesContext) -> T: pass + @staticmethod + def is_primitive(obj: Any) -> bool: + """Check if object contains only JSON-serializable primitives.""" + if obj is None or isinstance(obj, str | int | float | bool): + return True + if isinstance(obj, list): + return all(SerDes.is_primitive(item) for item in obj) + return False + class JsonSerDes(SerDes[T]): def serialize(self, value: T, _: SerDesContext) -> str: # noqa: PLR6301 @@ -37,38 +365,108 @@ def deserialize(self, data: str, _: SerDesContext) -> T: # noqa: PLR6301 return json.loads(data) -_DEFAULT_JSON_SERDES: SerDes = JsonSerDes() +class ExtendedTypeSerDes(SerDes[T]): + """Main serializer class.""" + + def __init__(self): + self._codec = TYPE_CODEC + + def serialize(self, value: Any, context: SerDesContext | None = None) -> str: # noqa: ARG002 + """Serialize value to JSON string.""" + # Fast path for primitives + if SerDes.is_primitive(value): + return json.dumps(value, separators=(",", ":")) + + encoded = self._codec.encode(value) + wrapped = self._to_json_serializable(encoded) + return json.dumps(wrapped, separators=(",", ":")) + + def deserialize(self, data: str, context: SerDesContext | None = None) -> Any: # noqa: ARG002 + """Deserialize JSON string to Python object.""" + obj = json.loads(data) + + # Fast path for primitives + if SerDes.is_primitive(obj): + return obj + + if not (isinstance(obj, dict) and TYPE_TOKEN in obj and VALUE_TOKEN in obj): + msg = 'Malformed envelope: missing "t" or "v" at root.' + raise SerDesError(msg) + if obj[TYPE_TOKEN] not in TypeTag: + msg = f'Unknown type tag: "{obj[TYPE_TOKEN]}"' + raise SerDesError(msg) + tag = TypeTag(obj[TYPE_TOKEN]) + return self._codec.decode(tag, obj[VALUE_TOKEN]) + + def _to_json_serializable(self, obj: Any) -> Any: + """Convert EncodedValue objects to JSON-serializable format.""" + match obj: + case EncodedValue(): + return { + TYPE_TOKEN: obj.tag, + VALUE_TOKEN: self._to_json_serializable(obj.value), + } + case list(): + return [self._to_json_serializable(x) for x in obj] + case dict(): + return {k: self._to_json_serializable(v) for k, v in obj.items()} + case _: + return obj + + +_DEFAULT_JSON_SERDES: SerDes[Any] = JsonSerDes() +_EXTENDED_TYPES_SERDES: SerDes[Any] = ExtendedTypeSerDes() def serialize( serdes: SerDes[T] | None, value: T, operation_id: str, durable_execution_arn: str ) -> str: + """Serialize value using provided or default serializer. + + Args: + serdes: Custom serializer or None for default + value: Object to serialize + operation_id: Unique operation identifier + durable_execution_arn: ARN of durable execution + + Returns: + Serialized string representation + + Raises: + FatalError: If serialization fails + """ serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn) - if serdes is None: - serdes = _DEFAULT_JSON_SERDES + active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES try: - return serdes.serialize(value, serdes_context) + return active_serdes.serialize(value, serdes_context) except Exception as e: - logger.exception( - "⚠️ Serialization failed for id: %s", - operation_id, - ) - msg = f"Serialization failed for id: {operation_id}, error: {e}." + logger.exception("⚠️ Serialization failed for id: %s", operation_id) + msg = f"Serialization failed for id: {operation_id}, error: {e}" raise FatalError(msg) from e def deserialize( serdes: SerDes[T] | None, data: str, operation_id: str, durable_execution_arn: str ) -> T: + """Deserialize data using provided or default serializer. + + Args: + serdes: Custom serializer or None for default + data: Serialized string data + operation_id: Unique operation identifier + durable_execution_arn: ARN of durable execution + + Returns: + Deserialized Python object + + Raises: + FatalError: If deserialization fails + """ serdes_context: SerDesContext = SerDesContext(operation_id, durable_execution_arn) - if serdes is None: - serdes = _DEFAULT_JSON_SERDES + active_serdes: SerDes[T] = serdes or _EXTENDED_TYPES_SERDES try: - return serdes.deserialize(data, serdes_context) + return active_serdes.deserialize(data, serdes_context) except Exception as e: - logger.exception( - "⚠️ Deserialization failed for id: %s", - operation_id, - ) + logger.exception("⚠️ Deserialization failed for id: %s", operation_id) msg = f"Deserialization failed for id: {operation_id}" raise FatalError(msg) from e diff --git a/tests/serdes_test.py b/tests/serdes_test.py index 424d5ee..9768b27 100644 --- a/tests/serdes_test.py +++ b/tests/serdes_test.py @@ -1,12 +1,32 @@ +import base64 import json +import math +import uuid +from datetime import UTC, date, datetime +from decimal import Decimal from typing import Any import pytest -from aws_durable_execution_sdk_python.exceptions import FatalError +from aws_durable_execution_sdk_python.exceptions import ( + DurableExecutionsError, + FatalError, + SerDesError, +) from aws_durable_execution_sdk_python.serdes import ( + BytesCodec, + ContainerCodec, + DateTimeCodec, + DecimalCodec, + EncodedValue, + ExtendedTypeSerDes, + JsonSerDes, + PrimitiveCodec, SerDes, SerDesContext, + TypeCodec, + TypeTag, + UuidCodec, deserialize, serialize, ) @@ -55,6 +75,7 @@ def _rec_deserialize(self, value: Any) -> Any: return value +# region Abstract SerDes Tests def test_serdes_abstract(): """Test SerDes abstract base class.""" @@ -94,6 +115,10 @@ def test_serdes_abstract_methods_coverage(): SerDes.deserialize(None, None, None) # Covers line 104 +# endregion + + +# region JsonSerDes Tests def test_serialize_invalid_json(): circular_ref = {"a": 1} circular_ref["self"] = circular_ref @@ -112,21 +137,26 @@ def test_deserialize_invalid_json(): def test_none_serdes_context(): data = {"test": "value"} result = serialize(None, data, None, None) - assert json.loads(result) == data + # Dict uses envelope format, so roundtrip through deserialize + deserialized = deserialize(None, result, None, None) + assert deserialized == data def test_default_json_serialization(): data = {"name": "test", "value": 123} serialized = serialize(None, data, "test-op", "test-arn") assert isinstance(serialized, str) - assert json.loads(serialized) == data + # Dict uses envelope format, so roundtrip through deserialize + deserialized = deserialize(None, serialized, "test-op", "test-arn") + assert deserialized == data def test_default_json_deserialization(): - data = '{"name": "test", "value": 123}' + # Use a simple list that can be plain JSON + data = "[1, 2, 3]" deserialized = deserialize(None, data, "test-op", "test-arn") - assert isinstance(deserialized, dict) - assert deserialized == {"name": "test", "value": 123} + assert isinstance(deserialized, list) + assert deserialized == [1, 2, 3] def test_default_json_roundtrip(): @@ -136,6 +166,10 @@ def test_default_json_roundtrip(): assert deserialized == original +# endregion + + +# region Custom SerDes Tests def test_custom_str_serdes_serialization(): result = serialize(CustomStrSerDes(), "hello world", "test-op", "test-arn") assert result == "HELLO WORLD" @@ -197,3 +231,666 @@ def deserialize(self, data: str, serdes_context: SerDesContext) -> str: assert serialized == "data" + "test-arn" deserialized = deserialize(serdes, serialized, "test-op", "test-arn") assert deserialized == "data" + "test-arn" + "test-op" + + +# endregion + + +# region EnvelopeSerDes Basic Tests +def _roundtrip_envelope(value: Any) -> Any: + """Helper for envelope round-trip testing.""" + serdes: ExtendedTypeSerDes[Any] = ExtendedTypeSerDes() + context = SerDesContext( + "test-op", "arn:aws:lambda:us-east-1:123456789012:function:test" + ) + serialized = serdes.serialize(value, context) + return serdes.deserialize(serialized, context) + + +def test_envelope_none_roundtrip(): + assert _roundtrip_envelope(None) is None + + +def test_envelope_bool_roundtrip(): + assert _roundtrip_envelope(True) is True + assert _roundtrip_envelope(False) is False + + +def test_envelope_int_roundtrip(): + values = [0, 1, -1, 42, -999, 2**63 - 1, -(2**63)] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_float_roundtrip(): + values = [0.0, 1.5, -math.pi, 1e10, -1e-10, float("inf"), float("-inf")] + for val in values: + result = _roundtrip_envelope(val) + if val != val: # NaN check # noqa: PLR0124 + assert result != result # NaN != NaN # noqa: PLR0124 + else: + assert result == val + + +def test_envelope_float_nan_roundtrip(): + nan_val = float("nan") + result = _roundtrip_envelope(nan_val) + assert result != result # NaN != NaN is True # noqa: PLR0124 + + +def test_envelope_str_roundtrip(): + values = ["", "hello", "🚀", "line1\nline2", "tab\there", '"quotes"', "\\backslash"] + for val in values: + assert _roundtrip_envelope(val) == val + + +# endregion + + +# region EnvelopeSerDes Extended Types +def test_envelope_datetime_roundtrip(): + values = [ + datetime(2024, 1, 1, tzinfo=UTC), + datetime(2024, 12, 31, 23, 59, 59, 999999, tzinfo=UTC), + datetime(1970, 1, 1, tzinfo=UTC), + datetime.now(UTC), + datetime.now(UTC), + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_date_roundtrip(): + values = [ + date(2024, 1, 1), + date(1970, 1, 1), + date(9999, 12, 31), + date.today(), # noqa: DTZ011 + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_decimal_roundtrip(): + values = [ + Decimal(0), + Decimal("3.14159"), + Decimal("-999.999"), + Decimal("1e10"), + Decimal("1e-28"), + Decimal("123456789.123456789"), + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_uuid_roundtrip(): + values = [ + uuid.uuid4(), + uuid.UUID("12345678-1234-5678-1234-123456789abc"), + uuid.UUID(int=0), + uuid.UUID(int=2**128 - 1), + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_bytes_roundtrip(): + values = [ + b"", + b"hello", + b"\x00\x01\x02\xff", + bytes(range(256)), + "🚀".encode(), + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_bytearray_roundtrip(): + val = bytearray(b"hello world") + result = _roundtrip_envelope(val) + assert result == b"hello world" # Returns bytes, not bytearray + + +def test_envelope_memoryview_roundtrip(): + val = memoryview(b"memory test") + result = _roundtrip_envelope(val) + assert result == b"memory test" # Returns bytes, not memoryview + + +# endregion + + +# region EnvelopeSerDes Container Types +def test_envelope_tuple_roundtrip(): + values = [ + (), + (1,), + (1, 2, 3), + ("a", "b", "c"), + (1, "mixed", math.pi), + ((1, 2), (3, 4)), # Nested tuples + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_list_roundtrip(): + values = [ + [], + [1], + [1, 2, 3], + ["a", "b", "c"], + [1, "mixed", math.pi], + [[1, 2], [3, 4]], # Nested lists + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +def test_envelope_dict_roundtrip(): + values = [ + {}, + {"a": 1}, + {"x": 1, "y": 2, "z": 3}, + {"nested": {"inner": "value"}}, + {"mixed": [1, {"deep": True}]}, + ] + for val in values: + assert _roundtrip_envelope(val) == val + + +# endregion + + +# region EnvelopeSerDes Complex Structures +def test_envelope_deeply_nested_structure(): + complex_data = { + "user": { + "id": uuid.uuid4(), + "created": datetime.now(UTC), + "balance": Decimal("1234.56"), + "metadata": b"binary_data", + "coordinates": (40.7128, -74.0060), + "tags": ["premium", "verified"], + "settings": { + "notifications": True, + "theme": "dark", + "limits": { + "daily": Decimal("500.00"), + "monthly": Decimal("10000.00"), + }, + }, + }, + "session": { + "started": datetime.now(UTC), + "expires": date.today(), # noqa: DTZ011 + "token": uuid.uuid4(), + }, + } + assert _roundtrip_envelope(complex_data) == complex_data + + +def test_envelope_mixed_type_collections(): + mixed_list = [ + None, + True, + 42, + math.pi, + "string", + datetime.now(UTC), + Decimal("99.99"), + uuid.uuid4(), + b"bytes", + (1, 2, 3), + [4, 5, 6], + {"key": "value"}, + ] + assert _roundtrip_envelope(mixed_list) == mixed_list + + +def test_envelope_tuple_with_all_types(): + all_types_tuple = ( + None, + True, + 42, + math.pi, + "string", + datetime(2024, 1, 1, tzinfo=UTC), + date(2024, 1, 1), + Decimal("123.45"), + uuid.uuid4(), + b"binary", + [1, 2, 3], + {"nested": "dict"}, + ) + assert _roundtrip_envelope(all_types_tuple) == all_types_tuple + + +# endregion + + +# region EnvelopeSerDes Error Cases +def test_envelope_unsupported_type_error(): + serdes = ExtendedTypeSerDes() + context = SerDesContext("test-op", "test-arn") + with pytest.raises(SerDesError, match="Unsupported type: "): + serdes.serialize(object(), context) + + +# endregion + + +# region EnvelopeSerDes Format Validation +def test_envelope_format_structure(): + serdes = ExtendedTypeSerDes() + context = SerDesContext("test-op", "test-arn") + # Dict will use envelope format, primitives use plain JSON + serialized = serdes.serialize({"test": "value"}, context) + parsed = json.loads(serialized) + + # Verify envelope structure + assert "t" in parsed + assert "v" in parsed + assert parsed["t"] == "m" # dict tag + assert parsed["v"]["test"]["v"] == "value" + + +def test_envelope_compact_json_output(): + serdes = ExtendedTypeSerDes() + context = SerDesContext("test-op", "test-arn") + serialized = serdes.serialize({"key": "value"}, context) + # Should not contain extra whitespace + assert " " not in serialized + assert "\n" not in serialized + + +def test_envelope_bytes_base64_encoding(): + serdes = ExtendedTypeSerDes() + context = SerDesContext("test-op", "test-arn") + test_bytes = b"hello world" + serialized = serdes.serialize(test_bytes, context) + parsed = json.loads(serialized) + + # Verify base64 encoding + encoded_value = parsed["v"] + assert base64.b64decode(encoded_value) == test_bytes + + +# endregion + + +# region EnvelopeSerDes Integration Tests +def test_envelope_with_main_api(): + """Test EnvelopeSerDes works with main serialize/deserialize functions.""" + envelope_serdes = ExtendedTypeSerDes() + + test_data = { + "id": uuid.uuid4(), + "timestamp": datetime.now(UTC), + "amount": Decimal("123.45"), + "data": b"binary_data", + "coordinates": (40.7128, -74.0060), + "tags": ["important", "verified"], + } + + # Serialize with EnvelopeSerDes + serialized = serialize(envelope_serdes, test_data, "test-op", "test-arn") + + # Deserialize with EnvelopeSerDes + deserialized = deserialize(envelope_serdes, serialized, "test-op", "test-arn") + + assert deserialized == test_data + + +def test_envelope_vs_json_serdes_compatibility(): + """Test that EnvelopeSerDes and JsonSerDes can coexist.""" + json_serdes = JsonSerDes() + envelope_serdes = ExtendedTypeSerDes() + + # Simple data that both can handle + simple_data = {"name": "test", "value": 123, "active": True} + + # Both should serialize successfully + json_result = serialize(json_serdes, simple_data, "test-op", "test-arn") + envelope_result = serialize(envelope_serdes, simple_data, "test-op", "test-arn") + + # Results should be different (envelope has wrapper) + assert json_result != envelope_result + + # Both should deserialize to same data + json_deserialized = deserialize(json_serdes, json_result, "test-op", "test-arn") + envelope_deserialized = deserialize( + envelope_serdes, envelope_result, "test-op", "test-arn" + ) + + assert json_deserialized == simple_data + assert envelope_deserialized == simple_data + + +def test_envelope_handles_json_incompatible_types(): + """Test that EnvelopeSerDes handles types that JsonSerDes cannot.""" + json_serdes = JsonSerDes() + envelope_serdes = ExtendedTypeSerDes() + + # Data with types JsonSerDes cannot handle + complex_data = { + "uuid": uuid.uuid4(), + "decimal": Decimal("123.45"), + "bytes": b"binary", + "tuple": (1, 2, 3), + } + + # JsonSerDes should fail + with pytest.raises(FatalError): + serialize(json_serdes, complex_data, "test-op", "test-arn") + + # EnvelopeSerDes should succeed + serialized = serialize(envelope_serdes, complex_data, "test-op", "test-arn") + deserialized = deserialize(envelope_serdes, serialized, "test-op", "test-arn") + + assert deserialized == complex_data + + +def test_envelope_error_handling_with_main_api(): + """Test error handling when using EnvelopeSerDes with main API.""" + envelope_serdes = ExtendedTypeSerDes() + + # Test serialization error + with pytest.raises(FatalError, match="Serialization failed"): + serialize(envelope_serdes, object(), "test-op", "test-arn") + + # Test deserialization error + with pytest.raises(FatalError, match="Deserialization failed"): + deserialize(envelope_serdes, "invalid json", "test-op", "test-arn") + + +def test_primitive_codec_errors(): + """Test PrimitiveCodec error cases.""" + primitive_codec = PrimitiveCodec() + with pytest.raises(SerDesError, match="Unsupported primitive type"): + primitive_codec.encode(object()) + + with pytest.raises(SerDesError, match="Unknown primitive tag"): + primitive_codec.decode(TypeTag.BYTES, "test") + + +def test_bytes_codec_errors(): + """Test BytesCodec error cases.""" + bytes_codec = BytesCodec() + with pytest.raises(SerDesError, match="Expected BYTES tag, got"): + bytes_codec.decode(TypeTag.STR, "test") + + +def test_uuid_codec_errors(): + """Test UuidCodec error cases.""" + uuid_codec = UuidCodec() + with pytest.raises(SerDesError, match="Expected UUID tag, got"): + uuid_codec.decode(TypeTag.STR, "test") + + +def test_decimal_codec_errors(): + """Test DecimalCodec error cases.""" + + decimal_codec = DecimalCodec() + with pytest.raises(SerDesError, match="Expected DECIMAL tag, got"): + decimal_codec.decode(TypeTag.STR, "test") + + +def test_datetime_codec_errors(): + """Test DateTimeCodec error cases.""" + datetime_codec = DateTimeCodec() + with pytest.raises(SerDesError, match="Unsupported datetime type"): + datetime_codec.encode("not a datetime") + + with pytest.raises(SerDesError, match="Unknown datetime tag"): + datetime_codec.decode(TypeTag.BYTES, "test") + + +def test_datetime_codec_z_suffix(): + """Test DateTimeCodec Z suffix handling.""" + datetime_codec = DateTimeCodec() + result = datetime_codec.decode(TypeTag.DATETIME, "2024-01-01T00:00:00Z") + expected = datetime.fromisoformat("2024-01-01T00:00:00+00:00") + assert result == expected + + +def test_container_codec_errors(): + """Test ContainerCodec error cases.""" + container_codec = ContainerCodec() + type_codec = TypeCodec() + container_codec.set_dispatcher(type_codec) + + with pytest.raises(SerDesError, match="Unsupported container type"): + container_codec.encode("not a container") + + with pytest.raises(SerDesError, match="Unknown container tag"): + container_codec.decode(TypeTag.BYTES, "test") + + with pytest.raises(SerDesError, match="Tuple keys not supported"): + container_codec.encode({(1, 2): "value"}) + + # Test without dispatcher + container_codec_no_dispatcher = ContainerCodec() + with pytest.raises( + DurableExecutionsError, + match="ContainerCodec not linked to a TypeCodec dispatcher", + ): + _ = container_codec_no_dispatcher.dispatcher + + # Test decode with wrong value types + with pytest.raises(SerDesError, match="Expected list, got"): + container_codec.decode(TypeTag.LIST, "not a list") + + with pytest.raises(SerDesError, match="Expected list, got"): + container_codec.decode(TypeTag.TUPLE, "not a list") + + with pytest.raises(SerDesError, match="Expected dict, got"): + container_codec.decode(TypeTag.DICT, "not a dict") + + # Test _unwrap with plain object (case _ branch) + result = ContainerCodec._unwrap("plain_string", type_codec) # noqa: SLF001 + assert result == "plain_string" + + # Test _unwrap with EncodedValue (case EncodedValue branch) + encoded_val = EncodedValue(TypeTag.STR, "test") + result = ContainerCodec._unwrap(encoded_val, type_codec) # noqa: SLF001 + assert result == "test" + + +def test_type_codec_errors(): + """Test TypeCodec error cases.""" + type_codec = TypeCodec() + + with pytest.raises(SerDesError, match="Unsupported type"): + type_codec.encode(object()) + + class MockTag: + def __str__(self): + return "unknown" + + with pytest.raises(SerDesError, match="Unknown type tag"): + type_codec.decode(MockTag(), "test") + + +def test_extended_serdes_errors(): + """Test ExtendedTypesSerDes error cases.""" + serdes = ExtendedTypeSerDes() + + with pytest.raises( + SerDesError, match='Malformed envelope: missing "t" or "v" at root' + ): + serdes.deserialize('{"invalid": "envelope"}', None) + + with pytest.raises(SerDesError, match='Unknown type tag: "unknown"'): + serdes.deserialize('{"t": "unknown", "v": "test"}', None) + + +# endregion + + +# region EnvelopeSerDes Performance and Edge Cases +def test_envelope_large_data_structure(): + """Test with reasonably large data.""" + large_list = list(range(1000)) + large_dict = {f"key_{i}": f"value_{i}" for i in range(100)} + large_tuple = tuple(range(500)) + + large_structure = { + "list": large_list, + "dict": large_dict, + "tuple": large_tuple, + } + + result = _roundtrip_envelope(large_structure) + assert result == large_structure + + +def test_envelope_empty_containers(): + empty_data = { + "empty_list": [], + "empty_dict": {}, + "empty_tuple": (), + "empty_string": "", + "empty_bytes": b"", + } + assert _roundtrip_envelope(empty_data) == empty_data + + +def test_envelope_type_preservation_after_roundtrip(): + original = { + "none": None, + "bool": True, + "int": 42, + "float": math.pi, + "str": "text", + "datetime": datetime.now(UTC), + "date": date.today(), # noqa: DTZ011 + "decimal": Decimal("123.45"), + "uuid": uuid.uuid4(), + "bytes": b"data", + "tuple": (1, 2, 3), + "list": [1, 2, 3], + "dict": {"nested": True}, + } + + result = _roundtrip_envelope(original) + + # Verify types are preserved + assert type(result["none"]) is type(None) + assert type(result["bool"]) is bool + assert type(result["int"]) is int + assert type(result["float"]) is float + assert type(result["str"]) is str + assert type(result["datetime"]) is datetime + assert type(result["date"]) is date + assert type(result["decimal"]) is Decimal + assert type(result["uuid"]) is uuid.UUID + assert type(result["bytes"]) is bytes + assert type(result["tuple"]) is tuple + assert type(result["list"]) is list + assert type(result["dict"]) is dict + + +def test_envelope_unicode_and_special_characters(): + unicode_data = { + "emoji": "🚀🌟💫", + "chinese": "你好世界", + "arabic": "مرحبا بالعالم", + "russian": "Привет мир", + "special": "\"'\\n\\t\\r", + "zero_width": "\u200b\u200c\u200d", + } + assert _roundtrip_envelope(unicode_data) == unicode_data + + +def test_primitives(): + primitives = [ + 123, + "hello", + True, + False, + None, + math.pi, + Decimal("10.5"), + uuid.UUID("12345678-1234-5678-1234-567812345678"), + b"bytes here", + date(2025, 10, 22), + datetime(2025, 10, 22, 15, 30, 0), # noqa: DTZ001 + ] + serdes = ExtendedTypeSerDes() + ctx = SerDesContext("test-op", "test-arn") + for val in primitives: + serialized = serdes.serialize(val, ctx) + deserialized = serdes.deserialize(serialized, ctx) + assert deserialized == val + + +def test_nested_arrays(): + serdes = ExtendedTypeSerDes() + ctx = SerDesContext("test-op", "test-arn") + val = [1, "two", [3, {"four": 4}], True, b"hi"] + serialized = serdes.serialize(val, ctx) + deserialized = serdes.deserialize(serialized, ctx) + assert deserialized == val + + +def test_nested_dicts(): + val = { + "a": 1, + "b": [2, 3, {"c": 4}], + "d": {"e": "f", "g": [5, 6]}, + "h": b"bytes in dict", + "i": uuid.UUID("12345678-1234-5678-1234-567812345678"), + } + serdes = ExtendedTypeSerDes() + ctx = SerDesContext("test-op", "test-arn") + serialized = serdes.serialize(val, ctx) + deserialized = serdes.deserialize(serialized, ctx) + assert deserialized == val + + +def test_user_dict_with_t_v_keys(): + val = {"t": "user t value", "v": "user v value"} + serdes = ExtendedTypeSerDes() + ctx = SerDesContext("test-op", "test-arn") + serialized = serdes.serialize(val, ctx) + deserialized = serdes.deserialize(serialized, ctx) + assert deserialized == val + + +def test_complex_nested_structure(): + val = { + "list": [1, 2, [3, 4], {"nested_bytes": b"abc"}], + "tuple": (Decimal("3.14"), True), + "dict": { + "uuid": uuid.UUID("12345678-1234-5678-1234-567812345678"), + "date": date(2025, 10, 22), + "datetime": datetime(2025, 10, 22, 15, 30), # noqa: DTZ001 + }, + "t_v": {"t": "inner t", "v": [1, b"bytes"]}, + } + serdes = ExtendedTypeSerDes() + ctx = SerDesContext("test-op", "test-arn") + serialized = serdes.serialize(val, ctx) + deserialized = serdes.deserialize(serialized, ctx) + assert deserialized == val + + +def test_all_t_v_nested_dicts(): + val = { + "t": {"t": "s", "v": "outer t"}, + "v": { + "t": {"t": "s", "v": "inner t"}, + "v": {"t": {"t": "s", "v": "deep t"}, "v": "deep v"}, + }, + } + serdes = ExtendedTypeSerDes() + ctx = SerDesContext("test-op", "test-arn") + serialized = serdes.serialize(val, ctx) + deserialized = serdes.deserialize(serialized, ctx) + assert deserialized == val + + +# endregion