Skip to content

Commit

Permalink
[PEP 695] Implement new scoping rules for type parameters (python#17258)
Browse files Browse the repository at this point in the history
Type parameters get a separate scope with some special features.

Work on python#15238.
  • Loading branch information
JukkaL authored May 17, 2024
1 parent 5fb8d62 commit 3b97e6e
Show file tree
Hide file tree
Showing 4 changed files with 323 additions and 45 deletions.
10 changes: 7 additions & 3 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2502,7 +2502,7 @@ class TypeVarLikeExpr(SymbolNode, Expression):
Note that they are constructed by the semantic analyzer.
"""

__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance")
__slots__ = ("_name", "_fullname", "upper_bound", "default", "variance", "is_new_style")

_name: str
_fullname: str
Expand All @@ -2525,13 +2525,15 @@ def __init__(
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
is_new_style: bool = False,
) -> None:
super().__init__()
self._name = name
self._fullname = fullname
self.upper_bound = upper_bound
self.default = default
self.variance = variance
self.is_new_style = is_new_style

@property
def name(self) -> str:
Expand Down Expand Up @@ -2570,8 +2572,9 @@ def __init__(
upper_bound: mypy.types.Type,
default: mypy.types.Type,
variance: int = INVARIANT,
is_new_style: bool = False,
) -> None:
super().__init__(name, fullname, upper_bound, default, variance)
super().__init__(name, fullname, upper_bound, default, variance, is_new_style)
self.values = values

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand Down Expand Up @@ -2648,8 +2651,9 @@ def __init__(
tuple_fallback: mypy.types.Instance,
default: mypy.types.Type,
variance: int = INVARIANT,
is_new_style: bool = False,
) -> None:
super().__init__(name, fullname, upper_bound, default, variance)
super().__init__(name, fullname, upper_bound, default, variance, is_new_style)
self.tuple_fallback = tuple_fallback

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand Down
118 changes: 90 additions & 28 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,14 @@
CORE_BUILTIN_CLASSES: Final = ["object", "bool", "function"]


# Python has several different scope/namespace kinds with subtly different semantics.
SCOPE_GLOBAL: Final = 0 # Module top level
SCOPE_CLASS: Final = 1 # Class body
SCOPE_FUNC: Final = 2 # Function or lambda
SCOPE_COMPREHENSION: Final = 3 # Comprehension or generator expression
SCOPE_ANNOTATION: Final = 4 # Annotation scopes for type parameters and aliases (PEP 695)


# Used for tracking incomplete references
Tag: _TypeAlias = int

Expand All @@ -342,8 +350,8 @@ class SemanticAnalyzer(
nonlocal_decls: list[set[str]]
# Local names of function scopes; None for non-function scopes.
locals: list[SymbolTable | None]
# Whether each scope is a comprehension scope.
is_comprehension_stack: list[bool]
# Type of each scope (SCOPE_*, indexes match locals)
scope_stack: list[int]
# Nested block depths of scopes
block_depth: list[int]
# TypeInfo of directly enclosing class (or None)
Expand Down Expand Up @@ -417,7 +425,7 @@ def __init__(
errors: Report analysis errors using this instance
"""
self.locals = [None]
self.is_comprehension_stack = [False]
self.scope_stack = [SCOPE_GLOBAL]
# Saved namespaces from previous iteration. Every top-level function/method body is
# analyzed in several iterations until all names are resolved. We need to save
# the local namespaces for the top level function and all nested functions between
Expand Down Expand Up @@ -880,6 +888,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
# Don't store not ready types (including placeholders).
if self.found_incomplete_ref(tag) or has_placeholder(result):
self.defer(defn)
# TODO: pop type args
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
Expand Down Expand Up @@ -1645,6 +1654,8 @@ def push_type_args(
) -> list[tuple[str, TypeVarLikeExpr]] | None:
if not type_args:
return []
self.locals.append(SymbolTable())
self.scope_stack.append(SCOPE_ANNOTATION)
tvs: list[tuple[str, TypeVarLikeExpr]] = []
for p in type_args:
tv = self.analyze_type_param(p)
Expand All @@ -1653,10 +1664,23 @@ def push_type_args(
tvs.append((p.name, tv))

for name, tv in tvs:
self.add_symbol(name, tv, context, no_progress=True)
if self.is_defined_type_param(name):
self.fail(f'"{name}" already defined as a type parameter', context)
else:
self.add_symbol(name, tv, context, no_progress=True, type_param=True)

return tvs

def is_defined_type_param(self, name: str) -> bool:
for names in self.locals:
if names is None:
continue
if name in names:
node = names[name].node
if isinstance(node, TypeVarLikeExpr):
return True
return False

def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
fullname = self.qualified_name(type_param.name)
if type_param.upper_bound:
Expand All @@ -1681,10 +1705,15 @@ def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
upper_bound=upper_bound,
default=default,
variance=VARIANCE_NOT_READY,
is_new_style=True,
)
elif type_param.kind == PARAM_SPEC_KIND:
return ParamSpecExpr(
name=type_param.name, fullname=fullname, upper_bound=upper_bound, default=default
name=type_param.name,
fullname=fullname,
upper_bound=upper_bound,
default=default,
is_new_style=True,
)
else:
assert type_param.kind == TYPE_VAR_TUPLE_KIND
Expand All @@ -1696,14 +1725,14 @@ def analyze_type_param(self, type_param: TypeParam) -> TypeVarLikeExpr | None:
upper_bound=tuple_fallback.copy_modified(),
tuple_fallback=tuple_fallback,
default=default,
is_new_style=True,
)

