Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEP 695] Support recursive type aliases #17268

Merged
merged 5 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,19 +1647,21 @@ def accept(self, visitor: StatementVisitor[T]) -> T:


class TypeAliasStmt(Statement):
__slots__ = ("name", "type_args", "value")
__slots__ = ("name", "type_args", "value", "invalid_recursive_alias")

__match_args__ = ("name", "type_args", "value")

name: NameExpr
type_args: list[TypeParam]
value: Expression # Will get translated into a type
invalid_recursive_alias: bool

def __init__(self, name: NameExpr, type_args: list[TypeParam], value: Expression) -> None:
super().__init__()
self.name = name
self.type_args = type_args
self.value = value
self.invalid_recursive_alias = False

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_type_alias_stmt(self)
Expand Down
32 changes: 28 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3961,7 +3961,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
alias_node.normalized = rvalue.node.normalized
current_node = existing.node if existing else alias_node
assert isinstance(current_node, TypeAlias)
self.disable_invalid_recursive_aliases(s, current_node)
self.disable_invalid_recursive_aliases(s, current_node, s.rvalue)
if self.is_class_scope():
assert self.type is not None
if self.type.is_protocol:
Expand Down Expand Up @@ -4057,7 +4057,7 @@ def analyze_type_alias_type_params(
return declared_tvars, all_declared_tvar_names

def disable_invalid_recursive_aliases(
self, s: AssignmentStmt, current_node: TypeAlias
self, s: AssignmentStmt | TypeAliasStmt, current_node: TypeAlias, ctx: Context
) -> None:
"""Prohibit and fix recursive type aliases that are invalid/unsupported."""
messages = []
Expand All @@ -4074,7 +4074,7 @@ def disable_invalid_recursive_aliases(
current_node.target = AnyType(TypeOfAny.from_error)
s.invalid_recursive_alias = True
for msg in messages:
self.fail(msg, s.rvalue)
self.fail(msg, ctx)

def analyze_lvalue(
self,
Expand Down Expand Up @@ -5304,6 +5304,8 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
self.visit_block(s.bodies[i])

def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
if s.invalid_recursive_alias:
return
self.statement = s
type_params = self.push_type_args(s.type_args, s)
if type_params is None:
Expand Down Expand Up @@ -5369,10 +5371,32 @@ def visit_type_alias_stmt(self, s: TypeAliasStmt) -> None:
and isinstance(existing.node, (PlaceholderNode, TypeAlias))
and existing.node.line == s.line
):
existing.node = alias_node
updated = False
if isinstance(existing.node, TypeAlias):
if existing.node.target != res:
# Copy expansion to the existing alias, this matches how we update base classes
# for a TypeInfo _in place_ if there are nested placeholders.
existing.node.target = res
existing.node.alias_tvars = alias_tvars
updated = True
else:
# Otherwise just replace existing placeholder with type alias.
existing.node = alias_node
updated = True

if updated:
if self.final_iteration:
self.cannot_resolve_name(s.name.name, "name", s)
return
else:
# We need to defer so that this change can get propagated to base classes.
self.defer(s, force_progress=True)
else:
self.add_symbol(s.name.name, alias_node, s)

current_node = existing.node if existing else alias_node
assert isinstance(current_node, TypeAlias)
self.disable_invalid_recursive_aliases(s, current_node, s.value)
finally:
self.pop_type_args(s.type_args)

Expand Down
43 changes: 43 additions & 0 deletions test-data/unit/check-python312.test
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,49 @@ def decorator(x: str) -> Any: ...
class C[T]:
pass

[case testPEP695RecursiceTypeAlias]
# mypy: enable-incomplete-feature=NewGenericSyntax

type A = str | list[A]
a: A
reveal_type(a) # N: Revealed type is "Union[builtins.str, builtins.list[...]]"

class C[T]: pass

type B[T] = C[T] | list[B[T]]
b: B[int]
reveal_type(b) # N: Revealed type is "Union[__main__.C[builtins.int], builtins.list[...]]"

[case testPEP695BadRecursiveTypeAlias]
# mypy: enable-incomplete-feature=NewGenericSyntax

type A = A # E: Cannot resolve name "A" (possible cyclic definition)
type B = B | int # E: Invalid recursive alias: a union item of itself
a: A
reveal_type(a) # N: Revealed type is "Any"
b: B
reveal_type(b) # N: Revealed type is "Any"

[case testPEP695RecursiveTypeAliasForwardReference]
# mypy: enable-incomplete-feature=NewGenericSyntax

def f(a: A) -> None:
if isinstance(a, str):
reveal_type(a) # N: Revealed type is "builtins.str"
else:
reveal_type(a) # N: Revealed type is "__main__.C[Union[builtins.str, __main__.C[...]]]"

type A = str | C[A]

class C[T]: pass

f('x')
f(C[str]())
f(C[C[str]]())
f(1) # E: Argument 1 to "f" has incompatible type "int"; expected "A"
f(C[int]()) # E: Argument 1 to "f" has incompatible type "C[int]"; expected "A"
[builtins fixtures/isinstance.pyi]

[case testPEP695InvalidGenericOrProtocolBaseClass]
# mypy: enable-incomplete-feature=NewGenericSyntax
from typing import Generic, Protocol, TypeVar
Expand Down
Loading