Skip to content

Commit

Permalink
Merge branch 'pep-604'
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Jan 4, 2022
2 parents 705de2f + 870fbcf commit 3e47d46
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 16 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ for special primitives from the [`typing`](https://docs.python.org/3/library/typ
* [`Union`](https://docs.python.org/3/library/typing.html#typing.Union)
* [`TypeVar`](https://docs.python.org/3/library/typing.html#typing.TypeVar)

for standard interpreter types from [`types`](https://docs.python.org/3/library/types.html#standard-interpreter-types) module:
* [`NoneType`](https://docs.python.org/3/library/types.html#types.NoneType)
* [`UnionType`](https://docs.python.org/3/library/types.html#types.UnionType)

for enumerations based on classes from the standard [`enum`](https://docs.python.org/3/library/enum.html) module:
* [`Enum`](https://docs.python.org/3/library/enum.html#enum.Enum)
* [`IntEnum`](https://docs.python.org/3/library/enum.html#enum.IntEnum)
Expand Down Expand Up @@ -163,8 +167,7 @@ for other less popular built-in types:
* [`ipaddress.IPv4Interface`](https://docs.python.org/3/library/ipaddress.html#ipaddress.IPv4Interface)
* [`ipaddress.IPv6Interface`](https://docs.python.org/3/library/ipaddress.html#ipaddress.IPv6Interface)

for specific types like:
* [`NoneType`](https://docs.python.org/3/library/constants.html#None)
for arbitrary types:
* [user-defined classes](#serializabletype-interface)
* [user-defined generic types](#user-defined-generic-types)

Expand Down
43 changes: 36 additions & 7 deletions mashumaro/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,14 @@ def _get_args_str(
short: bool,
type_vars: typing.Dict[str, typing.Any] = None,
limit: typing.Optional[int] = None,
none_type_as_none: bool = False,
sep: str = ", ",
):
args = get_args(t)[:limit]
return ", ".join(type_name(arg, short, type_vars) for arg in args)
return sep.join(
type_name(arg, short, type_vars, none_type_as_none=none_type_as_none)
for arg in args
)


def _typing_name(t: str, short: bool = False):
Expand All @@ -84,16 +89,19 @@ def type_name(
short: bool = False,
type_vars: typing.Dict[str, typing.Any] = None,
is_type_origin: bool = False,
none_type_as_none: bool = False,
) -> str:
if type_vars is None:
type_vars = {}
if t is typing.Any:
if t is NoneType and none_type_as_none:
return "None"
elif t is typing.Any:
return _typing_name("Any", short)
elif is_optional(t):
args_str = _get_args_str(t, short, type_vars, 1)
elif is_optional(t, type_vars):
args_str = type_name(not_none_type_arg(get_args(t), type_vars), short)
return f"{_typing_name('Optional', short)}[{args_str}]"
elif is_union(t):
args_str = _get_args_str(t, short, type_vars)
args_str = _get_args_str(t, short, type_vars, none_type_as_none=True)
return f"{_typing_name('Union', short)}[{args_str}]"
elif is_generic(t) and not is_type_origin:
args_str = _get_args_str(t, short, type_vars)
Expand Down Expand Up @@ -183,16 +191,36 @@ def is_named_tuple(t):

def is_union(t):
try:
if PY_310_MIN and isinstance(t, types.UnionType):
return True
return t.__origin__ is typing.Union
except AttributeError:
return False


def is_optional(t):
def is_optional(t, type_vars: typing.Dict[str, typing.Any] = None):
if type_vars is None:
type_vars = {}
if not is_union(t):
return False
args = get_args(t)
return len(args) == 2 and args[1] == NoneType
if len(args) != 2:
return False
for arg in args:
if type_vars.get(arg, arg) is NoneType:
return True
return False


def not_none_type_arg(
args: typing.Tuple[typing.Any, ...],
type_vars: typing.Dict[str, typing.Any] = None,
):
if type_vars is None:
type_vars = {}
for arg in args:
if type_vars.get(arg, arg) is not NoneType:
return arg


def is_type_var(t):
Expand Down Expand Up @@ -329,6 +357,7 @@ def is_dialect_subclass(t) -> bool:
"is_named_tuple",
"is_optional",
"is_union",
"not_none_type_arg",
"is_type_var",
"is_type_var_any",
"is_class_var",
Expand Down
17 changes: 11 additions & 6 deletions mashumaro/serializer/base/metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
is_type_var_any,
is_typed_dict,
is_union,
not_none_type_arg,
resolve_type_vars,
type_name,
)
Expand Down Expand Up @@ -767,9 +768,11 @@ def _pack_value(
if overridden:
return overridden
args = get_args(ftype)
if is_optional(ftype):
field_type_vars = self._get_field_type_vars(fname)
if is_optional(ftype, field_type_vars):
arg = not_none_type_arg(args, field_type_vars)
pv = self._pack_value(
fname, args[0], parent, value_name, metadata=metadata
fname, arg, parent, value_name, metadata=metadata
)
if could_be_none:
return f"{pv} if {value_name} is not None else None"
Expand Down Expand Up @@ -824,7 +827,7 @@ def _pack_value(
return overridden or f"int({value_name})"
elif origin_type is float:
return overridden or f"float({value_name})"
elif origin_type in (bool, NoneType):
elif origin_type in (bool, NoneType, None):
return overridden or value_name
elif origin_type in (datetime.datetime, datetime.date, datetime.time):
if overridden:
Expand Down Expand Up @@ -1116,9 +1119,11 @@ def _unpack_field_value(
if overridden:
return overridden
args = get_args(ftype)
if is_optional(ftype):
field_type_vars = self._get_field_type_vars(fname)
if is_optional(ftype, field_type_vars):
arg = not_none_type_arg(args, field_type_vars)
ufv = self._unpack_field_value(
fname, args[0], parent, value_name, metadata=metadata
fname, arg, parent, value_name, metadata=metadata
)
if could_be_none:
return f"{ufv} if {value_name} is not None else None"
Expand Down Expand Up @@ -1173,7 +1178,7 @@ def _unpack_field_value(
return overridden or f"int({value_name})"
elif origin_type is float:
return overridden or f"float({value_name})"
elif origin_type in (bool, NoneType):
elif origin_type in (bool, NoneType, None):
return overridden or value_name
elif origin_type in (datetime.datetime, datetime.date, datetime.time):
if overridden:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,3 +1439,15 @@ class DataClass(DataClassDictMixin):
obj = DataClass(x=MyUntypedNamedTupleWithDefaults(i="1"))
assert DataClass.from_dict({"x": ["1"]}) == obj
assert obj.to_dict() == {"x": ["1", 2.0]}


def test_data_class_with_none():
@dataclass
class DataClass(DataClassDictMixin):
x: None
y: NoneType
z: List[None]

obj = DataClass(x=None, y=None, z=[None])
assert DataClass.from_dict({"x": None, "y": None, "z": [None]}) == obj
assert obj.to_dict() == {"x": None, "y": None, "z": [None]}
95 changes: 94 additions & 1 deletion tests/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mashumaro import DataClassDictMixin, DataClassJSONMixin
from mashumaro.dialect import Dialect
from mashumaro.meta.helpers import (
get_args,
get_class_that_defines_field,
get_class_that_defines_method,
get_generic_name,
Expand All @@ -19,11 +20,20 @@
is_dialect_subclass,
is_generic,
is_init_var,
is_optional,
is_type_var_any,
is_union,
not_none_type_arg,
resolve_type_vars,
type_name,
)
from mashumaro.meta.macros import PEP_585_COMPATIBLE, PY_37, PY_37_MIN, PY_38
from mashumaro.meta.macros import (
PEP_585_COMPATIBLE,
PY_37,
PY_37_MIN,
PY_38,
PY_310_MIN,
)
from mashumaro.serializer.base.metaprogramming import CodeBuilder

from .entities import (
Expand All @@ -36,6 +46,8 @@
TIntStr,
)

NoneType = type(None)

TMyDataClass = typing.TypeVar("TMyDataClass", bound=MyDataClass)


Expand Down Expand Up @@ -177,7 +189,22 @@ def test_type_name():
== "typing.OrderedDict[int, int]"
)
assert type_name(typing.Optional[int]) == "typing.Optional[int]"
assert type_name(typing.Union[None, int]) == "typing.Optional[int]"
assert type_name(typing.Union[int, None]) == "typing.Optional[int]"
assert type_name(None) == "None"
assert type_name(NoneType) == "NoneType"
assert type_name(NoneType, none_type_as_none=True) == "None"
assert type_name(typing.List[NoneType]) == "typing.List[NoneType]"
assert (
type_name(typing.Union[int, str, None])
== "typing.Union[int, str, None]"
)
assert type_name(typing.Optional[NoneType]) == "NoneType"

if PY_310_MIN:
assert type_name(int | None) == "typing.Optional[int]"
assert type_name(None | int) == "typing.Optional[int]"
assert type_name(int | str) == "typing.Union[int, str]"


@pytest.mark.skipif(not PEP_585_COMPATIBLE, reason="requires python 3.9+")
Expand Down Expand Up @@ -251,7 +278,22 @@ def test_type_name_short():
== "OrderedDict[int, int]"
)
assert type_name(typing.Optional[int], short=True) == "Optional[int]"
assert type_name(typing.Union[None, int], short=True) == "Optional[int]"
assert type_name(typing.Union[int, None], short=True) == "Optional[int]"
assert type_name(None, short=True) == "None"
assert type_name(NoneType, short=True) == "NoneType"
assert type_name(NoneType, short=True, none_type_as_none=True) == "None"
assert type_name(typing.List[NoneType], short=True) == "List[NoneType]"
assert (
type_name(typing.Union[int, str, None], short=True)
== "Union[int, str, None]"
)
assert type_name(typing.Optional[NoneType], short=True) == "NoneType"

if PY_310_MIN:
assert type_name(int | None, short=True) == "Optional[int]"
assert type_name(None | int, short=True) == "Optional[int]"
assert type_name(int | str, short=True) == "Union[int, str]"


@pytest.mark.skipif(not PEP_585_COMPATIBLE, reason="requires python 3.9+")
Expand Down Expand Up @@ -339,3 +381,54 @@ class MyDialect(Dialect):
assert is_dialect_subclass(Dialect)
assert is_dialect_subclass(MyDialect)
assert not is_dialect_subclass(123)


def test_is_union():
t = typing.Optional[str]
assert is_union(t)
assert get_args(t) == (str, NoneType)
t = typing.Union[str, None]
assert is_union(t)
assert get_args(t) == (str, NoneType)
t = typing.Union[None, str]
assert is_union(t)
assert get_args(t) == (NoneType, str)


@pytest.mark.skipif(not PY_310_MIN, reason="requires python 3.10+")
def test_is_union_pep_604():
t = str | None
assert is_union(t)
assert get_args(t) == (str, NoneType)
t = None | str
assert is_union(t)
assert get_args(t) == (NoneType, str)


def test_is_optional():
t = typing.Optional[str]
assert is_optional(t)
assert get_args(t) == (str, NoneType)
t = typing.Union[str, None]
assert is_optional(t)
assert get_args(t) == (str, NoneType)
t = typing.Union[None, str]
assert is_optional(t)
assert get_args(t) == (NoneType, str)


@pytest.mark.skipif(not PY_310_MIN, reason="requires python 3.10+")
def test_is_optional_pep_604():
t = str | None
assert is_optional(t)
assert get_args(t) == (str, NoneType)
t = None | str
assert is_optional(t)
assert get_args(t) == (NoneType, str)


def test_not_non_type_arg():
assert not_none_type_arg((str, int)) == str
assert not_none_type_arg((NoneType, int)) == int
assert not_none_type_arg((str, NoneType)) == str
assert not_none_type_arg((T, int), {T: NoneType}) == int

0 comments on commit 3e47d46

Please sign in to comment.