Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEP 695] Implement new scoping rules for type parameters #17258

Merged
merged 12 commits into from
May 17, 2024
10 changes: 7 additions & 3 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,7 +2502,7 @@ class TypeVarLikeExpr(SymbolNode, Expression):
Note that they are constructed by the semantic analyzer.
"""

__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance", "is_new_style")

_name: str
_fullname: str
Expand All @@ -2525,13 +2525,15 @@ def __init__(
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
is_new_style: bool = False,
) -> None:
super().__init__()
self._name = name
self._fullname = fullname
self.upper_bound = upper_bound
self.default = default
self.variance = variance
self.is_new_style = is_new_style

@property
def name(self) -> str:
Expand Down Expand Up @@ -2570,8 +2572,9 @@ def __init__(
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
is_new_style: bool = False,
) -> None:
super().__init__(name, fullname, upper_bound, default, variance)
super().__init__(name, fullname, upper_bound, default, variance, is_new_style)
self.values = values

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand Down Expand Up @@ -2648,8 +2651,9 @@ def __init__(
tuple_fallback: mypy.types.Instance,
default: mypy.types.Type,
variance: int = INVARIANT,
is_new_style: bool = False,
) -> None:
super().__init__(name, fullname, upper_bound, default, variance)
super().__init__(name, fullname, upper_bound, default, variance, is_new_style)
self.tuple_fallback = tuple_fallback

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand Down
118 changes: 90 additions & 28 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@
CORE_BUILTIN_CLASSES: Final = ["object", "bool", "function"]


# Python has several different scope/namespace kinds with subtly different semantics.
SCOPE_GLOBAL: Final = 0 # Module top level
SCOPE_CLASS: Final = 1 # Class body
SCOPE_FUNC: Final = 2 # Function or lambda
SCOPE_COMPREHENSION: Final = 3 # Comprehension or generator expression
SCOPE_ANNOTATION: Final = 4 # Annotation scopes for type parameters and aliases (PEP 695)


# Used for tracking incomplete references
Tag: _TypeAlias = int

Expand All @@ -342,8 +350,8 @@ class SemanticAnalyzer(
nonlocal_decls: list[set[str]]
# Local names of function scopes; None for non-function scopes.
locals: list[SymbolTable | None]
# Whether each scope is a comprehension scope.
is_comprehension_stack: list[bool]
# Type of each scope (SCOPE_*, indexes match locals)
scope_stack: list[int]
# Nested block depths of scopes
block_depth: list[int]
# TypeInfo of directly enclosing class (or None)
Expand Down Expand Up @@ -417,7 +425,7 @@ def __init__(
errors: Report analysis errors using this instance
"""
self.locals = [None]
self.is_comprehension_stack = [False]
self.scope_stack = [SCOPE_GLOBAL]
# Saved namespaces from previous iteration. Every top-level function/method body is
# analyzed in several iterations until all names are resolved. We need to save
# the local namespaces for the top level function and all nested functions between
Expand Down Expand Up @@ -880,6 +888,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
# Don't store not ready types (including placeholders).
if self.found_incomplete_ref(tag) or has_placeholder(result):
self.defer(defn)
# TODO: pop type args
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
Expand Down Expand Up @@ -1645,6 +1654,8 @@ def push_type_args(
) -> list[tuple[str, TypeVarLikeExpr]] | None:
if not type_args:
return []
self.locals.append(SymbolTable())
self.scope_stack.append(SCOPE_ANNOTATION)
tvs: list[tuple[str, TypeVarLikeExpr]] = []
for p in type_args:
tv = self.analyze_type_param(p)
Expand All @@ -1653,10 +1664,23 @@ def push_type_args(
tvs.append((p.name, tv))

for name, tv in tvs:
self.add_symbol(name, tv, context, no_progress=True)
if self.is_defined_type_param(name):
self.fail(f'"{name}" already defined as a type parameter', context)
else:
self.add_symbol(name, tv, context, no_progress=True, type_param=True)

return tvs

def is_defined_type_param(self, name: str) -> bool:
for names in self.locals:
if names is None:
continue
if name in names:
node = names[name].node
if isinstance(node, TypeVarLikeExpr):
return True
return False

def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
fullname = self.qualified_name(type_param.name)
if type_param.upper_bound:
Expand All @@ -1681,10 +1705,15 @@ def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
upper_bound=upper_bound,
default=default,
variance=VARIANCE_NOT_READY,
is_new_style=True,
)
elif type_param.kind == PARAM_SPEC_KIND:
return ParamSpecExpr(
name=type_param.name, fullname=fullname, upper_bound=upper_bound, default=default
name=type_param.name,
fullname=fullname,
upper_bound=upper_bound,
default=default,
is_new_style=True,
)
else:
assert type_param.kind == TYPE_VAR_TUPLE_KIND
Expand All @@ -1696,14 +1725,14 @@ def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
upper_bound=tuple_fallback.copy_modified(),
tuple_fallback=tuple_fallback,
default=default,
is_new_style=True,
)

