From 7a9418356082092d2cb1585acb816b2074cff43e Mon Sep 17 00:00:00 2001 From: Ilya Priven Date: Fri, 14 Jul 2023 02:58:23 -0700 Subject: [PATCH] Fix dataclass/protocol crash on joining types (#15629) The root cause is hacky creation of incomplete symbols; instead switching to `add_method_to_class` which does the necessary housekeeping. Fixes #15618. --- mypy/checker.py | 7 +- mypy/plugins/dataclasses.py | 137 +++++++++++--------------- test-data/unit/check-dataclasses.test | 23 +++++ test-data/unit/deps.test | 4 +- 4 files changed, 84 insertions(+), 87 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index e2ff8a6ec2a4..f2873c7d58e4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1200,9 +1200,10 @@ def check_func_def( elif isinstance(arg_type, TypeVarType): # Refuse covariant parameter type variables # TODO: check recursively for inner type variables - if arg_type.variance == COVARIANT and defn.name not in ( - "__init__", - "__new__", + if ( + arg_type.variance == COVARIANT + and defn.name not in ("__init__", "__new__", "__post_init__") + and not is_private(defn.name) # private methods are not inherited ): ctx: Context = arg_type if ctx.line < 0: diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index b1dc016a0279..a4babe7faf61 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Final, Iterator +from typing import TYPE_CHECKING, Final, Iterator, Literal from mypy import errorcodes, message_registry from mypy.expandtype import expand_type, expand_type_by_instance @@ -86,7 +86,7 @@ field_specifiers=("dataclasses.Field", "dataclasses.field"), ) _INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace" -_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-__post_init__" +_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init" class DataclassAttribute: @@ -118,14 +118,33 @@ def __init__( self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen self._api = api - def to_argument(self, current_info: TypeInfo) -> Argument: - arg_kind = ARG_POS - if self.kw_only and self.has_default: - arg_kind = ARG_NAMED_OPT - elif self.kw_only and not self.has_default: - arg_kind = ARG_NAMED - elif not self.kw_only and self.has_default: - arg_kind = ARG_OPT + def to_argument( + self, current_info: TypeInfo, *, of: Literal["__init__", "replace", "__post_init__"] + ) -> Argument: + if of == "__init__": + arg_kind = ARG_POS + if self.kw_only and self.has_default: + arg_kind = ARG_NAMED_OPT + elif self.kw_only and not self.has_default: + arg_kind = ARG_NAMED + elif not self.kw_only and self.has_default: + arg_kind = ARG_OPT + elif of == "replace": + arg_kind = ARG_NAMED if self.is_init_var and not self.has_default else ARG_NAMED_OPT + elif of == "__post_init__": + # We always use `ARG_POS` without a default value, because it is practical. + # Consider this case: + # + # @dataclass + # class My: + # y: dataclasses.InitVar[str] = 'a' + # def __post_init__(self, y: str) -> None: ... + # + # We would be *required* to specify `y: str = ...` if default is added here. + # But, most people won't care about adding default values to `__post_init__`, + # because it is not designed to be called directly, and duplicating default values + # for the sake of type-checking is unpleasant. + arg_kind = ARG_POS return Argument( variable=self.to_var(current_info), type_annotation=self.expand_type(current_info), @@ -236,7 +255,7 @@ def transform(self) -> bool: and attributes ): args = [ - attr.to_argument(info) + attr.to_argument(info, of="__init__") for attr in attributes if attr.is_in_init and not self._is_kw_only_type(attr.type) ] @@ -375,70 +394,26 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) -> Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass to be used later whenever 'dataclasses.replace' is called for this dataclass. """ - arg_types: list[Type] = [] - arg_kinds = [] - arg_names: list[str | None] = [] - - info = self._cls.info - for attr in attributes: - attr_type = attr.expand_type(info) - assert attr_type is not None - arg_types.append(attr_type) - arg_kinds.append( - ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT - ) - arg_names.append(attr.name) - - signature = CallableType( - arg_types=arg_types, - arg_kinds=arg_kinds, - arg_names=arg_names, - ret_type=NoneType(), - fallback=self._api.named_type("builtins.function"), - ) - - info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode( - kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True + add_method_to_class( + self._api, + self._cls, + _INTERNAL_REPLACE_SYM_NAME, + args=[attr.to_argument(self._cls.info, of="replace") for attr in attributes], + return_type=NoneType(), + is_staticmethod=True, ) def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None: - arg_types: list[Type] = [fill_typevars(self._cls.info)] - arg_kinds = [ARG_POS] - arg_names: list[str | None] = ["self"] - - info = self._cls.info - for attr in attributes: - if not attr.is_init_var: - continue - attr_type = attr.expand_type(info) - assert attr_type is not None - arg_types.append(attr_type) - # We always use `ARG_POS` without a default value, because it is practical. - # Consider this case: - # - # @dataclass - # class My: - # y: dataclasses.InitVar[str] = 'a' - # def __post_init__(self, y: str) -> None: ... - # - # We would be *required* to specify `y: str = ...` if default is added here. - # But, most people won't care about adding default values to `__post_init__`, - # because it is not designed to be called directly, and duplicating default values - # for the sake of type-checking is unpleasant. - arg_kinds.append(ARG_POS) - arg_names.append(attr.name) - - signature = CallableType( - arg_types=arg_types, - arg_kinds=arg_kinds, - arg_names=arg_names, - ret_type=NoneType(), - fallback=self._api.named_type("builtins.function"), - name="__post_init__", - ) - - info.names[_INTERNAL_POST_INIT_SYM_NAME] = SymbolTableNode( - kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True + add_method_to_class( + self._api, + self._cls, + _INTERNAL_POST_INIT_SYM_NAME, + args=[ + attr.to_argument(self._cls.info, of="__post_init__") + for attr in attributes + if attr.is_init_var + ], + return_type=NoneType(), ) def add_slots( @@ -1120,20 +1095,18 @@ def is_processed_dataclass(info: TypeInfo | None) -> bool: def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None: if defn.type is None: return - - ideal_sig = info.get_method(_INTERNAL_POST_INIT_SYM_NAME) - if ideal_sig is None or ideal_sig.type is None: - return - - # We set it ourself, so it is always fine: - assert isinstance(ideal_sig.type, ProperType) - assert isinstance(ideal_sig.type, FunctionLike) - # Type of `FuncItem` is always `FunctionLike`: assert isinstance(defn.type, FunctionLike) + ideal_sig_method = info.get_method(_INTERNAL_POST_INIT_SYM_NAME) + assert ideal_sig_method is not None and ideal_sig_method.type is not None + ideal_sig = ideal_sig_method.type + assert isinstance(ideal_sig, ProperType) # we set it ourselves + assert isinstance(ideal_sig, CallableType) + ideal_sig = ideal_sig.copy_modified(name="__post_init__") + api.check_override( override=defn.type, - original=ideal_sig.type, + original=ideal_sig, name="__post_init__", name_in_super="__post_init__", supertype="dataclass", diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index 131521aa98e4..adcaa60a5b19 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -744,6 +744,17 @@ s: str = a.bar() # E: Incompatible types in assignment (expression has type "in [builtins fixtures/dataclasses.pyi] +[case testDataclassGenericCovariant] +from dataclasses import dataclass +from typing import Generic, TypeVar + +T_co = TypeVar("T_co", covariant=True) + +@dataclass +class MyDataclass(Generic[T_co]): + a: T_co + +[builtins fixtures/dataclasses.pyi] [case testDataclassUntypedGenericInheritance] # flags: --python-version 3.7 @@ -2449,3 +2460,15 @@ class Test(Protocol): def reset(self) -> None: self.x = DEFAULT [builtins fixtures/dataclasses.pyi] + +[case testProtocolNoCrashOnJoining] +from dataclasses import dataclass +from typing import Protocol + +@dataclass +class MyDataclass(Protocol): ... + +a: MyDataclass +b = [a, a] # trigger joining the types + +[builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/deps.test b/test-data/unit/deps.test index fe5107b1529d..b43a2ace5eed 100644 --- a/test-data/unit/deps.test +++ b/test-data/unit/deps.test @@ -1388,7 +1388,7 @@ class B(A): -> , m -> -> , m.B.__init__ - -> + -> , m.B.__mypy-replace -> -> -> @@ -1420,7 +1420,7 @@ class B(A): -> -> , m.B.__init__ -> - -> + -> , m.B.__mypy-replace -> -> ->