def pop_type_args(self, type_args: list[TypeParam] | None) -> None:
if not type_args:
return
for tv in type_args:
names = self.current_symbol_table()
del names[tv.name]
self.locals.pop()
self.scope_stack.pop()

def analyze_class(self, defn: ClassDef) -> None:
fullname = self.qualified_name(defn.name)
Expand Down Expand Up @@ -1785,8 +1814,18 @@ def analyze_class(self, defn: ClassDef) -> None:
defn.info.is_protocol = is_protocol
self.recalculate_metaclass(defn, declared_metaclass)
defn.info.runtime_protocol = False

if defn.type_args:
# PEP 695 type parameters are not in scope in class decorators, so
# temporarily disable type parameter namespace.
type_params_names = self.locals.pop()
self.scope_stack.pop()
for decorator in defn.decorators:
self.analyze_class_decorator(defn, decorator)
if defn.type_args:
self.locals.append(type_params_names)
self.scope_stack.append(SCOPE_ANNOTATION)

self.analyze_class_body_common(defn)

def setup_type_vars(self, defn: ClassDef, tvar_defs: list[TypeVarLikeType]) -> None:
Expand Down Expand Up @@ -1938,7 +1977,7 @@ def enter_class(self, info: TypeInfo) -> None:
# Remember previous active class
self.type_stack.append(self.type)
self.locals.append(None) # Add class scope
self.is_comprehension_stack.append(False)
self.scope_stack.append(SCOPE_CLASS)
self.block_depth.append(-1) # The class body increments this to 0
self.loop_depth.append(0)
self._type = info
Expand All @@ -1949,7 +1988,7 @@ def leave_class(self) -> None:
self.block_depth.pop()
self.loop_depth.pop()
self.locals.pop()
self.is_comprehension_stack.pop()
self.scope_stack.pop()
self._type = self.type_stack.pop()
self.missing_names.pop()

Expand Down Expand Up @@ -2923,8 +2962,8 @@ class C:
[(j := i) for i in [1, 2, 3]]
is a syntax error that is not enforced by Python parser, but at later steps.
"""
for i, is_comprehension in enumerate(reversed(self.is_comprehension_stack)):
if not is_comprehension and i < len(self.locals) - 1:
for i, scope_type in enumerate(reversed(self.scope_stack)):
if scope_type != SCOPE_COMPREHENSION and i < len(self.locals) - 1:
if self.locals[-1 - i] is None:
self.fail(
"Assignment expression within a comprehension"
Expand Down Expand Up @@ -5188,8 +5227,14 @@ def visit_nonlocal_decl(self, d: NonlocalDecl) -> None:
self.fail("nonlocal declaration not allowed at module level", d)
else:
for name in d.names:
for table in reversed(self.locals[:-1]):
for table, scope_type in zip(
reversed(self.locals[:-1]), reversed(self.scope_stack[:-1])
):
if table is not None and name in table:
if scope_type == SCOPE_ANNOTATION:
self.fail(
f'nonlocal binding not allowed for type parameter "{name}"', d
)
break
else:
self.fail(f'No binding for nonlocal "{name}" found', d)
Expand Down Expand Up @@ -5350,7 +5395,7 @@ def visit_star_expr(self, expr: StarExpr) -> None:
def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
if not self.is_func_scope():
self.fail('"yield from" outside function', e, serious=True, blocker=True)
elif self.is_comprehension_stack[-1]:
elif self.scope_stack[-1] == SCOPE_COMPREHENSION:
self.fail(
'"yield from" inside comprehension or generator expression',
e,
Expand Down Expand Up @@ -5848,7 +5893,7 @@ def visit__promote_expr(self, expr: PromoteExpr) -> None:
def visit_yield_expr(self, e: YieldExpr) -> None:
if not self.is_func_scope():
self.fail('"yield" outside function', e, serious=True, blocker=True)
elif self.is_comprehension_stack[-1]:
elif self.scope_stack[-1] == SCOPE_COMPREHENSION:
self.fail(
'"yield" inside comprehension or generator expression',
e,
Expand Down Expand Up @@ -6281,6 +6326,7 @@ def add_symbol(
can_defer: bool = True,
escape_comprehensions: bool = False,
no_progress: bool = False,
type_param: bool = False,
) -> bool:
"""Add symbol to the currently active symbol table.
Expand All @@ -6303,7 +6349,7 @@ def add_symbol(
kind, node, module_public=module_public, module_hidden=module_hidden
)
return self.add_symbol_table_node(
name, symbol, context, can_defer, escape_comprehensions, no_progress
name, symbol, context, can_defer, escape_comprehensions, no_progress, type_param
)

