Skip to content

Commit

Permalink
Streamline some elements of variadic types support (#15924)
Browse files Browse the repository at this point in the history
Fixes #13981
Fixes #15241
Fixes #15495
Fixes #15633
Fixes #15667
Fixes #15897
Fixes #15929

OK, I started following the plan outlined in
#15879. In this PR I focused mostly
on "kinematics". Here are some notes (in random order):
* I decided to normalize `TupleType` and `Instance` items in
`semanal_typeargs.py` (not in the type constructors, like for unions).
It looks like a simpler way to normalize for now. After this, we can
rely on the fact that only non-trivial (more below on what is trivial)
variadic items in a type list is either `*Ts` or `*tuple[X, ...]`. A
single top-level `TupleType` can appear in an unpack only as type of
`*args`.
* Callables turned out to be tricky. There is certain tight coupling
between `FuncDef.type` and `FuncDef.arguments` that makes it hard to
normalize prefix to be expressed as individual arguments _at
definition_. I faced exactly the same problem when I implemented `**`
unpacking for TypedDicts. So we have two choices: either handle prefixes
everywhere, or use normalization helper in relevant code. I propose to
go with the latter (it worked well for `**` unpacking).
* I decided to switch `Unpack` to be disallowed by default in
`typeanal.py`, only very specific potions are allowed now. Although this
required plumbing `allow_unpack` all the way from `semanal.py`,
conceptually it is simple. This is similar to how `ParamSpec` is
handled.
* This PR fixes all currently open crash issues (some intentionally,
some accidentally) plus a bunch of TODOs I found in the tests (but not
all).
* I decided to simplify `TypeAliasExpr` (and made it simple reference to
the `SymbolNode`, like e.g. `TypedDictExpr` and `NamedTupleExpr`). This
is not strictly necessary for this PR, but it makes some parts of it a
bit simpler, and I wanted to do it for long time.

Here is a more detailed plan of what I am leaving for future PRs (in
rough order of priority):
* Close non-crash open issues (it looks like there are only three, and
all seem to be straightforward)
* Handle trivial items in `UnpackType` gracefully. These are `<nothing>`
and `Any` (and also potentially `object`). They can appear e.g. after a
user error. Currently they can cause assert crashes. (Not sure what is
the best way to do this).
* Go over current places where `Unpack` is handled, and verify both
possible variadic items are handled.
* Audit variadic `Instance` constrains and subtyping (the latter is
probably OK, but the former may be broken).
* Audit `Callable` and `Tuple` subtyping for variadic-related edge cases
(constraints seem OK for these).
* Figure out story about `map_instance_to_supertype()` (if no changes
are needed, add tests for subclassing).
* Clear most remaining TODOs.
* Go once more over the large scale picture and check whether we have
some important parts missing (or unhandled interactions between those).
* Verify various "advanced" typing features work well with
`TypeVarTuple`s (and add some support if missing but looks important).
* Enable this feature by default.

I hope to finish these in next few weeks.
  • Loading branch information
ilevkivskyi authored Aug 23, 2023
1 parent 48835a3 commit 6f650cf
Show file tree
Hide file tree
Showing 22 changed files with 439 additions and 229 deletions.
5 changes: 1 addition & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4665,10 +4665,7 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
isinstance(iterable, TupleType)
and iterable.partial_fallback.type.fullname == "builtins.tuple"
):
joined: Type = UninhabitedType()
for item in iterable.items:
joined = join_types(joined, item)
return iterator, joined
return iterator, tuple_fallback(iterable).args[0]
else:
# Non-tuple iterable.
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]
Expand Down
11 changes: 7 additions & 4 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
find_unpack_in_list,
flatten_nested_unions,
get_proper_type,
get_proper_types,
Expand All @@ -185,7 +185,6 @@
)
from mypy.typestate import type_state
from mypy.typevars import fill_typevars
from mypy.typevartuples import find_unpack_in_list
from mypy.util import split_module_names
from mypy.visitor import ExpressionVisitor

