Skip to content

Commit

Permalink
[PEP 695] Support recursive type aliases (#17268)
Browse files Browse the repository at this point in the history
The implementation follows the approach used for old-style type aliases.

Work on #15238.
  • Loading branch information
JukkaL committed May 30, 2024
1 parent 7032f8c commit 0820e95
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 5 deletions.
4 changes: 3 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,19 +1647,21 @@ def accept(self, visitor: StatementVisitor[T]) -> T:


class TypeAliasStmt(Statement):
__slots__ = ("name", "type_args", "value")
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias")

__match_args__ = ("name", "type_args", "value")

name: NameExpr
type_args: list[TypeParam]
value: Expression # Will get translated into a type
invalid_recursive_alias: bool

def __init__(self, name: NameExpr, type_args: list[TypeParam], value: Expression) -> None:
super().__init__()
self.name = name
self.type_args = type_args
self.value = value
self.invalid_recursive_alias = False

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_type_alias_stmt(self)
Expand Down
32 changes: 28 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3961,7 +3961,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
alias_node.normalized = rvalue.node.normalized
current_node = existing.node if existing else alias_node
assert isinstance(current_node, TypeAlias)
self.disable_invalid_recursive_aliases(s, current_node)
self.disable_invalid_recursive_aliases(s, current_node, s.rvalue)
if self.is_class_scope():
assert self.type is not None
if self.type.is_protocol:
Expand Down Expand Up @@ -4057,7 +4057,7 @@ def analyze_type_alias_type_params(
return declared_tvars, all_declared_tvar_names

def disable_invalid_recursive_aliases(
self, s: AssignmentStmt, current_node: TypeAlias
self, s: AssignmentStmt | TypeAliasStmt, current_node: TypeAlias, ctx: Context
) -> None:
"""Prohibit and fix recursive type aliases that are invalid/unsupported."""
messages = []
Expand All @@ -4074,7 +4074,7 @@ def disable_invalid_recursive_aliases(
current_node.target = AnyType(TypeOfAny.from_error)
s.invalid_recursive_alias = True
for msg in messages:
self.fail(msg, s.rvalue)
self.fail(msg, ctx)

def analyze_lvalue(
self,
Expand Down Expand Up @@ -5304,6 +5304,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
self.visit_block(s.bodies[i])

def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
if s.invalid_recursive_alias:
return
self.statement = s
type_params = self.push_type_args(s.type_args, s)
if type_params is None:
Expand Down Expand Up @@ -5369,10 +5371,32 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
and isinstance(existing.node, (PlaceholderNode, TypeAlias))
and existing.node.line == s.line
):
existing.node = alias_node
updated = False
if isinstance(existing.node, TypeAlias):
if existing.node.target != res:
# Copy expansion to the existing alias, this matches how we update base classes
# for a TypeInfo _in place_ if there are nested placeholders.
existing.node.target = res
existing.node.alias_tvars = alias_tvars
updated = True
else:
# Otherwise just replace existing placeholder with type alias.
existing.node = alias_node
updated = True

if updated:
if self.final_iteration:
self.cannot_resolve_name(s.name.name, "name", s)
return
else:
# We need to defer so that this change can get propagated to base classes.
self.defer(s, force_progress=True)
else:
self.add_symbol(s.name.name, alias_node, s)

current_node = existing.node if existing else alias_node
assert isinstance(current_node, TypeAlias)
self.disable_invalid_recursive_aliases(s, current_node, s.value)
finally:
self.pop_type_args(s.type_args)

Expand Down
43 changes: 43 additions & 0 deletions test-data/unit/check-python312.test
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,49 @@ def decorator(x: str) -> Any: ...
class C[T]:
pass

[case testPEP695RecursiceTypeAlias]
# mypy: enable-incomplete-feature=NewGenericSyntax

type A = str | list[A]
a: A
reveal_type(a) # N: Revealed type is "Union[builtins.str, builtins.list[...]]"

class C[T]: pass

type B[T] = C[T] | list[B[T]]
b: B[int]
reveal_type(b) # N: Revealed type is "Union[__main__.C[builtins.int], builtins.list[...]]"

[case testPEP695BadRecursiveTypeAlias]
# mypy: enable-incomplete-feature=NewGenericSyntax

type A = A # E: Cannot resolve name "A" (possible cyclic definition)
type B = B | int # E: Invalid recursive alias: a union item of itself
a: A
reveal_type(a) # N: Revealed type is "Any"
b: B
reveal_type(b) # N: Revealed type is "Any"

[case testPEP695RecursiveTypeAliasForwardReference]
# mypy: enable-incomplete-feature=NewGenericSyntax

def f(a: A) -> None:
if isinstance(a, str):
reveal_type(a) # N: Revealed type is "builtins.str"
else:
reveal_type(a) # N: Revealed type is "__main__.C[Union[builtins.str, __main__.C[...]]]"

type A = str | C[A]

class C[T]: pass

f('x')
f(C[str]())
f(C[C[str]]())
f(1) # E: Argument 1 to "f" has incompatible type "int"; expected "A"
f(C[int]()) # E: Argument 1 to "f" has incompatible type "C[int]"; expected "A"
[builtins fixtures/isinstance.pyi]

[case testPEP695InvalidGenericOrProtocolBaseClass]
# mypy: enable-incomplete-feature=NewGenericSyntax
from typing import Generic, Protocol, TypeVar
Expand Down

0 comments on commit 0820e95

Please sign in to comment.