Skip to content

Commit

Permalink
Make Union deserialization algorithn more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Nov 11, 2024
1 parent d93d8cf commit 872f668
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 32 deletions.
6 changes: 1 addition & 5 deletions benchmark/libs/mashumaro/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
import pyperf

from benchmark.common import AbstractBenchmark
from mashumaro import field_options, pass_through
from mashumaro import field_options
from mashumaro.codecs import BasicDecoder, BasicEncoder
from mashumaro.dialect import Dialect


class DefaultDialect(Dialect):
serialize_by_alias = True
serialization_strategy = {
str: {"deserialize": str, "serialize": pass_through},
int: {"serialize": pass_through},
}


class IssueState(Enum):
Expand Down
2 changes: 1 addition & 1 deletion mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
self.add_type_modules(ftype)
metadata = self.metadatas.get(fname, {})
field_block = FieldUnpackerCodeBlockBuilder(
self, self.lines.branch_off()
self, CodeLines()
).build(
fname=fname,
ftype=ftype,
Expand Down
8 changes: 2 additions & 6 deletions mashumaro/core/meta/code/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def append(self, line: str) -> None:
self._lines.append(f"{self._current_indent}{line}")

def extend(self, lines: "CodeLines") -> None:
self._lines.extend(lines._lines)
for line in lines._lines:
self._lines.append(f"{self._current_indent}{line}")

@contextmanager
def indent(
Expand All @@ -34,8 +35,3 @@ def as_text(self) -> str:
def reset(self) -> None:
self._lines = []
self._current_indent = ""

def branch_off(self) -> "CodeLines":
branch = CodeLines()
branch._current_indent = self._current_indent
return branch
7 changes: 6 additions & 1 deletion mashumaro/core/meta/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Sequence,
Type,
TypeVar,
Union,
)

from typing_extensions import ParamSpec, TypeAlias
Expand All @@ -39,8 +40,12 @@
CodeBuilder = Any


class TypeMatchEligibleExpression(str):
pass


NoneType = type(None)
Expression: TypeAlias = str
Expression: TypeAlias = Union[str, TypeMatchEligibleExpression]

P = ParamSpec("P")
T = TypeVar("T")
Expand Down
59 changes: 40 additions & 19 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
ExpressionWrapper,
NoneType,
Registry,
TypeMatchEligibleExpression,
ValueSpec,
clean_id,
ensure_generic_collection,
Expand Down Expand Up @@ -170,28 +171,40 @@ def get_method_prefix(self) -> str:
return "union"

def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
ambiguous_unpacker_types = []
unpackers = set()
fallback_unpackers = []
for type_arg in self.union_args:
condition = ""
unpacker = UnpackerRegistry.get(
spec.copy(type=type_arg, expression="value")
)
if type_arg in (bool, str) and unpacker == "value":
ambiguous_unpacker_types.append(type_arg)
if unpacker in unpackers:
do_try = unpacker != "value"
unpacker_block = CodeLines()
if isinstance(unpacker, TypeMatchEligibleExpression):
do_try = False
condition = f"type(value) is {type_arg.__name__}"
if (condition, unpacker) in unpackers:
continue
with unpacker_block.indent(f"if {condition}:"):
unpacker_block.append("return value")
if (condition, unpacker) not in unpackers:
fallback_unpackers.append(unpacker)
elif (condition, unpacker) in unpackers:
continue
else:
unpacker_block.append(f"return {unpacker}")

if do_try:
with lines.indent("try:"):
lines.extend(unpacker_block)
lines.append("except Exception: pass")
else:
lines.extend(unpacker_block)
unpackers.add((condition, unpacker))
for fallback_unpacker in fallback_unpackers:
with lines.indent("try:"):
lines.append(f"return {unpacker}")
lines.append(f"return {fallback_unpacker}")
lines.append("except Exception: pass")
unpackers.add(unpacker)
# if len(ambiguous_unpacker_types) >= 2:
# warnings.warn(
# f"{type_name(spec.builder.cls)}.{spec.field_ctx.name} "
# f"({type_name(spec.type)}): "
# "In the next release, data marked with Union type "
# "containing 'str' and 'bool' will be coerced to the value "
# "of the type specified first instead of passing it as is"
# )
field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
Expand Down Expand Up @@ -843,13 +856,21 @@ def unpack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]:
@register
def unpack_number(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (int, float):
return f"{type_name(spec.origin_type)}({spec.expression})"
return TypeMatchEligibleExpression(
f"{type_name(spec.origin_type)}({spec.expression})"
)


@register
def unpack_bool_and_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (bool, NoneType, None):
return spec.expression
def unpack_bool(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is bool:
return TypeMatchEligibleExpression(f"bool({spec.expression})")


@register
def unpack_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (NoneType, None):
return TypeMatchEligibleExpression("None")


@register
Expand Down Expand Up @@ -1199,7 +1220,7 @@ def inner_expr(
spec.builder.ensure_object_imported(decodebytes)
return f"bytearray(decodebytes({spec.expression}.encode()))"
elif issubclass(spec.origin_type, str):
return spec.expression
return TypeMatchEligibleExpression(f"str({spec.expression})")
elif ensure_generic_collection_subclass(spec, List):
return f"[{inner_expr()} for value in {spec.expression}]"
elif ensure_generic_collection_subclass(spec, typing.Deque):
Expand Down

0 comments on commit 872f668

Please sign in to comment.