def add_symbol_skip_local(self, name: str, node: SymbolNode) -> None:
Expand Down Expand Up @@ -6336,6 +6382,7 @@ def add_symbol_table_node(
can_defer: bool = True,
escape_comprehensions: bool = False,
no_progress: bool = False,
type_param: bool = False,
) -> bool:
"""Add symbol table node to the currently active symbol table.
Expand All @@ -6355,7 +6402,9 @@ def add_symbol_table_node(
can_defer: if True, defer current target if adding a placeholder
context: error context (see above about None value)
"""
names = self.current_symbol_table(escape_comprehensions=escape_comprehensions)
names = self.current_symbol_table(
escape_comprehensions=escape_comprehensions, type_param=type_param
)
existing = names.get(name)
if isinstance(symbol.node, PlaceholderNode) and can_defer:
if context is not None:
Expand Down Expand Up @@ -6673,7 +6722,7 @@ def enter(
names = self.saved_locals.setdefault(function, SymbolTable())
self.locals.append(names)
is_comprehension = isinstance(function, (GeneratorExpr, DictionaryComprehension))
self.is_comprehension_stack.append(is_comprehension)
self.scope_stack.append(SCOPE_FUNC if not is_comprehension else SCOPE_COMPREHENSION)
self.global_decls.append(set())
self.nonlocal_decls.append(set())
# -1 since entering block will increment this to 0.
Expand All @@ -6684,19 +6733,22 @@ def enter(
yield
finally:
self.locals.pop()
self.is_comprehension_stack.pop()
self.scope_stack.pop()
self.global_decls.pop()
self.nonlocal_decls.pop()
self.block_depth.pop()
self.loop_depth.pop()
self.missing_names.pop()

def is_func_scope(self) -> bool:
return self.locals[-1] is not None
scope_type = self.scope_stack[-1]
if scope_type == SCOPE_ANNOTATION:
scope_type = self.scope_stack[-2]
return scope_type in (SCOPE_FUNC, SCOPE_COMPREHENSION)

def is_nested_within_func_scope(self) -> bool:
"""Are we underneath a function scope, even if we are in a nested class also?"""
return any(l is not None for l in self.locals)
return any(s in (SCOPE_FUNC, SCOPE_COMPREHENSION) for s in self.scope_stack)

def is_class_scope(self) -> bool:
return self.type is not None and not self.is_func_scope()
Expand All @@ -6713,14 +6765,24 @@ def current_symbol_kind(self) -> int:
kind = GDEF
return kind

def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTable:
if self.is_func_scope():
assert self.locals[-1] is not None
def current_symbol_table(
self, escape_comprehensions: bool = False, type_param: bool = False
) -> SymbolTable:
if type_param and self.scope_stack[-1] == SCOPE_ANNOTATION:
n = self.locals[-1]
assert n is not None
return n
elif self.is_func_scope():
if self.scope_stack[-1] == SCOPE_ANNOTATION:
n = self.locals[-2]
else:
n = self.locals[-1]
assert n is not None
if escape_comprehensions:
assert len(self.locals) == len(self.is_comprehension_stack)
assert len(self.locals) == len(self.scope_stack)
# Retrieve the symbol table from the enclosing non-comprehension scope.
for i, is_comprehension in enumerate(reversed(self.is_comprehension_stack)):
if not is_comprehension:
for i, scope_type in enumerate(reversed(self.scope_stack)):
if scope_type != SCOPE_COMPREHENSION:
if i == len(self.locals) - 1: # The last iteration.
# The caller of the comprehension is in the global space.
names = self.globals
Expand All @@ -6734,7 +6796,7 @@ def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTab
else:
assert False, "Should have at least one non-comprehension scope"
else:
names = self.locals[-1]
names = n
assert names is not None
elif self.type is not None:
names = self.type.names
Expand Down
Loading

0 comments on commit 3b97e6e

Please sign in to comment.