Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ikonst committed Jul 9, 2023
1 parent ebfea94 commit 02c7e03
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 81 deletions.
133 changes: 52 additions & 81 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Final, Iterator
from typing import TYPE_CHECKING, Final, Iterator, Literal

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type, expand_type_by_instance
Expand Down Expand Up @@ -86,7 +86,7 @@
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-__post_init__"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init"


class DataclassAttribute:
Expand Down Expand Up @@ -118,14 +118,31 @@ def __init__(
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
self._api = api

def to_argument(self, current_info: TypeInfo) -> Argument:
def to_argument(self, current_info: TypeInfo, *, of: Literal['__init__', 'replace', '__post_init__']) -> Argument:
arg_kind = ARG_POS
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
if of == '__init__':
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
elif of == 'replace':
arg_kind = ARG_NAMED if self.is_init_var and not self.has_default else ARG_NAMED_OPT
elif of == '__post_init__':
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to be called directly, and duplicating default values
# for the sake of type-checking is unpleasant.
arg_kind = ARG_POS
return Argument(
variable=self.to_var(current_info),
type_annotation=self.expand_type(current_info),
Expand Down Expand Up @@ -236,7 +253,7 @@ def transform(self) -> bool:
and attributes
):
args = [
attr.to_argument(info)
attr.to_argument(info, of='__init__')
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]
Expand Down Expand Up @@ -375,70 +392,26 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass
to be used later whenever 'dataclasses.replace' is called for this dataclass.
"""
arg_types: list[Type] = []
arg_kinds = []
arg_names: list[str | None] = []

info = self._cls.info
for attr in attributes:
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
arg_kinds.append(
ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT
)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
)

info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
add_method_to_class(
self._api,
self._cls,
_INTERNAL_REPLACE_SYM_NAME,
args=[attr.to_argument(self._cls.info, of="replace") for attr in attributes],
return_type=NoneType(),
is_staticmethod=True,
)

def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None:
arg_types: list[Type] = [fill_typevars(self._cls.info)]
arg_kinds = [ARG_POS]
arg_names: list[str | None] = ["self"]

info = self._cls.info
for attr in attributes:
if not attr.is_init_var:
continue
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to be called directly, and duplicating default values
# for the sake of type-checking is unpleasant.
arg_kinds.append(ARG_POS)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
name="__post_init__",
)

info.names[_INTERNAL_POST_INIT_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
add_method_to_class(
self._api,
self._cls,
_INTERNAL_POST_INIT_SYM_NAME,
args=[
attr.to_argument(self._cls.info, of="__post_init__")
for attr in attributes
if attr.is_init_var
],
return_type=NoneType(),
)

def add_slots(
Expand Down Expand Up @@ -1113,20 +1086,18 @@ def is_processed_dataclass(info: TypeInfo | None) -> bool:
def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None:
if defn.type is None:
return

ideal_sig = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
if ideal_sig is None or ideal_sig.type is None:
return

# We set it ourself, so it is always fine:
assert isinstance(ideal_sig.type, ProperType)
assert isinstance(ideal_sig.type, FunctionLike)
# Type of `FuncItem` is always `FunctionLike`:
assert isinstance(defn.type, FunctionLike)

ideal_sig_method = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
assert ideal_sig_method is not None and ideal_sig_method.type is not None
ideal_sig = ideal_sig_method.type
assert isinstance(ideal_sig, ProperType) # we set it ourselves
assert isinstance(ideal_sig, CallableType)
ideal_sig = ideal_sig.copy_modified(name='__post_init__')

api.check_override(
override=defn.type,
original=ideal_sig.type,
original=ideal_sig,
name="__post_init__",
name_in_super="__post_init__",
supertype="dataclass",
Expand Down
12 changes: 12 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -2420,3 +2420,15 @@ class Test(Protocol):
def reset(self) -> None:
self.x = DEFAULT
[builtins fixtures/dataclasses.pyi]

[case testProtocolNoCrashOnJoining]
from dataclasses import dataclass
from typing import Protocol

@dataclass
class MyDataclass(Protocol): ...

a: MyDataclass
b = [a, a] # trigger joining the types

[builtins fixtures/dataclasses.pyi]

0 comments on commit 02c7e03

Please sign in to comment.