diff --git a/mypy/nodes.py b/mypy/nodes.py index e52618fcdae6..dbde3ddf4f1b 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1647,19 +1647,21 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class TypeAliasStmt(Statement): - __slots__ = ("name", "type_args", "value") + __slots__ = ("name", "type_args", "value", "invalid_recursive_alias") __match_args__ = ("name", "type_args", "value") name: NameExpr type_args: list[TypeParam] value: Expression # Will get translated into a type + invalid_recursive_alias: bool def __init__(self, name: NameExpr, type_args: list[TypeParam], value: Expression) -> None: super().__init__() self.name = name self.type_args = type_args self.value = value + self.invalid_recursive_alias = False def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_type_alias_stmt(self) diff --git a/mypy/semanal.py b/mypy/semanal.py index 320ae72d99f9..0689d5416efe 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -3961,7 +3961,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: alias_node.normalized = rvalue.node.normalized current_node = existing.node if existing else alias_node assert isinstance(current_node, TypeAlias) - self.disable_invalid_recursive_aliases(s, current_node) + self.disable_invalid_recursive_aliases(s, current_node, s.rvalue) if self.is_class_scope(): assert self.type is not None if self.type.is_protocol: @@ -4057,7 +4057,7 @@ def analyze_type_alias_type_params( return declared_tvars, all_declared_tvar_names def disable_invalid_recursive_aliases( - self, s: AssignmentStmt, current_node: TypeAlias + self, s: AssignmentStmt | TypeAliasStmt, current_node: TypeAlias, ctx: Context ) -> None: """Prohibit and fix recursive type aliases that are invalid/unsupported.""" messages = [] @@ -4074,7 +4074,7 @@ def disable_invalid_recursive_aliases( current_node.target = AnyType(TypeOfAny.from_error) s.invalid_recursive_alias = True for msg in messages: - self.fail(msg, s.rvalue) + self.fail(msg, ctx) def analyze_lvalue( self, @@ -5304,6 +5304,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None: self.visit_block(s.bodies[i]) def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: + if s.invalid_recursive_alias: + return self.statement = s type_params = self.push_type_args(s.type_args, s) if type_params is None: @@ -5369,10 +5371,32 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None: and isinstance(existing.node, (PlaceholderNode, TypeAlias)) and existing.node.line == s.line ): - existing.node = alias_node + updated = False + if isinstance(existing.node, TypeAlias): + if existing.node.target != res: + # Copy expansion to the existing alias, this matches how we update base classes + # for a TypeInfo _in place_ if there are nested placeholders. + existing.node.target = res + existing.node.alias_tvars = alias_tvars + updated = True + else: + # Otherwise just replace existing placeholder with type alias. + existing.node = alias_node + updated = True + + if updated: + if self.final_iteration: + self.cannot_resolve_name(s.name.name, "name", s) + return + else: + # We need to defer so that this change can get propagated to base classes. + self.defer(s, force_progress=True) else: self.add_symbol(s.name.name, alias_node, s) + current_node = existing.node if existing else alias_node + assert isinstance(current_node, TypeAlias) + self.disable_invalid_recursive_aliases(s, current_node, s.value) finally: self.pop_type_args(s.type_args) diff --git a/test-data/unit/check-python312.test b/test-data/unit/check-python312.test index f5d9fd195f04..6dd61351d7a8 100644 --- a/test-data/unit/check-python312.test +++ b/test-data/unit/check-python312.test @@ -1162,6 +1162,49 @@ def decorator(x: str) -> Any: ... class C[T]: pass +[case testPEP695RecursiceTypeAlias] +# mypy: enable-incomplete-feature=NewGenericSyntax + +type A = str | list[A] +a: A +reveal_type(a) # N: Revealed type is "Union[builtins.str, builtins.list[...]]" + +class C[T]: pass + +type B[T] = C[T] | list[B[T]] +b: B[int] +reveal_type(b) # N: Revealed type is "Union[__main__.C[builtins.int], builtins.list[...]]" + +[case testPEP695BadRecursiveTypeAlias] +# mypy: enable-incomplete-feature=NewGenericSyntax + +type A = A # E: Cannot resolve name "A" (possible cyclic definition) +type B = B | int # E: Invalid recursive alias: a union item of itself +a: A +reveal_type(a) # N: Revealed type is "Any" +b: B +reveal_type(b) # N: Revealed type is "Any" + +[case testPEP695RecursiveTypeAliasForwardReference] +# mypy: enable-incomplete-feature=NewGenericSyntax + +def f(a: A) -> None: + if isinstance(a, str): + reveal_type(a) # N: Revealed type is "builtins.str" + else: + reveal_type(a) # N: Revealed type is "__main__.C[Union[builtins.str, __main__.C[...]]]" + +type A = str | C[A] + +class C[T]: pass + +f('x') +f(C[str]()) +f(C[C[str]]()) +f(1) # E: Argument 1 to "f" has incompatible type "int"; expected "A" +f(C[int]()) # E: Argument 1 to "f" has incompatible type "C[int]"; expected "A" +[builtins fixtures/isinstance.pyi] + [case testPEP695InvalidGenericOrProtocolBaseClass] # mypy: enable-incomplete-feature=NewGenericSyntax from typing import Generic, Protocol, TypeVar