From ba5c2793b1f9bd253c0415492dffb703eb523306 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 18 Jun 2024 00:45:00 +0100 Subject: [PATCH] Allow new-style self-types in classmethods (#17381) Fixes https://github.com/python/mypy/issues/16547 Fixes https://github.com/python/mypy/issues/16410 Fixes https://github.com/python/mypy/issues/5570 From the upvotes on the issue it looks like an important use case. From what I see this is an omission in the original implementation, I don't see any additional unsafety (except for the same that exists for instance methods/variables). I also incorporate a small refactoring and remove couple unused `get_proper_type()` calls. The fix uncovered an unrelated issue with unions in descriptors, so I fix that one as well. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- mypy/checkexpr.py | 4 +- mypy/checkmember.py | 78 ++++++++++++++--------- test-data/unit/check-classes.test | 35 ++++++++++ test-data/unit/check-recursive-types.test | 2 +- test-data/unit/check-selftype.test | 61 ++++++++++++++++++ 5 files changed, 149 insertions(+), 31 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 4fd1a308e560..1cea4f6c19e6 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3261,7 +3261,9 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type if isinstance(base, RefExpr) and isinstance(base.node, MypyFile): module_symbol_table = base.node.names if isinstance(base, RefExpr) and isinstance(base.node, Var): - is_self = base.node.is_self + # This is needed to special case self-types, so we don't need to track + # these flags separately in checkmember.py. + is_self = base.node.is_self or base.node.is_cls else: is_self = False diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 7525db25d9cd..0f117f5475ed 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -638,7 +638,7 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type: Return: The return type of the appropriate ``__get__`` overload for the descriptor. """ - instance_type = get_proper_type(mx.original_type) + instance_type = get_proper_type(mx.self_type) orig_descriptor_type = descriptor_type descriptor_type = get_proper_type(descriptor_type) @@ -647,16 +647,6 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type: return make_simplified_union( [analyze_descriptor_access(typ, mx) for typ in descriptor_type.items] ) - elif isinstance(instance_type, UnionType): - # map over the instance types - return make_simplified_union( - [ - analyze_descriptor_access( - descriptor_type, mx.copy_modified(original_type=original_type) - ) - for original_type in instance_type.relevant_items() - ] - ) elif not isinstance(descriptor_type, Instance): return orig_descriptor_type @@ -777,23 +767,10 @@ def analyze_var( if mx.is_lvalue and var.is_classvar: mx.msg.cant_assign_to_classvar(name, mx.context) t = freshen_all_functions_type_vars(typ) - if not (mx.is_self or mx.is_super) or supported_self_type( - get_proper_type(mx.original_type) - ): - t = expand_self_type(var, t, mx.original_type) - elif ( - mx.is_self - and original_itype.type != var.info - # If an attribute with Self-type was defined in a supertype, we need to - # rebind the Self type variable to Self type variable of current class... - and original_itype.type.self_type is not None - # ...unless `self` has an explicit non-trivial annotation. - and original_itype == mx.chk.scope.active_self_type() - ): - t = expand_self_type(var, t, original_itype.type.self_type) - t = get_proper_type(expand_type_by_instance(t, itype)) + t = expand_self_type_if_needed(t, mx, var, original_itype) + t = expand_type_by_instance(t, itype) freeze_all_type_vars(t) - result: Type = t + result = t typ = get_proper_type(typ) call_type: ProperType | None = None @@ -857,6 +834,50 @@ def analyze_var( return result +def expand_self_type_if_needed( + t: Type, mx: MemberContext, var: Var, itype: Instance, is_class: bool = False +) -> Type: + """Expand special Self type in a backwards compatible manner. + + This should ensure that mixing old-style and new-style self-types work + seamlessly. Also, re-bind new style self-types in subclasses if needed. + """ + original = get_proper_type(mx.self_type) + if not (mx.is_self or mx.is_super): + repl = mx.self_type + if is_class: + if isinstance(original, TypeType): + repl = original.item + elif isinstance(original, CallableType): + # Problematic access errors should have been already reported. + repl = erase_typevars(original.ret_type) + else: + repl = itype + return expand_self_type(var, t, repl) + elif supported_self_type( + # Support compatibility with plain old style T -> T and Type[T] -> T only. + get_proper_type(mx.self_type), + allow_instances=False, + allow_callable=False, + ): + repl = mx.self_type + if is_class and isinstance(original, TypeType): + repl = original.item + return expand_self_type(var, t, repl) + elif ( + mx.is_self + and itype.type != var.info + # If an attribute with Self-type was defined in a supertype, we need to + # rebind the Self type variable to Self type variable of current class... + and itype.type.self_type is not None + # ...unless `self` has an explicit non-trivial annotation. + and itype == mx.chk.scope.active_self_type() + ): + return expand_self_type(var, t, itype.type.self_type) + else: + return t + + def freeze_all_type_vars(member_type: Type) -> None: member_type.accept(FreezeTypeVarsVisitor()) @@ -1059,12 +1080,11 @@ def analyze_class_attribute_access( else: message = message_registry.GENERIC_INSTANCE_VAR_CLASS_ACCESS mx.msg.fail(message, mx.context) - + t = expand_self_type_if_needed(t, mx, node.node, itype, is_class=True) # Erase non-mapped variables, but keep mapped ones, even if there is an error. # In the above example this means that we infer following types: # C.x -> Any # C[int].x -> int - t = get_proper_type(expand_self_type(node.node, t, itype)) t = erase_typevars(expand_type_by_instance(t, isuper), {tv.id for tv in def_vars}) is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or ( diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 983cb8454a05..f37b0dd1dc41 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -1950,6 +1950,41 @@ class B: def foo(x: Union[A, B]) -> None: reveal_type(x.attr) # N: Revealed type is "builtins.str" +[case testDescriptorGetUnionRestricted] +from typing import Any, Union + +class getter: + def __get__(self, instance: X1, owner: Any) -> str: ... + +class X1: + prop = getter() + +class X2: + prop: str + +def foo(x: Union[X1, X2]) -> None: + reveal_type(x.prop) # N: Revealed type is "builtins.str" + +[case testDescriptorGetUnionType] +from typing import Any, Union, Type, overload + +class getter: + @overload + def __get__(self, instance: None, owner: Any) -> getter: ... + @overload + def __get__(self, instance: object, owner: Any) -> str: ... + def __get__(self, instance, owner): + ... + +class X1: + prop = getter() +class X2: + prop = getter() + +def foo(x: Type[Union[X1, X2]]) -> None: + reveal_type(x.prop) # N: Revealed type is "__main__.getter" + + -- _promote decorators -- ------------------- diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 33cb9ccad9af..d5c8acd1bc15 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -440,7 +440,7 @@ from typing import NamedTuple, TypeVar, Tuple NT = NamedTuple("NT", [("x", NT), ("y", int)]) nt: NT reveal_type(nt) # N: Revealed type is "Tuple[..., builtins.int, fallback=__main__.NT]" -reveal_type(nt.x) # N: Revealed type is "Tuple[Tuple[..., builtins.int, fallback=__main__.NT], builtins.int, fallback=__main__.NT]" +reveal_type(nt.x) # N: Revealed type is "Tuple[..., builtins.int, fallback=__main__.NT]" reveal_type(nt[0]) # N: Revealed type is "Tuple[Tuple[..., builtins.int, fallback=__main__.NT], builtins.int, fallback=__main__.NT]" y: str if nt.x is not None: diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index e99b859bbcd0..fdd628b0271b 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -2071,3 +2071,64 @@ p: Partial reveal_type(p()) # N: Revealed type is "Never" p2: Partial2 reveal_type(p2(42)) # N: Revealed type is "builtins.int" + +[case testAccessingSelfClassVarInClassMethod] +from typing import Self, ClassVar, Type, TypeVar + +T = TypeVar("T", bound="Foo") + +class Foo: + instance: ClassVar[Self] + @classmethod + def get_instance(cls) -> Self: + return reveal_type(cls.instance) # N: Revealed type is "Self`0" + @classmethod + def get_instance_old(cls: Type[T]) -> T: + return reveal_type(cls.instance) # N: Revealed type is "T`-1" + +class Bar(Foo): + extra: int + + @classmethod + def get_instance(cls) -> Self: + reveal_type(cls.instance.extra) # N: Revealed type is "builtins.int" + return cls.instance + + @classmethod + def other(cls) -> None: + reveal_type(cls.instance) # N: Revealed type is "Self`0" + reveal_type(cls.instance.extra) # N: Revealed type is "builtins.int" + +reveal_type(Bar.instance) # N: Revealed type is "__main__.Bar" +[builtins fixtures/classmethod.pyi] + +[case testAccessingSelfClassVarInClassMethodTuple] +from typing import Self, ClassVar, Tuple + +class C(Tuple[int, str]): + x: Self + y: ClassVar[Self] + + @classmethod + def bar(cls) -> None: + reveal_type(cls.y) # N: Revealed type is "Self`0" + @classmethod + def bar_self(self) -> Self: + return reveal_type(self.y) # N: Revealed type is "Self`0" + +c: C +reveal_type(c.x) # N: Revealed type is "Tuple[builtins.int, builtins.str, fallback=__main__.C]" +reveal_type(c.y) # N: Revealed type is "Tuple[builtins.int, builtins.str, fallback=__main__.C]" +reveal_type(C.y) # N: Revealed type is "Tuple[builtins.int, builtins.str, fallback=__main__.C]" +C.x # E: Access to generic instance variables via class is ambiguous +[builtins fixtures/classmethod.pyi] + +[case testAccessingTypingSelfUnion] +from typing import Self, Union + +class C: + x: Self +class D: + x: int +x: Union[C, D] +reveal_type(x.x) # N: Revealed type is "Union[__main__.C, builtins.int]"