Skip to content

Commit

Permalink
Merge pull request #178 from Fatal1ty/inspect-signature-with-forward-ref
Browse files Browse the repository at this point in the history
Use a ForwardRef object as a result of inspect.signature annotations
  • Loading branch information
Fatal1ty authored Nov 14, 2023
2 parents b7d6435 + 6cbc63c commit 14b9389
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 14 deletions.
5 changes: 3 additions & 2 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def evaluate_forward_ref(
typ: typing.ForwardRef,
owner: typing.Optional[typing.Type],
) -> typing.Optional[typing.Type]:
if not getattr(typ, "__forward_module__", None) and owner:
forward_module = getattr(typ, "__forward_module__", None)
if not forward_module and owner:
# We can't get the module in which ForwardRef's value is defined on
# Python < 3.10, ForwardRef evaluation might not work properly
# without this information, so we will consider the namespace of
Expand All @@ -304,7 +305,7 @@ def evaluate_forward_ref(
self.globals,
)
else:
globalns = self.globals
globalns = getattr(forward_module, "__dict__", self.globals)
return evaluate_forward_ref(typ, globalns, self.__dict__)

def get_declared_hook(self, method_name: str) -> typing.Any:
Expand Down
31 changes: 21 additions & 10 deletions mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,10 @@ def get_function_arg_annotation(
annotation = parameter.annotation
if annotation is inspect.Signature.empty:
raise ValueError(f"Argument {arg_name} doesn't have annotation")
if isinstance(annotation, str):
annotation = str_to_forward_ref(
annotation, inspect.getmodule(function)
)
return annotation


Expand All @@ -689,6 +693,10 @@ def get_function_return_annotation(
annotation = inspect.signature(function).return_annotation
if annotation is inspect.Signature.empty:
raise ValueError("Function doesn't have return annotation")
if isinstance(annotation, str):
annotation = str_to_forward_ref(
annotation, inspect.getmodule(function)
)
return annotation


Expand Down Expand Up @@ -738,18 +746,21 @@ def is_hashable_type(typ: Any) -> bool:
return True


if PY_39_MIN:
def str_to_forward_ref(
annotation: str, module: Optional[types.ModuleType] = None
) -> ForwardRef:
if PY_39_MIN:
return ForwardRef(annotation, module=module) # type: ignore
else:
return ForwardRef(annotation)


def evaluate_forward_ref(
typ: ForwardRef, globalns: Any, localns: Any
) -> Optional[Type]:
def evaluate_forward_ref(
typ: ForwardRef, globalns: Any, localns: Any
) -> Optional[Type]:
if PY_39_MIN:
return typ._evaluate(
globalns, localns, frozenset()
) # type: ignore[call-arg]

else:

def evaluate_forward_ref(
typ: ForwardRef, globalns: Any, localns: Any
) -> Optional[Type]:
else:
return typ._evaluate(globalns, localns) # type: ignore[call-arg]
10 changes: 9 additions & 1 deletion mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,12 @@ def _pack_with_annotated_serialization_strategy(
)
except (KeyError, ValueError):
value_type = Any
if isinstance(value_type, ForwardRef):
value_type = spec.builder.evaluate_forward_ref(
value_type, spec.origin_type
)
value_type = substitute_type_params(
value_type,
value_type, # type: ignore
resolve_type_params(strategy_type, get_args(spec.type))[strategy_type],
)
overridden_fn = f"__{spec.field_ctx.name}_serialize_{random_hex()}"
Expand Down Expand Up @@ -176,6 +180,10 @@ def _pack_annotated_serializable_type(
) from None
if is_self(value_type):
return f"{spec.expression}._serialize()"
if isinstance(value_type, ForwardRef):
value_type = spec.builder.evaluate_forward_ref(
value_type, spec.origin_type
)
value_type = substitute_type_params(
value_type,
resolve_type_params(spec.origin_type, get_args(spec.type))[
Expand Down
10 changes: 9 additions & 1 deletion mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,12 @@ def _unpack_with_annotated_serialization_strategy(
)
except (KeyError, ValueError):
value_type = Any
if isinstance(value_type, ForwardRef):
value_type = spec.builder.evaluate_forward_ref(
value_type, spec.origin_type
)
value_type = substitute_type_params(
value_type,
value_type, # type: ignore
resolve_type_params(strategy_type, get_args(spec.type))[strategy_type],
)
overridden_fn = f"__{spec.field_ctx.name}_deserialize_{random_hex()}"
Expand Down Expand Up @@ -540,6 +544,10 @@ def _unpack_annotated_serializable_type(
) from None
if is_self(value_type):
return f"{type_name(spec.type)}._deserialize({spec.expression})"
if isinstance(value_type, ForwardRef):
value_type = spec.builder.evaluate_forward_ref(
value_type, spec.origin_type
)
value_type = substitute_type_params(
value_type,
resolve_type_params(spec.origin_type, get_args(spec.type))[
Expand Down
40 changes: 40 additions & 0 deletions tests/test_forward_refs/test_generic_serializable_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import date
from typing import Dict, Generic, TypeVar

from mashumaro import DataClassDictMixin
from mashumaro.types import SerializableType

T = TypeVar("T")


class Foo(Generic[T], SerializableType, use_annotations=True):
a: T

def __init__(self, a: T) -> None:
self.a = a

@classmethod
def _deserialize(cls, value: Dict[str, T]) -> Foo[T]:
return cls(**value)

def _serialize(self) -> Dict[str, T]:
return {"a": self.a}

def __eq__(self, other: Foo) -> bool:
return self.a == other.a


@dataclass
class Bar(DataClassDictMixin):
x_str: Foo[str]
x_date: Foo[date]


def test_generic_serializable_type():
data = {"x_str": {"a": "2023-11-14"}, "x_date": {"a": "2023-11-14"}}
obj = Bar(Foo("2023-11-14"), Foo(date(2023, 11, 14)))
assert obj.to_dict() == data
assert Bar.from_dict(data) == obj
45 changes: 45 additions & 0 deletions tests/test_forward_refs/test_generic_serialization_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from dataclasses import dataclass
from datetime import date
from typing import Dict, Generic, TypeVar

from mashumaro import DataClassDictMixin
from mashumaro.config import BaseConfig
from mashumaro.types import SerializationStrategy

T = TypeVar("T")


class Foo(Generic[T]):
a: T

def __init__(self, a: T) -> None:
self.a = a

def __eq__(self, other: Foo) -> bool:
return self.a == other.a


class FooStrategy(Generic[T], SerializationStrategy):
def deserialize(self, value: Dict[str, T]) -> Foo[T]:
return Foo(**value)

def serialize(self, value: Foo[T]) -> Dict[str, T]:
return {"a": value.a}


@dataclass
class Bar(DataClassDictMixin):
x_str: Foo[str]
x_date: Foo[date]

class Config(BaseConfig):
serialization_strategy = {Foo: FooStrategy()}


def test_generic_serialization_strategy():
data = {"x_str": {"a": "2023-11-14"}, "x_date": {"a": "2023-11-14"}}
obj = Bar(Foo("2023-11-14"), Foo(date(2023, 11, 14)))
assert obj.to_dict() == data
assert Bar.from_dict(data) == obj

0 comments on commit 14b9389

Please sign in to comment.