def pop_type_args(self, type_args: list[TypeParam] | None) -> None:
if not type_args:
return
for tv in type_args:
names = self.current_symbol_table()
del names[tv.name]
self.locals.pop()
self.scope_stack.pop()

def analyze_class(self, defn: ClassDef) -> None:
fullname = self.qualified_name(defn.name)
Expand Down Expand Up @@ -1785,8 +1814,18 @@ def analyze_class(self, defn: ClassDef) -> None:
defn.info.is_protocol = is_protocol
self.recalculate_metaclass(defn, declared_metaclass)
defn.info.runtime_protocol = False

if defn.type_args:
# PEP 695 type parameters are not in scope in class decorators, so
# temporarily disable type parameter namespace.
type_params_names = self.locals.pop()
self.scope_stack.pop()
for decorator in defn.decorators:
self.analyze_class_decorator(defn, decorator)
if defn.type_args:
self.locals.append(type_params_names)
self.scope_stack.append(SCOPE_ANNOTATION)

self.analyze_class_body_common(defn)

def setup_type_vars(self, defn: ClassDef, tvar_defs: list[TypeVarLikeType]) -> None:
Expand Down Expand Up @@ -1938,7 +1977,7 @@ def enter_class(self, info: TypeInfo) -> None:
# Remember previous active class
self.type_stack.append(self.type)
self.locals.append(None) # Add class scope
self.is_comprehension_stack.append(False)
self.scope_stack.append(SCOPE_CLASS)
self.block_depth.append(-1) # The class body increments this to 0
self.loop_depth.append(0)
self._type = info
Expand All @@ -1949,7 +1988,7 @@ def leave_class(self) -> None:
self.block_depth.pop()
self.loop_depth.pop()
self.locals.pop()
self.is_comprehension_stack.pop()
self.scope_stack.pop()
self._type = self.type_stack.pop()
self.missing_names.pop()

Expand Down Expand Up @@ -2923,8 +2962,8 @@ class C:
[(j := i) for i in [1, 2, 3]]
is a syntax error that is not enforced by Python parser, but at later steps.
"""
for i, is_comprehension in enumerate(reversed(self.is_comprehension_stack)):
if not is_comprehension and i < len(self.locals) - 1:
for i, scope_type in enumerate(reversed(self.scope_stack)):
if scope_type != SCOPE_COMPREHENSION and i < len(self.locals) - 1:
if self.locals[-1 - i] is None:
self.fail(
"Assignment expression within a comprehension"
Expand Down Expand Up @@ -5188,8 +5227,14 @@ def visit_nonlocal_decl(self, d: NonlocalDecl) -> None:
self.fail("nonlocal declaration not allowed at module level", d)
else:
for name in d.names:
for table in reversed(self.locals[:-1]):
for table, scope_type in zip(
reversed(self.locals[:-1]), reversed(self.scope_stack[:-1])
):
if table is not None and name in table:
if scope_type == SCOPE_ANNOTATION:
self.fail(
f'nonlocal binding not allowed for type parameter "{name}"', d
)
break
else:
self.fail(f'No binding for nonlocal "{name}" found', d)
Expand Down Expand Up @@ -5350,7 +5395,7 @@ def visit_star_expr(self, expr: StarExpr) -> None:
def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
if not self.is_func_scope():
self.fail('"yield from" outside function', e, serious=True, blocker=True)
elif self.is_comprehension_stack[-1]:
elif self.scope_stack[-1] == SCOPE_COMPREHENSION:
self.fail(
'"yield from" inside comprehension or generator expression',
e,
Expand Down Expand Up @@ -5848,7 +5893,7 @@ def visit__promote_expr(self, expr: PromoteExpr) -> None:
def visit_yield_expr(self, e: YieldExpr) -> None:
if not self.is_func_scope():
self.fail('"yield" outside function', e, serious=True, blocker=True)
elif self.is_comprehension_stack[-1]:
elif self.scope_stack[-1] == SCOPE_COMPREHENSION:
self.fail(
'"yield" inside comprehension or generator expression',
e,
Expand Down Expand Up @@ -6281,6 +6326,7 @@ def add_symbol(
can_defer: bool = True,
escape_comprehensions: bool = False,
no_progress: bool = False,
type_param: bool = False,
) -> bool:
"""Add symbol to the currently active symbol table.

