Skip to content

Commit

Permalink
Merge branch 'forward-references'
Browse files Browse the repository at this point in the history
Fatal1ty committed Apr 19, 2021
2 parents 57df4d1 + 127d5aa commit edf9eb7
Showing 10 changed files with 181 additions and 57 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ Table of contents
* [Benchmark](#benchmark)
* [API](#api)
* [Customization](#customization)
* [Serializable Interface](#serializable-interface)
* [SerializableType Interface](#serializabletype-interface)
* [Field options](#field-options)
* [`serialize` option](#serialize-option)
* [`deserialize` option](#deserialize-option)
@@ -350,7 +350,7 @@ decoder_kwargs # keyword arguments for decoder function
Customization
--------------------------------------------------------------------------------

### Serializable Interface
### SerializableType Interface

If you already have a separate custom class, and you want to serialize
instances of it with *mashumaro*, you can achieve this by implementing
19 changes: 19 additions & 0 deletions mashumaro/exceptions.py
Original file line number Diff line number Diff line change
@@ -82,3 +82,22 @@ def __str__(self):

class BadHookSignature(TypeError):
pass


class ThirdPartyModuleNotFoundError(ModuleNotFoundError):
def __init__(self, module_name, field_name, holder_class):
self.module_name = module_name
self.field_name = field_name
self.holder_class = holder_class

@property
def holder_class_name(self):
return type_name(self.holder_class)

def __str__(self):
s = (
f'Install "{self.module_name}" to use it as the serialization '
f'method for the field "{self.field_name}" '
f"in {self.holder_class_name}"
)
return s
24 changes: 14 additions & 10 deletions mashumaro/meta/helpers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
import dataclasses
import types
import typing

from .macros import PY_36, PY_37, PY_38, PY_39


def get_imported_module_names():
# noinspection PyUnresolvedReferences
return {
value.__name__
for value in globals().values()
if isinstance(value, types.ModuleType)
}
DataClassDictMixinPath = "mashumaro.serializer.base.dict.DataClassDictMixin"


def get_type_origin(t):
@@ -91,8 +83,18 @@ def get_class_that_define_method(method_name, cls):
return cls


def is_dataclass_dict_mixin(t):
return type_name(t) == DataClassDictMixinPath


def is_dataclass_dict_mixin_subclass(t):
for cls in t.__mro__:
if is_dataclass_dict_mixin(cls):
return True
return False


__all__ = [
"get_imported_module_names",
"get_type_origin",
"type_name",
"is_special_typing_primitive",
@@ -102,4 +104,6 @@ def get_class_that_define_method(method_name, cls):
"is_class_var",
"is_init_var",
"get_class_that_define_method",
"is_dataclass_dict_mixin",
"is_dataclass_dict_mixin_subclass",
]
94 changes: 50 additions & 44 deletions mashumaro/serializer/base/metaprogramming.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# noinspection PyUnresolvedReferences
import builtins # noqa
import collections
import collections.abc
import datetime
import enum
import importlib
import inspect
import ipaddress
import os
import pathlib
import sys
import types
import typing
import uuid

# noinspection PyUnresolvedReferences
from base64 import decodebytes, encodebytes # noqa
from contextlib import contextmanager, suppress

@@ -25,20 +25,20 @@
TO_DICT_ADD_OMIT_NONE_FLAG,
BaseConfig,
)

# noinspection PyUnresolvedReferences
from mashumaro.exceptions import ( # noqa
BadHookSignature,
InvalidFieldValue,
MissingField,
ThirdPartyModuleNotFoundError,
UnserializableDataError,
UnserializableField,
)
from mashumaro.meta.helpers import (
get_class_that_define_method,
get_imported_module_names,
get_type_origin,
is_class_var,
is_dataclass_dict_mixin,
is_dataclass_dict_mixin_subclass,
is_generic,
is_init_var,
is_special_typing_primitive,
@@ -51,20 +51,26 @@
from mashumaro.serializer.base.helpers import * # noqa
from mashumaro.types import SerializableType, SerializationStrategy

try:
import ciso8601
except ImportError: # pragma no cover
ciso8601: typing.Optional[types.ModuleType] = None # type: ignore
try:
import pendulum
except ImportError: # pragma no cover
pendulum: typing.Optional[types.ModuleType] = None # type: ignore

patch_fromisoformat()


NoneType = type(None)
INITIAL_MODULES = get_imported_module_names()


__PRE_SERIALIZE__ = "__pre_serialize__"
__PRE_DESERIALIZE__ = "__pre_deserialize__"
__POST_SERIALIZE__ = "__post_serialize__"
__POST_DESERIALIZE__ = "__post_deserialize__"

DataClassDictMixinPath = "mashumaro.serializer.base.dict.DataClassDictMixin"


class CodeLines:
def __init__(self):
@@ -94,13 +100,11 @@ class CodeBuilder:
def __init__(self, cls):
self.cls = cls
self.lines: CodeLines = CodeLines()
self.modules: typing.Set[str] = set()
self.globals: typing.Set[str] = set()
self.globals: typing.Dict[str, typing.Any] = {}

def reset(self) -> None:
self.lines.reset()
self.modules = INITIAL_MODULES.copy()
self.globals = set()
self.globals = globals().copy()

@property
def namespace(self) -> typing.Dict[typing.Any, typing.Any]:
@@ -114,7 +118,9 @@ def __get_field_types(
self, recursive=True
) -> typing.Dict[str, typing.Any]:
fields = {}
for fname, ftype in typing.get_type_hints(self.cls).items():
globalns = sys.modules[self.cls.__module__].__dict__.copy()
globalns[self.cls.__name__] = self.cls
for fname, ftype in typing.get_type_hints(self.cls, globalns).items():
if is_class_var(ftype) or is_init_var(ftype):
continue
if recursive or fname in self.annotations:
@@ -162,7 +168,7 @@ def metadatas(self) -> typing.Dict[str, typing.Mapping[str, typing.Any]]:

def _add_type_modules(self, *types_) -> None:
for t in types_:
module = getattr(t, "__module__", None)
module = inspect.getmodule(t)
if not module:
return
self.ensure_module_imported(module)
@@ -173,18 +179,10 @@ def _add_type_modules(self, *types_) -> None:
if constraints:
self._add_type_modules(*constraints)

def ensure_module_imported(self, module: str) -> None:
if module not in self.modules:
self.modules.add(module)
self.add_line(f"if '{module}' not in globals():")
with self.indent():
self.add_line(f"import {module}")
root_module = module.split(".")[0]
if root_module not in self.globals:
self.globals.add(root_module)
self.add_line("else:")
with self.indent():
self.add_line(f"global {root_module}")
def ensure_module_imported(self, module: types.ModuleType) -> None:
self.globals.setdefault(module.__name__, module)
package = module.__name__.split(".")[0]
self.globals.setdefault(package, importlib.import_module(package))

def add_line(self, line: str) -> None:
self.lines.append(line)
@@ -199,13 +197,13 @@ def compile(self) -> None:
if self.get_config().debug:
print(self.cls)
print(code)
exec(code, globals(), self.__dict__)
exec(code, self.globals, self.__dict__)

def get_declared_hook(self, method_name: str):
if not hasattr(self.cls, method_name):
return
cls = get_class_that_define_method(method_name, self.cls)
if type_name(cls) != DataClassDictMixinPath:
if not is_dataclass_dict_mixin(cls):
return cls.__dict__[method_name]

def add_from_dict(self) -> None:
@@ -477,10 +475,6 @@ def _pack_value(
)
overridden = f"self.__{fname}_serialize({value_name})"

if is_dataclass(ftype):
flags = self.get_to_dict_flags(ftype)
return overridden or f"{value_name}.to_dict({flags})"

with suppress(TypeError):
if issubclass(ftype, SerializableType):
return overridden or f"{value_name}._serialize()"
@@ -688,6 +682,9 @@ def inner_expr(arg_num=0, v_name="value", v_type=None):
elif issubclass(origin_type, enum.Enum):
specific = f"{value_name}.value"
return f"{value_name} if use_enum else {overridden or specific}"
elif is_dataclass_dict_mixin_subclass(ftype):
flags = self.get_to_dict_flags(ftype)
return overridden or f"{value_name}.to_dict({flags})"
elif overridden:
return overridden

@@ -718,12 +715,6 @@ def _unpack_field_value(
setattr(self.cls, f"__{fname}_deserialize", deserialize_option)
overridden = f"cls.__{fname}_deserialize({value_name})"

if is_dataclass(ftype):
return overridden or (
f"{type_name(ftype)}.from_dict({value_name}, "
f"use_bytes, use_enum, use_datetime)"
)

with suppress(TypeError):
if issubclass(ftype, SerializableType):
return (
@@ -772,11 +763,21 @@ def _unpack_field_value(
return f"{value_name} if use_datetime else {overridden}"
elif deserialize_option is not None:
if deserialize_option == "ciso8601":
self.ensure_module_imported("ciso8601")
datetime_parser = "ciso8601.parse_datetime"
if ciso8601:
self.ensure_module_imported(ciso8601)
datetime_parser = "ciso8601.parse_datetime"
else:
raise ThirdPartyModuleNotFoundError(
"ciso8601", fname, parent
) # pragma no cover
elif deserialize_option == "pendulum":
self.ensure_module_imported("pendulum")
datetime_parser = "pendulum.parse"
if pendulum:
self.ensure_module_imported(pendulum)
datetime_parser = "pendulum.parse"
else:
raise ThirdPartyModuleNotFoundError(
"pendulum", fname, parent
) # pragma no cover
else:
raise UnserializableField(
fname,
@@ -1021,6 +1022,11 @@ def inner_expr(arg_num=0, v_name="value", v_type=None):
elif issubclass(origin_type, enum.Enum):
specific = f"{type_name(origin_type)}({value_name})"
return f"{value_name} if use_enum else {overridden or specific}"
elif is_dataclass_dict_mixin_subclass(ftype):
return overridden or (
f"{type_name(ftype)}.from_dict({value_name}, "
f"use_bytes, use_enum, use_datetime)"
)
elif overridden:
return overridden

20 changes: 20 additions & 0 deletions tests/entities.py
Original file line number Diff line number Diff line change
@@ -95,3 +95,23 @@ def __init__(self, value):

def __eq__(self, other):
return isinstance(other, ThirdPartyType) and self.value == other.value


@dataclass
class DataClassWithoutMixin:
i: int


@dataclass
class SerializableTypeDataClass(SerializableType):
a: int
b: int

def _serialize(self):
return {"a": self.a + 1, "b": self.b + 1}

@classmethod
def _deserialize(cls, value):
a = value.get("a") - 1
b = value.get("b") - 1
return cls(a, b)
11 changes: 11 additions & 0 deletions tests/entities_forward_refs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from mashumaro import DataClassDictMixin


@dataclass
class Node(DataClassDictMixin):
next: Optional[Node] = None
20 changes: 20 additions & 0 deletions tests/test_data_types.py
Original file line number Diff line number Diff line change
@@ -60,6 +60,7 @@

from .entities import (
CustomPath,
DataClassWithoutMixin,
MutableString,
MyDataClass,
MyDataClassWithUnion,
@@ -68,6 +69,7 @@
MyIntEnum,
MyIntFlag,
MyStrEnum,
SerializableTypeDataClass,
)
from .utils import same_types

@@ -1022,3 +1024,21 @@ class DataClass(DataClassDictMixin):
)
assert same_types(instance_dumped, dumped)
assert same_types(instance_loaded.x, x_value)


def test_dataclass_field_without_mixin():
with pytest.raises(UnserializableField):

@dataclass
class _(DataClassDictMixin):
p: DataClassWithoutMixin


def test_serializable_type_dataclass():
@dataclass
class DataClass(DataClassDictMixin):
s: SerializableTypeDataClass

s_value = SerializableTypeDataClass(a=9, b=9)
assert DataClass.from_dict({"s": {"a": 10, "b": 10}}) == DataClass(s_value)
assert DataClass(s_value).to_dict() == {"s": {"a": 10, "b": 10}}
16 changes: 16 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from mashumaro.exceptions import (
InvalidFieldValue,
MissingField,
ThirdPartyModuleNotFoundError,
UnserializableField,
)

@@ -96,3 +97,18 @@ def test_invalid_field_value_with_msg_str():
str(exc) == 'Field "x" of type builtins.int in builtins.object '
"has invalid value 'y': test message"
)


def test_third_party_module_not_found_error_holder_class_name():
exc = ThirdPartyModuleNotFoundError("third_party", "x", object)
assert exc.holder_class_name == "builtins.object"
exc = ThirdPartyModuleNotFoundError("third_party", "x", List[int])
assert exc.holder_class_name == "typing.List[int]"


def test_third_party_module_not_found_error_str():
exc = ThirdPartyModuleNotFoundError("third_party", "x", object)
assert (
str(exc) == 'Install "third_party" to use it as the serialization '
'method for the field "x" in builtins.object'
)
13 changes: 13 additions & 0 deletions tests/test_forward_refs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest

from mashumaro.meta.macros import PY_37_MIN


@pytest.mark.skipif(not PY_37_MIN, reason="requires python>=3.7")
def test_self_reference():
from .entities_forward_refs import Node

assert Node.from_dict({}) == Node()
assert Node.from_dict({"next": {}}) == Node(Node())
assert Node().to_dict() == {"next": None}
assert Node(Node()).to_dict() == {"next": {"next": None}}
17 changes: 16 additions & 1 deletion tests/test_meta.py
Original file line number Diff line number Diff line change
@@ -3,15 +3,19 @@

import pytest

from mashumaro import DataClassDictMixin
from mashumaro import DataClassDictMixin, DataClassJSONMixin
from mashumaro.meta.helpers import (
get_class_that_define_method,
is_class_var,
is_dataclass_dict_mixin,
is_dataclass_dict_mixin_subclass,
is_generic,
is_init_var,
)
from mashumaro.serializer.base.metaprogramming import CodeBuilder

from .entities import MyDataClass


def test_is_generic_unsupported_python():
with patch("mashumaro.meta.helpers.PY_36", False):
@@ -83,3 +87,14 @@ def foobar(self):
def test_get_unknown_declared_hook():
builder = CodeBuilder(object)
assert builder.get_declared_hook("unknown_name") is None


def test_is_dataclass_dict_mixin():
assert is_dataclass_dict_mixin(DataClassDictMixin)
assert not is_dataclass_dict_mixin(DataClassJSONMixin)


def test_is_dataclass_dict_mixin_subclass():
assert is_dataclass_dict_mixin_subclass(DataClassDictMixin)
assert is_dataclass_dict_mixin_subclass(DataClassJSONMixin)
assert is_dataclass_dict_mixin_subclass(MyDataClass)

0 comments on commit edf9eb7

Please sign in to comment.