Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for recursive Union #266

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mashumaro/core/meta/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self, expression: str):
class FieldContext:
name: str
metadata: Mapping
packer: Optional[str] = None
unpacker: Optional[str] = None

def copy(self, **changes: Any) -> "FieldContext":
return replace(self, **changes)
Expand Down Expand Up @@ -181,8 +183,13 @@ def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str:
def _before_build(self, spec: ValueSpec) -> None:
pass

def _get_existing_method(self, spec: ValueSpec) -> Optional[str]:
return None

def build(self, spec: ValueSpec) -> str:
self._before_build(spec)
if method := self._get_existing_method(spec):
return method
lines = CodeLines()
method_name = self._add_definition(spec, lines)
with lines.indent():
Expand Down
20 changes: 19 additions & 1 deletion mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,27 @@ def pack_any(spec: ValueSpec) -> Optional[Expression]:
def pack_union(
spec: ValueSpec, args: tuple[type, ...], prefix: str = "union"
) -> Expression:
if spec.type is spec.owner and spec.field_ctx.packer:
return spec.field_ctx.packer
lines = CodeLines()

method_name = (
f"__pack_{prefix}_{spec.builder.cls.__name__}_{spec.field_ctx.name}__"
f"{random_hex()}"
)

if not spec.field_ctx.packer:
method_args = ", ".join(
filter(None, ("value", spec.builder.get_pack_method_flags()))
)
if spec.builder.is_nailed:
union_packer = (
f"{spec.self_attrs_name}.{method_name}({method_args})"
)
else:
union_packer = f"{method_name}({method_args})"
spec.field_ctx.packer = union_packer

method_args = "self, value" if spec.builder.is_nailed else "value"
default_kwargs = spec.builder.get_pack_method_default_flag_values()
if default_kwargs:
Expand All @@ -304,7 +320,7 @@ def pack_union(
packer_arg_types: dict[str, list[type]] = {}
for type_arg in args:
packer = PackerRegistry.get(
spec.copy(type=type_arg, expression="value")
spec.copy(type=type_arg, expression="value", owner=spec.type)
)
if packer not in packers:
if packer == "value":
Expand Down Expand Up @@ -363,7 +379,9 @@ def pack_union(
if spec.builder.get_config().debug:
print(f"{type_name(spec.builder.cls)}:")
print(lines.as_text())

exec(lines.as_text(), spec.builder.globals, spec.builder.__dict__)

method_args = ", ".join(
filter(None, (spec.expression, spec.builder.get_pack_method_flags()))
)
Expand Down
16 changes: 15 additions & 1 deletion mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,21 @@ def _get_call_expr(self, spec: ValueSpec, method_name: str) -> str:
class UnionUnpackerBuilder(AbstractUnpackerBuilder):
def __init__(self, args: tuple[type, ...]):
self.union_args = args
self.method_name: Optional[str] = None

def get_method_prefix(self) -> str:
return "union"

def _generate_method_name(self, spec: ValueSpec) -> str:
method_name = super()._generate_method_name(spec)
self.method_name = method_name
return method_name

def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
if not spec.field_ctx.unpacker and self.method_name:
spec.field_ctx.unpacker = self._get_call_expr(
spec, self.method_name
)
orig_lines = lines
lines = CodeLines()
unpackers = set()
Expand All @@ -175,7 +185,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
type_match_statements = 0
for type_arg in self.union_args:
unpacker = UnpackerRegistry.get(
spec.copy(type=type_arg, expression="value")
spec.copy(type=type_arg, expression="value", owner=spec.type)
)
type_arg_unpackers.append((type_arg, unpacker))
if isinstance(unpacker, TypeMatchEligibleExpression):
Expand Down Expand Up @@ -230,6 +240,10 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
orig_lines.append("__value_type = type(value)")
orig_lines.extend(lines)

def _get_existing_method(self, spec: ValueSpec) -> Optional[str]:
if spec.owner is spec.type:
return spec.field_ctx.unpacker


class TypeVarUnpackerBuilder(UnionUnpackerBuilder):
def get_method_prefix(self) -> str:
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
collect_ignore = [
"test_generics_pep_695.py",
"test_pep_695.py",
"test_recursive_union.py",
]

if PY_313_MIN:
Expand Down
72 changes: 72 additions & 0 deletions tests/test_recursive_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass

from mashumaro import DataClassDictMixin
from mashumaro.codecs import BasicDecoder, BasicEncoder

type JSON = str | int | float | bool | dict[str, JSON] | list[JSON] | None


@dataclass
class MyClass:
x: str
y: JSON


def test_encoder_with_recursive_union():
encoder = BasicEncoder(JSON)
assert encoder.encode(
{"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
) == {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}


def test_encoder_with_recursive_union_in_dataclass():
encoder = BasicEncoder(MyClass)
assert encoder.encode(
MyClass(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
)
) == {
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}


def test_decoder_with_recursive_union():
decoder = BasicDecoder(JSON)
assert decoder.decode(
{"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
) == {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}


def test_decoder_with_recursive_union_in_dataclass():
decoder = BasicDecoder(MyClass)
assert decoder.decode(
{
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}
) == MyClass(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
)


def test_dataclass_dict_mixin_with_recursive_union():
@dataclass
class MyClassWithMixin(DataClassDictMixin):
x: str
y: JSON

assert MyClassWithMixin(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
).to_dict() == {
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}
assert MyClassWithMixin.from_dict(
{
"x": "x",
"y": {"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]},
}
) == MyClassWithMixin(
x="x", y={"x": [{"x": {"x": [{"x": ["x", 1, 1.0, True, None]}]}}]}
)
Loading