Expand All @@ -6303,7 +6349,7 @@ def add_symbol(
kind, node, module_public=module_public, module_hidden=module_hidden
)
return self.add_symbol_table_node(
name, symbol, context, can_defer, escape_comprehensions, no_progress
name, symbol, context, can_defer, escape_comprehensions, no_progress, type_param
)

def add_symbol_skip_local(self, name: str, node: SymbolNode) -> None:
Expand Down Expand Up @@ -6336,6 +6382,7 @@ def add_symbol_table_node(
can_defer: bool = True,
escape_comprehensions: bool = False,
no_progress: bool = False,
type_param: bool = False,
) -> bool:
"""Add symbol table node to the currently active symbol table.

Expand All @@ -6355,7 +6402,9 @@ def add_symbol_table_node(
can_defer: if True, defer current target if adding a placeholder
context: error context (see above about None value)
"""
names = self.current_symbol_table(escape_comprehensions=escape_comprehensions)
names = self.current_symbol_table(
escape_comprehensions=escape_comprehensions, type_param=type_param
)
existing = names.get(name)
if isinstance(symbol.node, PlaceholderNode) and can_defer:
if context is not None:
Expand Down Expand Up @@ -6673,7 +6722,7 @@ def enter(
names = self.saved_locals.setdefault(function, SymbolTable())
self.locals.append(names)
is_comprehension = isinstance(function, (GeneratorExpr, DictionaryComprehension))
self.is_comprehension_stack.append(is_comprehension)
self.scope_stack.append(SCOPE_FUNC if not is_comprehension else SCOPE_COMPREHENSION)
self.global_decls.append(set())
self.nonlocal_decls.append(set())
# -1 since entering block will increment this to 0.
Expand All @@ -6684,19 +6733,22 @@ def enter(
yield
finally:
self.locals.pop()
self.is_comprehension_stack.pop()
self.scope_stack.pop()
self.global_decls.pop()
self.nonlocal_decls.pop()
self.block_depth.pop()
self.loop_depth.pop()
self.missing_names.pop()

def is_func_scope(self) -> bool:
return self.locals[-1] is not None
scope_type = self.scope_stack[-1]
if scope_type == SCOPE_ANNOTATION:
scope_type = self.scope_stack[-2]
return scope_type in (SCOPE_FUNC, SCOPE_COMPREHENSION)

def is_nested_within_func_scope(self) -> bool:
"""Are we underneath a function scope, even if we are in a nested class also?"""
return any(l is not None for l in self.locals)
return any(s in (SCOPE_FUNC, SCOPE_COMPREHENSION) for s in self.scope_stack)

def is_class_scope(self) -> bool:
return self.type is not None and not self.is_func_scope()
Expand All @@ -6713,14 +6765,24 @@ def current_symbol_kind(self) -> int:
kind = GDEF
return kind

def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTable:
if self.is_func_scope():
assert self.locals[-1] is not None
def current_symbol_table(
self, escape_comprehensions: bool = False, type_param: bool = False
) -> SymbolTable:
if type_param and self.scope_stack[-1] == SCOPE_ANNOTATION:
n = self.locals[-1]
assert n is not None
return n
elif self.is_func_scope():
if self.scope_stack[-1] == SCOPE_ANNOTATION:
n = self.locals[-2]
else:
n = self.locals[-1]
assert n is not None
if escape_comprehensions:
assert len(self.locals) == len(self.is_comprehension_stack)
assert len(self.locals) == len(self.scope_stack)
# Retrieve the symbol table from the enclosing non-comprehension scope.
for i, is_comprehension in enumerate(reversed(self.is_comprehension_stack)):
if not is_comprehension:
for i, scope_type in enumerate(reversed(self.scope_stack)):
if scope_type != SCOPE_COMPREHENSION:
if i == len(self.locals) - 1: # The last iteration.
# The caller of the comprehension is in the global space.
names = self.globals
Expand All @@ -6734,7 +6796,7 @@ def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTab
else:
assert False, "Should have at least one non-comprehension scope"
else:
names = self.locals[-1]
names = n
assert names is not None
elif self.type is not None:
names = self.type.names
Expand Down
Loading
Loading