Expand Down Expand Up @@ -1600,7 +1599,7 @@ def check_callable_call(
See the docstring of check_call for more information.
"""
# Always unpack **kwargs before checking a call.
callee = callee.with_unpacked_kwargs()
callee = callee.with_unpacked_kwargs().with_normalized_var_args()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
Expand Down Expand Up @@ -2409,7 +2408,12 @@ def check_argument_types(
+ unpacked_type.items[inner_unpack_index + 1 :]
)
callee_arg_kinds = [ARG_POS] * len(actuals)
elif isinstance(unpacked_type, TypeVarTupleType):
callee_arg_types = [orig_callee_arg_type]
callee_arg_kinds = [ARG_STAR]
else:
# TODO: Any and <nothing> can appear in Unpack (as a result of user error),
# fail gracefully here and elsewhere (and/or normalize them away).
assert isinstance(unpacked_type, Instance)
assert unpacked_type.type.fullname == "builtins.tuple"
callee_arg_types = [unpacked_type.args[0]] * len(actuals)
Expand Down Expand Up @@ -4451,7 +4455,6 @@ class C(Generic[T, Unpack[Ts]]): ...

prefix = next(i for (i, v) in enumerate(vars) if isinstance(v, TypeVarTupleType))
suffix = len(vars) - prefix - 1
args = flatten_nested_tuples(args)
if len(args) < len(vars) - 1:
self.msg.incompatible_type_application(len(vars), len(args), ctx)
return [AnyType(TypeOfAny.from_error)] * len(vars)
Expand Down
46 changes: 37 additions & 9 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
get_proper_type,
has_recursive_types,
has_type_vars,
Expand All @@ -57,7 +58,7 @@
)
from mypy.types_utils import is_union_with_any
from mypy.typestate import type_state
from mypy.typevartuples import extract_unpack, find_unpack_in_list, split_with_mapped_and_template
from mypy.typevartuples import extract_unpack, split_with_mapped_and_template

if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext
Expand Down Expand Up @@ -155,16 +156,33 @@ def infer_constraints_for_callable(
# not to hold we can always handle the prefixes too.
inner_unpack = unpacked_type.items[0]
assert isinstance(inner_unpack, UnpackType)
inner_unpacked_type = inner_unpack.type
assert isinstance(inner_unpacked_type, TypeVarTupleType)
inner_unpacked_type = get_proper_type(inner_unpack.type)
suffix_len = len(unpacked_type.items) - 1
constraints.append(
Constraint(
inner_unpacked_type,
SUPERTYPE_OF,
TupleType(actual_types[:-suffix_len], inner_unpacked_type.tuple_fallback),
if isinstance(inner_unpacked_type, TypeVarTupleType):
# Variadic item can be either *Ts...
constraints.append(
Constraint(
inner_unpacked_type,
SUPERTYPE_OF,
TupleType(
actual_types[:-suffix_len], inner_unpacked_type.tuple_fallback
),
)
)
)
else:
# ...or it can be a homogeneous tuple.
assert (
isinstance(inner_unpacked_type, Instance)
and inner_unpacked_type.type.fullname == "builtins.tuple"
)
for at in actual_types[:-suffix_len]:
constraints.extend(
infer_constraints(inner_unpacked_type.args[0], at, SUPERTYPE_OF)
)
# Now handle the suffix (if any).
if suffix_len:
for tt, at in zip(unpacked_type.items[1:], actual_types[-suffix_len:]):
constraints.extend(infer_constraints(tt, at, SUPERTYPE_OF))
else:
assert False, "mypy bug: unhandled constraint inference case"
else:
Expand Down Expand Up @@ -863,6 +881,16 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
and self.direction == SUPERTYPE_OF
):
for item in actual.items:
if isinstance(item, UnpackType):
unpacked = get_proper_type(item.type)
if isinstance(unpacked, TypeVarType):
# Cannot infer anything for T from [T, ...] <: *Ts
continue
assert (
isinstance(unpacked, Instance)
and unpacked.type.fullname == "builtins.tuple"
)
item = unpacked.args[0]
cb = infer_constraints(template.args[0], item, SUPERTYPE_OF)
res.extend(cb)
return res
Expand Down
111 changes: 18 additions & 93 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Final, Iterable, Mapping, Sequence, TypeVar, cast, overload

from mypy.nodes import ARG_POS, ARG_STAR, ArgKind, Var
from mypy.nodes import ARG_STAR, Var
from mypy.state import state
from mypy.types import (
ANY_STRATEGY,
Expand Down Expand Up @@ -35,12 +35,11 @@
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_tuples,
flatten_nested_unions,
get_proper_type,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import find_unpack_in_list, split_with_instance
from mypy.typevartuples import split_with_instance

# Solving the import cycle:
import mypy.type_visitor # ruff: isort: skip
Expand Down Expand Up @@ -294,101 +293,30 @@ def expand_unpack(self, t: UnpackType) -> list[Type] | AnyType | UninhabitedType
def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))

# TODO: can we simplify this method? It is too long.
def interpolate_args_for_unpack(
self, t: CallableType, var_arg: UnpackType
) -> tuple[list[str | None], list[ArgKind], list[Type]]:
def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> list[Type]:
star_index = t.arg_kinds.index(ARG_STAR)
prefix = self.expand_types(t.arg_types[:star_index])
suffix = self.expand_types(t.arg_types[star_index + 1 :])

var_arg_type = get_proper_type(var_arg.type)
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
if isinstance(var_arg_type, TupleType):
expanded_tuple = var_arg_type.accept(self)
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
expanded_items = expanded_tuple.items
fallback = var_arg_type.partial_fallback
else:
# We have plain Unpack[Ts]
assert isinstance(var_arg_type, TypeVarTupleType)
fallback = var_arg_type.tuple_fallback
expanded_items_res = self.expand_unpack(var_arg)
if isinstance(expanded_items_res, list):
expanded_items = expanded_items_res
else:
# We got Any or <nothing>
arg_types = (
t.arg_types[:star_index] + [expanded_items_res] + t.arg_types[star_index + 1 :]
)
return t.arg_names, t.arg_kinds, arg_types

expanded_unpack_index = find_unpack_in_list(expanded_items)
# This is the case where we just have Unpack[Tuple[X1, X2, X3]]
# (for example if either the tuple had no unpacks, or the unpack in the
# tuple got fully expanded to something with fixed length)
if expanded_unpack_index is None:
arg_names = (
t.arg_names[:star_index]
+ [None] * len(expanded_items)
+ t.arg_names[star_index + 1 :]
)
arg_kinds = (
t.arg_kinds[:star_index]
+ [ARG_POS] * len(expanded_items)
+ t.arg_kinds[star_index + 1 :]
)
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ expanded_items
+ self.expand_types(t.arg_types[star_index + 1 :])
)
else:
# If Unpack[Ts] simplest form still has an unpack or is a
# homogenous tuple, then only the prefix can be represented as
# positional arguments, and we pass Tuple[Unpack[Ts-1], Y1, Y2]
# as the star arg, for example.
expanded_unpack = expanded_items[expanded_unpack_index]
assert isinstance(expanded_unpack, UnpackType)

# Extract the TypeVarTuple, so we can get a tuple fallback from it.
expanded_unpacked_tvt = expanded_unpack.type
if isinstance(expanded_unpacked_tvt, TypeVarTupleType):
fallback = expanded_unpacked_tvt.tuple_fallback
else:
# This can happen when tuple[Any, ...] is used to "patch" a variadic
# generic type without type arguments provided, or when substitution is
# homogeneous tuple.
assert isinstance(expanded_unpacked_tvt, ProperType)
assert isinstance(expanded_unpacked_tvt, Instance)
assert expanded_unpacked_tvt.type.fullname == "builtins.tuple"
fallback = expanded_unpacked_tvt

prefix_len = expanded_unpack_index
arg_names = t.arg_names[:star_index] + [None] * prefix_len + t.arg_names[star_index:]
arg_kinds = (
t.arg_kinds[:star_index] + [ARG_POS] * prefix_len + t.arg_kinds[star_index:]
)
if (
len(expanded_items) == 1
and isinstance(expanded_unpack.type, ProperType)
and isinstance(expanded_unpack.type, Instance)
):
assert expanded_unpack.type.type.fullname == "builtins.tuple"
# Normalize *args: *tuple[X, ...] -> *args: X
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ [expanded_unpack.type.args[0]]
+ self.expand_types(t.arg_types[star_index + 1 :])
)
else:
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ expanded_items[:prefix_len]
# Constructing the Unpack containing the tuple without the prefix.
+ [
UnpackType(TupleType(expanded_items[prefix_len:], fallback))
if len(expanded_items) - prefix_len > 1
else expanded_items[prefix_len]
]
+ self.expand_types(t.arg_types[star_index + 1 :])
)
return arg_names, arg_kinds, arg_types
return prefix + [expanded_items_res] + suffix
new_unpack = UnpackType(TupleType(expanded_items, fallback))
return prefix + [new_unpack] + suffix

def visit_callable_type(self, t: CallableType) -> CallableType:
param_spec = t.param_spec()
Expand Down Expand Up @@ -427,20 +355,20 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
)

var_arg = t.var_arg()
needs_normalization = False
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
arg_names, arg_kinds, arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
needs_normalization = True
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_names = t.arg_names
arg_kinds = t.arg_kinds
arg_types = self.expand_types(t.arg_types)

return t.copy_modified(
expanded = t.copy_modified(
arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
if needs_normalization:
return expanded.with_normalized_var_args()
return expanded

def visit_overloaded(self, t: Overloaded) -> Type:
items: list[CallableType] = []
Expand All @@ -460,9 +388,6 @@ def expand_types_with_unpack(
indicates use of Any or some error occurred earlier. In this case callers should
simply propagate the resulting type.
"""
# TODO: this will cause a crash on aliases like A = Tuple[int, Unpack[A]].
# Although it is unlikely anyone will write this, we should fail gracefully.
typs = flatten_nested_tuples(typs)
items: list[Type] = []
for item in typs:
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
Expand Down
3 changes: 2 additions & 1 deletion mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
IMPLICIT_GENERIC_ANY_BUILTIN: Final = (
'Implicit generic "Any". Use "{}" and specify generic parameters'
)
INVALID_UNPACK = "{} cannot be unpacked (must be tuple or TypeVarTuple)"
INVALID_UNPACK: Final = "{} cannot be unpacked (must be tuple or TypeVarTuple)"
INVALID_UNPACK_POSITION: Final = "Unpack is only valid in a variadic position"

# TypeVar
INCOMPATIBLE_TYPEVAR_VALUE: Final = 'Value of type variable "{}" of {} cannot be {}'
Expand Down
2 changes: 1 addition & 1 deletion mypy/mixedtraverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def visit_class_def(self, o: ClassDef) -> None:
def visit_type_alias_expr(self, o: TypeAliasExpr) -> None:
super().visit_type_alias_expr(o)
self.in_type_alias_expr = True
o.type.accept(self)
o.node.target.accept(self)
self.in_type_alias_expr = False

def visit_type_var_expr(self, o: TypeVarExpr) -> None:
Expand Down
17 changes: 2 additions & 15 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2625,27 +2625,14 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleExpr:
class TypeAliasExpr(Expression):
"""Type alias expression (rvalue)."""

__slots__ = ("type", "tvars", "no_args", "node")
__slots__ = ("node",)

__match_args__ = ("type", "tvars", "no_args", "node")
__match_args__ = ("node",)

# The target type.
type: mypy.types.Type
# Names of type variables used to define the alias
tvars: list[str]
# Whether this alias was defined in bare form. Used to distinguish
# between
# A = List
# and
# A = List[Any]
no_args: bool
node: TypeAlias

def __init__(self, node: TypeAlias) -> None:
super().__init__()
self.type = node.target
self.tvars = [v.name for v in node.alias_tvars]
self.no_args = node.no_args
self.node = node

def accept(self, visitor: ExpressionVisitor[T]) -> T:
Expand Down
Loading

0 comments on commit 6f650cf

Please sign in to comment.