Skip to content

Commit

Permalink
More codecs changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Nov 14, 2023
1 parent d44e015 commit 91a34c4
Show file tree
Hide file tree
Showing 14 changed files with 1,684 additions and 447 deletions.
25 changes: 17 additions & 8 deletions mashumaro/codecs/_builder.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from dataclasses import is_dataclass
from types import new_class
from typing import Any, Callable, Optional, Type

from mashumaro.core.meta.code.builder import CodeBuilder
from mashumaro.core.meta.helpers import is_optional, is_type_var_any
from mashumaro.core.meta.types.common import FieldContext, ValueSpec
from mashumaro.core.meta.types.common import (
AttrsHolder,
FieldContext,
ValueSpec,
)
from mashumaro.core.meta.types.pack import PackerRegistry
from mashumaro.core.meta.types.unpack import UnpackerRegistry


class NameSpace:
pass


class CodecCodeBuilder(CodeBuilder):
@classmethod
def new(cls, **kwargs: Any) -> "CodecCodeBuilder":
return cls(new_class("root", bases=(NameSpace,)), **kwargs)
if "attrs" not in kwargs:
kwargs["attrs"] = AttrsHolder()
return cls(AttrsHolder("__root__"), **kwargs)

def add_decode_method(
self,
Expand All @@ -42,11 +45,14 @@ def add_decode_method(
builder=self,
field_ctx=FieldContext(name="", metadata={}),
could_be_none=could_be_none,
is_root=True,
)
)
self.add_line(f"return {unpacked_value}")
self.add_line("setattr(decoder_obj, 'decode', decode)")
if pre_decoder_func is None and is_dataclass(shape_type):
method_name = unpacked_value.partition("(")[0]
self.lines.reset()
self.add_line(f"setattr(decoder_obj, 'decode', {method_name})")
self.ensure_object_imported(decoder_obj, "decoder_obj")
self.ensure_object_imported(self.cls, "cls")
self.compile()
Expand Down Expand Up @@ -76,7 +82,6 @@ def add_encode_method(
no_copy_collections=self._get_dialect_or_config_option(
"no_copy_collections", ()
),
is_root=True,
)
)
if post_encoder_func:
Expand All @@ -85,6 +90,10 @@ def add_encode_method(
else:
self.add_line(f"return {packed_value}")
self.add_line("setattr(encoder_obj, 'encode', encode)")
if post_encoder_func is None and is_dataclass(shape_type):
method_name = packed_value.partition("(")[0]
self.lines.reset()
self.add_line(f"setattr(encoder_obj, 'encode', {method_name})")
self.ensure_object_imported(encoder_obj, "encoder_obj")
self.ensure_object_imported(self.cls, "cls")
self.ensure_object_imported(self.cls, "self")
Expand Down
2 changes: 1 addition & 1 deletion mashumaro/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BaseConfig:
code_generation_options: List[CodeGenerationOption] = [] # type: ignore
serialization_strategy: Dict[Any, SerializationStrategyValueType] = {}
aliases: Dict[str, str] = {}
serialize_by_alias: bool = False
serialize_by_alias: bool = Sentinel.MISSING
namedtuple_as_dict: bool = False
allow_postponed_evaluation: bool = True
dialect: Optional[Type[Dialect]] = None
Expand Down
68 changes: 49 additions & 19 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def __init__(
encoder: typing.Optional[typing.Any] = None,
encoder_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
default_dialect: typing.Optional[typing.Type[Dialect]] = None,
attrs: typing.Any = None,
attrs_registry: typing.Optional[
typing.Dict[typing.Any, typing.Any]
] = None,
):
self.cls = cls
self.lines: CodeLines = CodeLines()
Expand All @@ -129,6 +133,15 @@ def __init__(
self.encoder = encoder
self.encoder_kwargs = encoder_kwargs or {}

if attrs is not None:
self.attrs = attrs
else:
self.attrs = cls
if attrs_registry is not None:
self.attrs_registry = attrs_registry
else:
self.attrs_registry = {}

def reset(self) -> None:
self.lines.reset()
self.globals = globals().copy()
Expand All @@ -145,6 +158,10 @@ def namespace(self) -> typing.Mapping[typing.Any, typing.Any]:
def annotations(self) -> typing.Dict[str, typing.Any]:
return self.namespace.get("__annotations__", {})

@property
def is_nailed(self) -> bool:
return self.attrs is self.cls

def __get_field_types(
self, recursive: bool = True, include_extras: bool = False
) -> typing.Dict[str, typing.Any]:
Expand Down Expand Up @@ -492,7 +509,7 @@ def add_unpack_method(self) -> None:
with self.indent(f"if not '{cache_name}' in cls.__dict__:"):
self.add_line(f"cls.{cache_name} = {{}}")

if self.dialect is None:
if self.dialect is None and self.is_nailed:
self.add_line("@classmethod")
self._add_unpack_method_definition(method_name)
with self.indent():
Expand All @@ -503,14 +520,7 @@ def add_unpack_method(self) -> None:
self._add_unpack_method_with_dialect_lines(method_name)
else:
self._add_unpack_method_lines(method_name)
if self.dialect is None:
self.add_line(f"setattr(cls, '{method_name}', {method_name})")
if is_dataclass_dict_mixin_subclass(self.cls):
self.add_line(
f"setattr(cls, '{method_name.public}', {method_name})"
)
else:
self.add_line(f"cls.{cache_name}[dialect] = {method_name}")
self._add_setattr_method(method_name, cache_name)
self.compile()

def _add_unpack_method_definition(self, method_name: str) -> None:
Expand All @@ -520,7 +530,11 @@ def _add_unpack_method_definition(self, method_name: str) -> None:
)
if default_kwargs:
kwargs += f", {default_kwargs}"
self.add_line(f"def {method_name}(cls, d{kwargs}):")

if self.is_nailed:
self.add_line(f"def {method_name}(cls, d{kwargs}):")
else:
self.add_line(f"def {method_name}(d{kwargs}):")

def _unpack_method_set_value(
self,
Expand Down Expand Up @@ -708,7 +722,9 @@ def get_pack_method_default_flag_values(
TO_DICT_ADD_BY_ALIAS_FLAG, cls
)
if by_alias_feature:
serialize_by_alias = self.get_config(cls).serialize_by_alias
serialize_by_alias = self._get_dialect_or_config_option(
"serialize_by_alias", False, cls
)
kw_param_names.append("by_alias")
kw_param_values.append("True" if serialize_by_alias else "False")

Expand Down Expand Up @@ -861,7 +877,9 @@ def _add_pack_method_lines(self, method_name: str) -> None:
omit_none_feature = self.is_code_generation_option_enabled(
TO_DICT_ADD_OMIT_NONE_FLAG
)
serialize_by_alias = self.get_config().serialize_by_alias
serialize_by_alias = self._get_dialect_or_config_option(
"serialize_by_alias", False
)
omit_none = self._get_dialect_or_config_option("omit_none", False)
omit_default = self._get_dialect_or_config_option(
"omit_default", False
Expand Down Expand Up @@ -1053,7 +1071,9 @@ def __pack_method_set_value(
with self.indent("else:"):
self.add_line(f"kwargs['{fname}'] = {packed_value}")
else:
serialize_by_alias = self.get_config().serialize_by_alias
serialize_by_alias = self._get_dialect_or_config_option(
"serialize_by_alias", False
)
if serialize_by_alias and alias is not None:
fname_or_alias = alias
else:
Expand Down Expand Up @@ -1136,15 +1156,25 @@ def add_pack_method(self) -> None:
self._add_pack_method_with_dialect_lines(method_name)
else:
self._add_pack_method_lines(method_name)
self._add_setattr_method(method_name, cache_name)
self.compile()

def _add_setattr_method(
self, method_name: InternalMethodName, cache_name: str
) -> None:
if self.dialect is None:
self.add_line(f"setattr(cls, '{method_name}', {method_name})")
if is_dataclass_dict_mixin_subclass(self.cls):
self.add_line(
f"setattr(cls, '{method_name.public}', {method_name})"
)
if not self.is_nailed:
self.ensure_object_imported(self.attrs, "_cls")
self.ensure_object_imported(self.cls, "cls")
self.add_line(f"setattr(_cls, '{method_name}', {method_name})")
else:
self.add_line(f"setattr(cls, '{method_name}', {method_name})")
if is_dataclass_dict_mixin_subclass(self.cls):
self.add_line(
f"setattr(cls, '{method_name.public}', {method_name})"
)
else:
self.add_line(f"cls.{cache_name}[dialect] = {method_name}")
self.compile()

def _get_field_packer(
self,
Expand Down
3 changes: 2 additions & 1 deletion mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"hash_type_args",
"iter_all_subclasses",
"is_hashable",
"is_hashable_type",
"evaluate_forward_ref",
]

Expand Down Expand Up @@ -167,7 +168,7 @@ def _get_literal_values_str(typ: Type, short: bool) -> str:
for value in get_literal_values(typ):
if isinstance(value, enum.Enum):
values_str.append(f"{type_name(type(value), short)}.{value.name}")
elif isinstance( # type: ignore
elif isinstance(
value, (int, str, bytes, bool, NoneType) # type: ignore
):
values_str.append(repr(value))
Expand Down
Loading

0 comments on commit 91a34c4

Please sign in to comment.