Skip to content

Commit

Permalink
Support TypeAliasType (#16926)
Browse files Browse the repository at this point in the history
Builds on top of and supersedes #16644

---------

Co-authored-by: sobolevn <[email protected]>
  • Loading branch information
hamdanal and sobolevn authored Mar 11, 2024
1 parent 16abf5c commit ea49e1f
Show file tree
Hide file tree
Showing 9 changed files with 373 additions and 69 deletions.
132 changes: 116 additions & 16 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@

from contextlib import contextmanager
from typing import Any, Callable, Collection, Final, Iterable, Iterator, List, TypeVar, cast
from typing_extensions import TypeAlias as _TypeAlias
from typing_extensions import TypeAlias as _TypeAlias, TypeGuard

from mypy import errorcodes as codes, message_registry
from mypy.constant_fold import constant_fold_expr
Expand Down Expand Up @@ -2018,34 +2018,35 @@ def analyze_class_typevar_declaration(self, base: Type) -> tuple[TypeVarLikeList

def analyze_unbound_tvar(self, t: Type) -> tuple[str, TypeVarLikeExpr] | None:
if isinstance(t, UnpackType) and isinstance(t.type, UnboundType):
return self.analyze_unbound_tvar_impl(t.type, allow_tvt=True)
return self.analyze_unbound_tvar_impl(t.type, is_unpacked=True)
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym and sym.fullname in ("typing.Unpack", "typing_extensions.Unpack"):
inner_t = t.args[0]
if isinstance(inner_t, UnboundType):
return self.analyze_unbound_tvar_impl(inner_t, allow_tvt=True)
return self.analyze_unbound_tvar_impl(inner_t, is_unpacked=True)
return None
return self.analyze_unbound_tvar_impl(t)
return None

def analyze_unbound_tvar_impl(
self, t: UnboundType, allow_tvt: bool = False
self, t: UnboundType, is_unpacked: bool = False, is_typealias_param: bool = False
) -> tuple[str, TypeVarLikeExpr] | None:
assert not is_unpacked or not is_typealias_param, "Mutually exclusive conditions"
sym = self.lookup_qualified(t.name, t)
if sym and isinstance(sym.node, PlaceholderNode):
self.record_incomplete_ref()
if not allow_tvt and sym and isinstance(sym.node, ParamSpecExpr):
if not is_unpacked and sym and isinstance(sym.node, ParamSpecExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
return t.name, sym.node
if allow_tvt and sym and isinstance(sym.node, TypeVarTupleExpr):
if (is_unpacked or is_typealias_param) and sym and isinstance(sym.node, TypeVarTupleExpr):
if sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
return None
return t.name, sym.node
if sym is None or not isinstance(sym.node, TypeVarExpr) or allow_tvt:
if sym is None or not isinstance(sym.node, TypeVarExpr) or is_unpacked:
return None
elif sym.fullname and not self.tvar_scope.allow_binding(sym.fullname):
# It's bound by our type variable scope
Expand Down Expand Up @@ -3515,7 +3516,11 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ
return typ

def analyze_alias(
self, name: str, rvalue: Expression, allow_placeholder: bool = False
self,
name: str,
rvalue: Expression,
allow_placeholder: bool = False,
declared_type_vars: TypeVarLikeList | None = None,
) -> tuple[Type | None, list[TypeVarLikeType], set[str], list[str], bool]:
"""Check if 'rvalue' is a valid type allowed for aliasing (e.g. not a type variable).
Expand All @@ -3540,9 +3545,10 @@ def analyze_alias(
found_type_vars = self.find_type_var_likes(typ)
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
alias_type_vars = found_type_vars if declared_type_vars is None else declared_type_vars
last_tvar_name_with_default: str | None = None
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
for name, tvar_expr in found_type_vars:
for name, tvar_expr in alias_type_vars:
tvar_expr.default = tvar_expr.default.accept(
TypeVarDefaultTranslator(self, tvar_expr.name, typ)
)
Expand All @@ -3567,6 +3573,7 @@ def analyze_alias(
in_dynamic_func=dynamic,
global_scope=global_scope,
allowed_alias_tvars=tvar_defs,
has_type_params=declared_type_vars is not None,
)

# There can be only one variadic variable at most, the error is reported elsewhere.
Expand All @@ -3579,7 +3586,7 @@ def analyze_alias(
variadic = True
new_tvar_defs.append(td)

qualified_tvars = [node.fullname for _name, node in found_type_vars]
qualified_tvars = [node.fullname for _name, node in alias_type_vars]
empty_tuple_index = typ.empty_tuple_index if isinstance(typ, UnboundType) else False
return analyzed, new_tvar_defs, depends_on, qualified_tvars, empty_tuple_index

Expand Down Expand Up @@ -3612,7 +3619,19 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# unless using PEP 613 `cls: TypeAlias = A`
return False

if isinstance(s.rvalue, CallExpr) and s.rvalue.analyzed:
# It can be `A = TypeAliasType('A', ...)` call, in this case,
# we just take the second argument and analyze it:
type_params: TypeVarLikeList | None
if self.check_type_alias_type_call(s.rvalue, name=lvalue.name):
rvalue = s.rvalue.args[1]
pep_695 = True
type_params = self.analyze_type_alias_type_params(s.rvalue)
else:
rvalue = s.rvalue
pep_695 = False
type_params = None

if isinstance(rvalue, CallExpr) and rvalue.analyzed:
return False

existing = self.current_symbol_table().get(lvalue.name)
Expand All @@ -3638,7 +3657,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
return False

non_global_scope = self.type or self.is_func_scope()
if not pep_613 and isinstance(s.rvalue, RefExpr) and non_global_scope:
if not pep_613 and isinstance(rvalue, RefExpr) and non_global_scope:
# Fourth rule (special case): Non-subscripted right hand side creates a variable
# at class and function scopes. For example:
#
Expand All @@ -3650,8 +3669,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# without this rule, this typical use case will require a lot of explicit
# annotations (see the second rule).
return False
rvalue = s.rvalue
if not pep_613 and not self.can_be_type_alias(rvalue):
if not pep_613 and not pep_695 and not self.can_be_type_alias(rvalue):
return False

if existing and not isinstance(existing.node, (PlaceholderNode, TypeAlias)):
Expand All @@ -3668,7 +3686,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
else:
tag = self.track_incomplete_refs()
res, alias_tvars, depends_on, qualified_tvars, empty_tuple_index = self.analyze_alias(
lvalue.name, rvalue, allow_placeholder=True
lvalue.name, rvalue, allow_placeholder=True, declared_type_vars=type_params
)
if not res:
return False
Expand Down Expand Up @@ -3698,13 +3716,15 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# so we need to replace it with non-explicit Anys.
res = make_any_non_explicit(res)
# Note: with the new (lazy) type alias representation we only need to set no_args to True
# if the expected number of arguments is non-zero, so that aliases like A = List work.
# if the expected number of arguments is non-zero, so that aliases like `A = List` work
# but not aliases like `A = TypeAliasType("A", List)` as these need explicit type params.
# However, eagerly expanding aliases like Text = str is a nice performance optimization.
no_args = (
isinstance(res, ProperType)
and isinstance(res, Instance)
and not res.args
and not empty_tuple_index
and not pep_695
)
if isinstance(res, ProperType) and isinstance(res, Instance):
if not validate_instance(res, self.fail, empty_tuple_index):
Expand Down Expand Up @@ -3771,6 +3791,80 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
self.note("Use variable annotation syntax to define protocol members", s)
return True

def check_type_alias_type_call(self, rvalue: Expression, *, name: str) -> TypeGuard[CallExpr]:
if not isinstance(rvalue, CallExpr):
return False

names = ["typing_extensions.TypeAliasType"]
if self.options.python_version >= (3, 12):
names.append("typing.TypeAliasType")
if not refers_to_fullname(rvalue.callee, tuple(names)):
return False

return self.check_typevarlike_name(rvalue, name, rvalue)

def analyze_type_alias_type_params(self, rvalue: CallExpr) -> TypeVarLikeList:
if "type_params" in rvalue.arg_names:
type_params_arg = rvalue.args[rvalue.arg_names.index("type_params")]
if not isinstance(type_params_arg, TupleExpr):
self.fail(
"Tuple literal expected as the type_params argument to TypeAliasType",
type_params_arg,
)
return []
type_params = type_params_arg.items
else:
type_params = []

declared_tvars: TypeVarLikeList = []
have_type_var_tuple = False
for tp_expr in type_params:
if isinstance(tp_expr, StarExpr):
tp_expr.valid = False
self.analyze_type_expr(tp_expr)
try:
base = self.expr_to_unanalyzed_type(tp_expr)
except TypeTranslationError:
continue
if not isinstance(base, UnboundType):
continue

tag = self.track_incomplete_refs()
tvar = self.analyze_unbound_tvar_impl(base, is_typealias_param=True)
if tvar:
if isinstance(tvar[1], TypeVarTupleExpr):
if have_type_var_tuple:
self.fail(
"Can only use one TypeVarTuple in type_params argument to TypeAliasType",
base,
code=codes.TYPE_VAR,
)
have_type_var_tuple = True
continue
have_type_var_tuple = True
elif not self.found_incomplete_ref(tag):
self.fail(
"Free type variable expected in type_params argument to TypeAliasType",
base,
code=codes.TYPE_VAR,
)
sym = self.lookup_qualified(base.name, base)
if sym and sym.fullname in ("typing.Unpack", "typing_extensions.Unpack"):
self.note(
"Don't Unpack type variables in type_params", base, code=codes.TYPE_VAR
)
continue
if tvar in declared_tvars:
self.fail(
f'Duplicate type variable "{tvar[0]}" in type_params argument to TypeAliasType',
base,
code=codes.TYPE_VAR,
)
continue
if tvar:
declared_tvars.append(tvar)
return declared_tvars

def disable_invalid_recursive_aliases(
self, s: AssignmentStmt, current_node: TypeAlias
) -> None:
Expand Down Expand Up @@ -5187,6 +5281,12 @@ def visit_call_expr(self, expr: CallExpr) -> None:
expr.analyzed = OpExpr("divmod", expr.args[0], expr.args[1])
expr.analyzed.line = expr.line
expr.analyzed.accept(self)
elif refers_to_fullname(
expr.callee, ("typing.TypeAliasType", "typing_extensions.TypeAliasType")
):
with self.allow_unbound_tvars_set():
for a in expr.args:
a.accept(self)
else:
# Normal call expression.
for a in expr.args:
Expand Down
49 changes: 37 additions & 12 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def analyze_type_alias(
in_dynamic_func: bool = False,
global_scope: bool = True,
allowed_alias_tvars: list[TypeVarLikeType] | None = None,
has_type_params: bool = False,
) -> tuple[Type, set[str]]:
"""Analyze r.h.s. of a (potential) type alias definition.
Expand All @@ -158,6 +159,7 @@ def analyze_type_alias(
allow_placeholder=allow_placeholder,
prohibit_self_type="type alias target",
allowed_alias_tvars=allowed_alias_tvars,
has_type_params=has_type_params,
)
analyzer.in_dynamic_func = in_dynamic_func
analyzer.global_scope = global_scope
Expand Down Expand Up @@ -210,6 +212,7 @@ def __init__(
prohibit_self_type: str | None = None,
allowed_alias_tvars: list[TypeVarLikeType] | None = None,
allow_type_any: bool = False,
has_type_params: bool = False,
) -> None:
self.api = api
self.fail_func = api.fail
Expand All @@ -231,6 +234,7 @@ def __init__(
if allowed_alias_tvars is None:
allowed_alias_tvars = []
self.allowed_alias_tvars = allowed_alias_tvars
self.has_type_params = has_type_params
# If false, record incomplete ref if we generate PlaceholderType.
self.allow_placeholder = allow_placeholder
# Are we in a context where Required[] is allowed?
Expand Down Expand Up @@ -325,7 +329,11 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
if tvar_def is None:
if self.allow_unbound_tvars:
return t
self.fail(f'ParamSpec "{t.name}" is unbound', t, code=codes.VALID_TYPE)
if self.defining_alias and self.has_type_params:
msg = f'ParamSpec "{t.name}" is not included in type_params'
else:
msg = f'ParamSpec "{t.name}" is unbound'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
assert isinstance(tvar_def, ParamSpecType)
if len(t.args) > 0:
Expand All @@ -349,11 +357,11 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
and not defining_literal
and (tvar_def is None or tvar_def not in self.allowed_alias_tvars)
):
self.fail(
f'Can\'t use bound type variable "{t.name}" to define generic alias',
t,
code=codes.VALID_TYPE,
)
if self.has_type_params:
msg = f'Type variable "{t.name}" is not included in type_params'
else:
msg = f'Can\'t use bound type variable "{t.name}" to define generic alias'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
if isinstance(sym.node, TypeVarExpr) and tvar_def is not None:
assert isinstance(tvar_def, TypeVarType)
Expand All @@ -368,17 +376,21 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool)
and self.defining_alias
and tvar_def not in self.allowed_alias_tvars
):
self.fail(
f'Can\'t use bound type variable "{t.name}" to define generic alias',
t,
code=codes.VALID_TYPE,
)
if self.has_type_params:
msg = f'Type variable "{t.name}" is not included in type_params'
else:
msg = f'Can\'t use bound type variable "{t.name}" to define generic alias'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
if isinstance(sym.node, TypeVarTupleExpr):
if tvar_def is None:
if self.allow_unbound_tvars:
return t
self.fail(f'TypeVarTuple "{t.name}" is unbound', t, code=codes.VALID_TYPE)
if self.defining_alias and self.has_type_params:
msg = f'TypeVarTuple "{t.name}" is not included in type_params'
else:
msg = f'TypeVarTuple "{t.name}" is unbound'
self.fail(msg, t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
assert isinstance(tvar_def, TypeVarTupleType)
if not self.allow_type_var_tuple:
Expand Down Expand Up @@ -1267,6 +1279,19 @@ def analyze_callable_args_for_paramspec(
AnyType(TypeOfAny.explicit), ret_type=ret_type, fallback=fallback
)
return None
elif (
self.defining_alias
and self.has_type_params
and tvar_def not in self.allowed_alias_tvars
):
self.fail(
f'ParamSpec "{callable_args.name}" is not included in type_params',
callable_args,
code=codes.VALID_TYPE,
)
return callable_with_ellipsis(
AnyType(TypeOfAny.special_form), ret_type=ret_type, fallback=fallback
)

return CallableType(
[
Expand Down
Loading

0 comments on commit ea49e1f

Please sign in to comment.