Skip to content

Commit

Permalink
Support better __post_init__ method signature for dataclasses (#1…
Browse files Browse the repository at this point in the history
…5503)

Now we use a similar approach to
#14849
First, we generate a private name to store in a metadata (with `-`, so -
no conflicts, ever).
Next, we check override to be compatible: we take the currect signature
and compare it to the ideal one we have.

Simple and it works :)

Closes #15498
Closes #9254

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Levkivskyi <[email protected]>
  • Loading branch information
3 people committed Jun 26, 2023
1 parent 9ad3f38 commit 9511daa
Show file tree
Hide file tree
Showing 6 changed files with 318 additions and 17 deletions.
9 changes: 8 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
from mypy.options import Options
from mypy.patterns import AsPattern, StarredPattern
from mypy.plugin import CheckerPluginInterface, Plugin
from mypy.plugins import dataclasses as dataclasses_plugin
from mypy.scope import Scope
from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name
from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS
Expand Down Expand Up @@ -1044,6 +1045,9 @@ def check_func_item(

if name == "__exit__":
self.check__exit__return_type(defn)
if name == "__post_init__":
if dataclasses_plugin.is_processed_dataclass(defn.info):
dataclasses_plugin.check_post_init(self, defn, defn.info)

@contextmanager
def enter_attribute_inference_context(self) -> Iterator[None]:
Expand Down Expand Up @@ -1851,7 +1855,7 @@ def check_method_or_accessor_override_for_base(
found_base_method = True

# Check the type of override.
if name not in ("__init__", "__new__", "__init_subclass__"):
if name not in ("__init__", "__new__", "__init_subclass__", "__post_init__"):
# Check method override
# (__init__, __new__, __init_subclass__ are special).
if self.check_method_override_for_base_with_name(defn, name, base):
Expand Down Expand Up @@ -2812,6 +2816,9 @@ def check_assignment(
if name == "__match_args__" and inferred is not None:
typ = self.expr_checker.accept(rvalue)
self.check_match_args(inferred, typ, lvalue)
if name == "__post_init__":
if dataclasses_plugin.is_processed_dataclass(self.scope.active_class()):
self.fail(message_registry.DATACLASS_POST_INIT_MUST_BE_A_FUNCTION, rvalue)

# Defer PartialType's super type checking.
if (
Expand Down
1 change: 1 addition & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = (
'"alias" argument to dataclass field must be a string literal'
)
DATACLASS_POST_INIT_MUST_BE_A_FUNCTION: Final = '"__post_init__" method must be an instance method'

# fastparse
FAILED_TO_MERGE_OVERLOADS: Final = ErrorMessage(
Expand Down
27 changes: 15 additions & 12 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,18 +1253,21 @@ def argument_incompatible_with_supertype(
code=codes.OVERRIDE,
secondary_context=secondary_context,
)
self.note(
"This violates the Liskov substitution principle",
context,
code=codes.OVERRIDE,
secondary_context=secondary_context,
)
self.note(
"See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides",
context,
code=codes.OVERRIDE,
secondary_context=secondary_context,
)
if name != "__post_init__":
# `__post_init__` is special, it can be incompatible by design.
# So, this note is misleading.
self.note(
"This violates the Liskov substitution principle",
context,
code=codes.OVERRIDE,
secondary_context=secondary_context,
)
self.note(
"See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides",
context,
code=codes.OVERRIDE,
secondary_context=secondary_context,
)

if name == "__eq__" and type_name:
multiline_msg = self.comparison_method_example_msg(class_name=type_name)
Expand Down
86 changes: 82 additions & 4 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Iterator, Optional
from typing import TYPE_CHECKING, Iterator, Optional
from typing_extensions import Final

from mypy import errorcodes, message_registry
Expand All @@ -26,6 +26,7 @@
DataclassTransformSpec,
Expression,
FuncDef,
FuncItem,
IfStmt,
JsonDict,
NameExpr,
Expand Down Expand Up @@ -55,6 +56,7 @@
from mypy.types import (
AnyType,
CallableType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Expand All @@ -69,19 +71,23 @@
)
from mypy.typevars import fill_typevars

if TYPE_CHECKING:
from mypy.checker import TypeChecker

# The set of decorators that generate dataclasses.
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}


SELF_TVAR_NAME: Final = "_DT"
_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec(
_TRANSFORM_SPEC_FOR_DATACLASSES: Final = DataclassTransformSpec(
eq_default=True,
order_default=False,
kw_only_default=False,
frozen_default=False,
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)
_INTERNAL_REPLACE_SYM_NAME = "__mypy-replace"
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-__post_init__"


class DataclassAttribute:
Expand Down Expand Up @@ -350,6 +356,8 @@ def transform(self) -> bool:

if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
self._add_internal_replace_method(attributes)
if "__post_init__" in info.names:
self._add_internal_post_init_method(attributes)

info.metadata["dataclass"] = {
"attributes": [attr.serialize() for attr in attributes],
Expand Down Expand Up @@ -385,7 +393,47 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
fallback=self._api.named_type("builtins.function"),
)

self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
)

def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None:
arg_types: list[Type] = [fill_typevars(self._cls.info)]
arg_kinds = [ARG_POS]
arg_names: list[str | None] = ["self"]

info = self._cls.info
for attr in attributes:
if not attr.is_init_var:
continue
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to be called directly, and duplicating default values
# for the sake of type-checking is unpleasant.
arg_kinds.append(ARG_POS)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
name="__post_init__",
)

info.names[_INTERNAL_POST_INIT_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
)

Expand Down Expand Up @@ -1052,3 +1100,33 @@ def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
fallback=ctx.default_signature.fallback,
name=f"{ctx.default_signature.name} of {inst_type_str}",
)


def is_processed_dataclass(info: TypeInfo | None) -> bool:
return info is not None and "dataclass" in info.metadata


def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None:
if defn.type is None:
return

ideal_sig = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
if ideal_sig is None or ideal_sig.type is None:
return

# We set it ourself, so it is always fine:
assert isinstance(ideal_sig.type, ProperType)
assert isinstance(ideal_sig.type, FunctionLike)
# Type of `FuncItem` is always `FunctionLike`:
assert isinstance(defn.type, FunctionLike)

api.check_override(
override=defn.type,
original=ideal_sig.type,
name="__post_init__",
name_in_super="__post_init__",
supertype="dataclass",
original_class_or_static=False,
override_class_or_static=False,
node=defn,
)
Loading

0 comments on commit 9511daa

Please sign in to comment.