Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"is_dataclass_dict_mixin_subclass",
"collect_type_params",
"resolve_type_params",
"resolve_type_alias_type",
"substitute_type_params",
"get_generic_name",
"get_name_error_name",
Expand Down Expand Up @@ -643,6 +644,20 @@ def substitute_type_params(typ: Type, substitutions: dict[Type, Type]) -> Type:
return typ


def resolve_type_alias_type(typ: Type) -> Type:
while True:
if is_type_alias_type(typ):
typ = typ.__value__
elif is_type_alias_type(get_type_origin(typ)):
origin = get_type_origin(typ)
type_params = getattr(origin, "__type_params__", ())
args = get_args(typ)
param_map = dict(zip(type_params, args))
typ = substitute_type_params(origin.__value__, param_map)
Comment on lines +661 to +683

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced the silent truncation in 829f7c8: missing trailing parameters now fall back to their PEP 696 defaults when defined (TypeAliasType subscription doesn't auto-fill defaults, so Pair[int] with TypeVar("V", default=str) previously leaked an unsubstituted ~V into the resolved type), and a genuine mismatch raises TypeError: Too few/Too many arguments for ...; actual N, expected M in the same format as the existing generics arity check. The check is skipped for TypeVarTuple/ParamSpec parameters since their subscription arity is flexible. One behavior note: invalid annotations like Pair[int] without a default used to serialize with a half-substituted value and now fail at class creation — happy to relax that if you'd rather keep them lenient.

else:
return typ
Comment on lines +647 to +685

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added cycle protection in 829f7c8: the resolver tracks seen values and bounds the number of resolution steps, raising TypeError: Cannot resolve recursive type alias ... for both exact cycles (type A = B; type B = A) and non-repeating chains (type G[T] = G[list[T]]). The seen set tracks values rather than id()s because each substitution builds a new (equal) object for parameterized cycles; is_hashable() guards unhashable values and the step bound is the backstop. Raising a targeted error seemed more useful than returning the unresolved alias, which would just push the non-termination into the packer/unpacker dispatch. Tests cover both shapes, including at dataclass creation time.



def get_name_error_name(e: NameError) -> str:
return e.name # type: ignore

Expand Down
17 changes: 8 additions & 9 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
is_union,
is_unpack,
not_none_type_arg,
resolve_type_alias_type,
resolve_type_params,
substitute_type_params,
type_name,
Expand Down Expand Up @@ -325,6 +326,7 @@ def pack_union(
packers: list[str] = []
packer_arg_types: dict[str, list[type]] = {}
for type_arg in args:
type_arg = resolve_type_alias_type(type_arg)
packer = PackerRegistry.get(
spec.copy(type=type_arg, expression="value", owner=spec.type)
)
Expand Down Expand Up @@ -549,15 +551,12 @@ def pack_special_typing_primitive(spec: ValueSpec) -> Expression | None:
evaluated = evaluate_forward_ref(spec.type)
if evaluated is not None:
return PackerRegistry.get(spec.copy(type=evaluated))
elif is_type_alias_type(spec.type):
return PackerRegistry.get(spec.copy(type=spec.type.__value__))
elif is_type_alias_type(get_type_origin(spec.type)):
origin = get_type_origin(spec.type)
type_params = getattr(origin, "__type_params__", ())
args = get_args(spec.type)
param_map = dict(zip(type_params, args))
resolved = substitute_type_params(origin.__value__, param_map)
return PackerRegistry.get(spec.copy(type=resolved))
elif is_type_alias_type(spec.type) or is_type_alias_type(
get_type_origin(spec.type)
):
return PackerRegistry.get(
spec.copy(type=resolve_type_alias_type(spec.type))
)
elif is_readonly(spec.type):
return PackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
raise UnserializableDataError(
Expand Down
17 changes: 8 additions & 9 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
is_unpack,
iter_all_subclasses,
not_none_type_arg,
resolve_type_alias_type,
resolve_type_params,
substitute_type_params,
type_name,
Expand Down Expand Up @@ -195,6 +196,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
type_arg_unpackers = []
type_match_statements = 0
for type_arg in self.union_args:
type_arg = resolve_type_alias_type(type_arg)
unpacker = UnpackerRegistry.get(
spec.copy(type=type_arg, expression="value", owner=spec.type)
)
Expand Down Expand Up @@ -874,15 +876,12 @@ def unpack_special_typing_primitive(spec: ValueSpec) -> Expression | None:
evaluated = evaluate_forward_ref(spec.type)
if evaluated is not None:
return UnpackerRegistry.get(spec.copy(type=evaluated))
elif is_type_alias_type(spec.type):
return UnpackerRegistry.get(spec.copy(type=spec.type.__value__))
elif is_type_alias_type(get_type_origin(spec.type)):
origin = get_type_origin(spec.type)
type_params = getattr(origin, "__type_params__", ())
args = get_args(spec.type)
param_map = dict(zip(type_params, args))
resolved = substitute_type_params(origin.__value__, param_map)
return UnpackerRegistry.get(spec.copy(type=resolved))
elif is_type_alias_type(spec.type) or is_type_alias_type(
get_type_origin(spec.type)
):
return UnpackerRegistry.get(
spec.copy(type=resolve_type_alias_type(spec.type))
)
elif is_readonly(spec.type):
return UnpackerRegistry.get(spec.copy(type=get_args(spec.type)[0]))
raise UnserializableDataError(
Expand Down
71 changes: 71 additions & 0 deletions tests/test_pep_695.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from mashumaro import DataClassDictMixin
from mashumaro.codecs import BasicDecoder, BasicEncoder
from mashumaro.config import TO_DICT_ADD_OMIT_NONE_FLAG, BaseConfig
from mashumaro.exceptions import MissingField
from tests.entities_pep_695 import (
Boxed,
Expand Down Expand Up @@ -144,3 +145,73 @@ def test_recursive_generic_alias_with_serializable_type():
DataClassWithRecursiveGenericAlias.from_dict({"x": ("hello", 7)})
== obj2
)


def test_type_alias_type_nested_in_union():
# https://github.com/Fatal1ty/mashumaro/issues/330
type UniqueIdentifier = str
type UniqueIdentifierList = list[UniqueIdentifier]
type UniqueIdentifierOrList = UniqueIdentifier | UniqueIdentifierList

@dataclass
class MyClass(DataClassDictMixin):
x: UniqueIdentifierOrList | None = None

assert MyClass(x="a").to_dict() == {"x": "a"}
assert MyClass(x=["a", "b"]).to_dict() == {"x": ["a", "b"]}
assert MyClass().to_dict() == {"x": None}
assert MyClass.from_dict({"x": "a"}) == MyClass(x="a")
assert MyClass.from_dict({"x": ["a", "b"]}) == MyClass(x=["a", "b"])
assert MyClass.from_dict({"x": None}) == MyClass()


@pytest.mark.parametrize("lazy", [False, True])
def test_type_alias_type_nested_in_union_with_omit_none_flag(lazy):
type UniqueIdentifier = str
type UniqueIdentifierList = list[UniqueIdentifier]
type UniqueIdentifierOrList = UniqueIdentifier | UniqueIdentifierList

@dataclass
class MyClass(DataClassDictMixin):
x: UniqueIdentifierOrList | None = None

class Config(BaseConfig):
code_generation_options = [TO_DICT_ADD_OMIT_NONE_FLAG]
lazy_compilation = lazy

assert MyClass(x="a").to_dict(omit_none=True) == {"x": "a"}
assert MyClass(x=["a", "b"]).to_dict(omit_none=True) == {"x": ["a", "b"]}
assert MyClass().to_dict(omit_none=True) == {}
assert MyClass.from_dict({"x": ["a", "b"]}) == MyClass(x=["a", "b"])


def test_type_alias_type_nested_in_union_with_codecs():
type UniqueIdentifier = str
type UniqueIdentifierList = list[UniqueIdentifier]
type UniqueIdentifierOrList = UniqueIdentifier | UniqueIdentifierList

decoder = BasicDecoder(UniqueIdentifierOrList)
encoder = BasicEncoder(UniqueIdentifierOrList)

assert decoder.decode("a") == "a"
assert decoder.decode(["a", "b"]) == ["a", "b"]
assert encoder.encode("a") == "a"
assert encoder.encode(["a", "b"]) == ["a", "b"]


def test_parameterized_type_alias_type_in_union():
type Identity[T] = T
type ListOf[T] = list[T]

@dataclass
class MyClass(DataClassDictMixin):
x: date | ListOf[int]
y: int | Identity[str]

obj1 = MyClass(x=date(2024, 4, 15), y=42)
assert obj1.to_dict() == {"x": "2024-04-15", "y": 42}
assert MyClass.from_dict({"x": "2024-04-15", "y": 42}) == obj1

obj2 = MyClass(x=[1, 2, 3], y="a")
assert obj2.to_dict() == {"x": [1, 2, 3], "y": "a"}
assert MyClass.from_dict({"x": [1, 2, 3], "y": "a"}) == obj2