Skip to content

Commit

Permalink
Subtyping and inference of user defined variadic types (#16076)
Browse files Browse the repository at this point in the history
The second part of support for user defined variadic types comes as a
single PR, it was hard to split into smaller parts. This part covers
subtyping and inference (and relies on the first part: type analysis,
normalization, and expansion, concluded by
#15991). Note btw that the third (and
last) part that covers actually using all the stuff in `checkexpr.py`
will likely come as several smaller PRs.

Some comments on this PR:
* First good news: it looks like instances subtyping/inference can be
handled in a really simple way, we just need to find correct type
arguments mapping for each type variable, and perform procedures
argument by argument (note this heavily relies on the normalization).
Also callable subtyping inference for variadic items effectively defers
to corresponding tuple types. This way all code paths will ultimately go
through variadic tuple subtyping/inference (there is still a bunch of
boilerplate to do the mapping, but it is quite simple).
* Second some bad news: a lot of edge cases involving `*tuple[X, ...]`
were missing everywhere (even couple cases in the code I touched
before). I added all that were either simple or important. We can handle
more if users will ask, since it is quite tricky.
* Note that I handle variadic tuples essentially as infinite unions, the
core of the logic for this (and for most of this PR FWIW) is in
`variadic_tuple_subtype()`.
* Previously `Foo[*tuple[int, ...]]` was considered a subtype of
`Foo[int, int]`. I think this is wrong. I didn't find where this is
required in the PEP (see one case below however), and mypy currently
considers `tuple[int, ...]` not a subtype of `tuple[int, int]` (vice
versa are subtypes), and similarly `(*args: int)` vs `(x: int, y: int)`
for callables. Because of the logic I described in the first comment,
the same logic now uniformly applies to instances as well.
* Note however the PEP requires special casing of `Foo[*tuple[Any,
...]]` (equivalent to bare `Foo`), and I agree we should do this. I
added a minimal special case for this. Note we also do this for
callables as well (`*args: Any` is very different from `*args: object`).
And I think we should special case `tuple[Any, ...] <: tuple[int, int]`
as well. In the future we can even extend the special casing to
`tuple[int, *tuple[Any, ...], int]` in the spirit of
#15913
* In this PR I specifically only handle the PEP required item from above
for instances. For plain tuples I left a TODO, @hauntsaninja may
implement it since it is needed for other unrelated PR.
* I make the default upper bound for `TypeVarTupleType` to be
`tuple[object, ...]`. I think it can never be `object` (and this
simplifies some subtyping corner cases).
* TBH I didn't look into callables subtyping/inference very deeply
(unlike instances and tuples), if needed we can improve their handling
later.
* Note I remove some failing unit tests because they test non-nomralized
forms that should never appear now. We should probably add some more unit
tests, but TBH I am quite tired now.
  • Loading branch information
ilevkivskyi authored Sep 13, 2023
1 parent 66fbf5b commit b327557
Show file tree
Hide file tree
Showing 19 changed files with 943 additions and 515 deletions.
231 changes: 109 additions & 122 deletions mypy/constraints.py

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,16 @@ def visit_deleted_type(self, t: DeletedType) -> ProperType:
return t

def visit_instance(self, t: Instance) -> ProperType:
return Instance(t.type, [AnyType(TypeOfAny.special_form)] * len(t.args), t.line)
args: list[Type] = []
for tv in t.type.defn.type_vars:
# Valid erasure for *Ts is *tuple[Any, ...], not just Any.
if isinstance(tv, TypeVarTupleType):
args.append(
tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
)
else:
args.append(AnyType(TypeOfAny.special_form))
return Instance(t.type, args, t.line)

def visit_type_var(self, t: TypeVarType) -> ProperType:
return AnyType(TypeOfAny.special_form)
Expand Down
3 changes: 2 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
variables=[*t.prefix.variables, *repl.variables],
)
else:
# TODO: replace this with "assert False"
# We could encode Any as trivial parameters etc., but it would be too verbose.
# TODO: assert this is a trivial type, like Any, Never, or object.
return repl

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
Expand Down
17 changes: 9 additions & 8 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,17 @@ def visit_type_info(self, info: TypeInfo) -> None:
info.update_tuple_type(info.tuple_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
for i, t in enumerate(info.defn.type_vars):
if isinstance(t, TypeVarTupleType):
info.special_alias.tvar_tuple_index = i
if info.typeddict_type:
info.typeddict_type.accept(self.type_fixer)
info.update_typeddict_type(info.typeddict_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
for i, t in enumerate(info.defn.type_vars):
if isinstance(t, TypeVarTupleType):
info.special_alias.tvar_tuple_index = i
if info.declared_metaclass:
info.declared_metaclass.accept(self.type_fixer)
if info.metaclass_type:
Expand Down Expand Up @@ -166,11 +172,7 @@ def visit_decorator(self, d: Decorator) -> None:

def visit_class_def(self, c: ClassDef) -> None:
for v in c.type_vars:
if isinstance(v, TypeVarType):
for value in v.values:
value.accept(self.type_fixer)
v.upper_bound.accept(self.type_fixer)
v.default.accept(self.type_fixer)
v.accept(self.type_fixer)

def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
Expand All @@ -184,6 +186,7 @@ def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:

def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.tuple_fallback.accept(self.type_fixer)
tv.default.accept(self.type_fixer)

def visit_var(self, v: Var) -> None:
Expand Down Expand Up @@ -314,6 +317,7 @@ def visit_param_spec(self, p: ParamSpecType) -> None:
p.default.accept(self)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.tuple_fallback.accept(self)
t.upper_bound.accept(self)
t.default.accept(self)

Expand All @@ -336,9 +340,6 @@ def visit_union_type(self, ut: UnionType) -> None:
for it in ut.items:
it.accept(self)

def visit_void(self, o: Any) -> None:
pass # Nothing to descend into.

def visit_type_type(self, t: TypeType) -> None:
t.item.accept(self)

Expand Down
154 changes: 148 additions & 6 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
get_proper_type,
get_proper_types,
split_with_prefix_and_suffix,
)


Expand All @@ -67,7 +69,25 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
args: list[Type] = []
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for ta, sa, type_var in zip(t.args, s.args, t.type.defn.type_vars):
if t.type.has_type_var_tuple_type:
# We handle joins of variadic instances by simply creating correct mapping
# for type arguments and compute the individual joins same as for regular
# instances. All the heavy lifting is done in the join of tuple types.
assert s.type.type_var_tuple_prefix is not None
assert s.type.type_var_tuple_suffix is not None
prefix = s.type.type_var_tuple_prefix
suffix = s.type.type_var_tuple_suffix
tvt = s.type.defn.type_vars[prefix]
assert isinstance(tvt, TypeVarTupleType)
fallback = tvt.tuple_fallback
s_prefix, s_middle, s_suffix = split_with_prefix_and_suffix(s.args, prefix, suffix)
t_prefix, t_middle, t_suffix = split_with_prefix_and_suffix(t.args, prefix, suffix)
s_args = s_prefix + (TupleType(list(s_middle), fallback),) + s_suffix
t_args = t_prefix + (TupleType(list(t_middle), fallback),) + t_suffix
else:
t_args = t.args
s_args = s.args
for ta, sa, type_var in zip(t_args, s_args, t.type.defn.type_vars):
ta_proper = get_proper_type(ta)
sa_proper = get_proper_type(sa)
new_type: Type | None = None
Expand All @@ -93,6 +113,18 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
# If the types are different but equivalent, then an Any is involved
# so using a join in the contravariant case is also OK.
new_type = join_types(ta, sa, self)
elif isinstance(type_var, TypeVarTupleType):
new_type = get_proper_type(join_types(ta, sa, self))
# Put the joined arguments back into instance in the normal form:
# a) Tuple[X, Y, Z] -> [X, Y, Z]
# b) tuple[X, ...] -> [*tuple[X, ...]]
if isinstance(new_type, Instance):
assert new_type.type.fullname == "builtins.tuple"
new_type = UnpackType(new_type)
else:
assert isinstance(new_type, TupleType)
args.extend(new_type.items)
continue
else:
# ParamSpec type variables behave the same, independent of variance
if not is_equivalent(ta, sa):
Expand Down Expand Up @@ -440,6 +472,113 @@ def visit_overloaded(self, t: Overloaded) -> ProperType:
return join_types(t, call)
return join_types(t.fallback, s)

def join_tuples(self, s: TupleType, t: TupleType) -> list[Type] | None:
"""Join two tuple types while handling variadic entries.
This is surprisingly tricky, and we don't handle some tricky corner cases.
Most of the trickiness comes from the variadic tuple items like *tuple[X, ...]
since they can have arbitrary partial overlaps (while *Ts can't be split).
"""
s_unpack_index = find_unpack_in_list(s.items)
t_unpack_index = find_unpack_in_list(t.items)
if s_unpack_index is None and t_unpack_index is None:
if s.length() == t.length():
items: list[Type] = []
for i in range(t.length()):
items.append(join_types(t.items[i], s.items[i]))
return items
return None
if s_unpack_index is not None and t_unpack_index is not None:
# The most complex case: both tuples have an upack item.
s_unpack = s.items[s_unpack_index]
assert isinstance(s_unpack, UnpackType)
s_unpacked = get_proper_type(s_unpack.type)
t_unpack = t.items[t_unpack_index]
assert isinstance(t_unpack, UnpackType)
t_unpacked = get_proper_type(t_unpack.type)
if s.length() == t.length() and s_unpack_index == t_unpack_index:
# We can handle a case where arity is perfectly aligned, e.g.
# join(Tuple[X1, *tuple[Y1, ...], Z1], Tuple[X2, *tuple[Y2, ...], Z2]).
# We can essentially perform the join elementwise.
prefix_len = t_unpack_index
suffix_len = t.length() - t_unpack_index - 1
items = []
for si, ti in zip(s.items[:prefix_len], t.items[:prefix_len]):
items.append(join_types(si, ti))
joined = join_types(s_unpacked, t_unpacked)
if isinstance(joined, TypeVarTupleType):
items.append(UnpackType(joined))
elif isinstance(joined, Instance) and joined.type.fullname == "builtins.tuple":
items.append(UnpackType(joined))
else:
if isinstance(t_unpacked, Instance):
assert t_unpacked.type.fullname == "builtins.tuple"
tuple_instance = t_unpacked
else:
assert isinstance(t_unpacked, TypeVarTupleType)
tuple_instance = t_unpacked.tuple_fallback
items.append(
UnpackType(
tuple_instance.copy_modified(
args=[object_from_instance(tuple_instance)]
)
)
)
if suffix_len:
for si, ti in zip(s.items[-suffix_len:], t.items[-suffix_len:]):
items.append(join_types(si, ti))
return items
if s.length() == 1 or t.length() == 1:
# Another case we can handle is when one of tuple is purely variadic
# (i.e. a non-normalized form of tuple[X, ...]), in this case the join
# will be again purely variadic.
if not (isinstance(s_unpacked, Instance) and isinstance(t_unpacked, Instance)):
return None
assert s_unpacked.type.fullname == "builtins.tuple"
assert t_unpacked.type.fullname == "builtins.tuple"
mid_joined = join_types(s_unpacked.args[0], t_unpacked.args[0])
t_other = [a for i, a in enumerate(t.items) if i != t_unpack_index]
s_other = [a for i, a in enumerate(s.items) if i != s_unpack_index]
other_joined = join_type_list(s_other + t_other)
mid_joined = join_types(mid_joined, other_joined)
return [UnpackType(s_unpacked.copy_modified(args=[mid_joined]))]
# TODO: are there other case we can handle (e.g. both prefix/suffix are shorter)?
return None
if s_unpack_index is not None:
variadic = s
unpack_index = s_unpack_index
fixed = t
else:
assert t_unpack_index is not None
variadic = t
unpack_index = t_unpack_index
fixed = s
# Case where one tuple has variadic item and the other one doesn't. The join will
# be variadic, since fixed tuple is a subtype of variadic, but not vice versa.
unpack = variadic.items[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
if not isinstance(unpacked, Instance):
return None
if fixed.length() < variadic.length() - 1:
# There are no non-trivial types that are supertype of both.
return None
prefix_len = unpack_index
suffix_len = variadic.length() - prefix_len - 1
prefix, middle, suffix = split_with_prefix_and_suffix(
tuple(fixed.items), prefix_len, suffix_len
)
items = []
for fi, vi in zip(prefix, variadic.items[:prefix_len]):
items.append(join_types(fi, vi))
mid_joined = join_type_list(list(middle))
mid_joined = join_types(mid_joined, unpacked.args[0])
items.append(UnpackType(unpacked.copy_modified(args=[mid_joined])))
if suffix_len:
for fi, vi in zip(suffix, variadic.items[-suffix_len:]):
items.append(join_types(fi, vi))
return items

def visit_tuple_type(self, t: TupleType) -> ProperType:
# When given two fixed-length tuples:
# * If they have the same length, join their subtypes item-wise:
Expand All @@ -452,19 +591,22 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
# Tuple[int, bool] + Tuple[bool, ...] becomes Tuple[int, ...]
# * Joining with any Sequence also returns a Sequence:
# Tuple[int, bool] + List[bool] becomes Sequence[int]
if isinstance(self.s, TupleType) and self.s.length() == t.length():
if isinstance(self.s, TupleType):
if self.instance_joiner is None:
self.instance_joiner = InstanceJoiner()
fallback = self.instance_joiner.join_instances(
mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t)
)
assert isinstance(fallback, Instance)
if self.s.length() == t.length():
items: list[Type] = []
for i in range(t.length()):
items.append(join_types(t.items[i], self.s.items[i]))
items = self.join_tuples(self.s, t)
if items is not None:
return TupleType(items, fallback)
else:
# TODO: should this be a default fallback behaviour like for meet?
if is_proper_subtype(self.s, t):
return t
if is_proper_subtype(t, self.s):
return self.s
return fallback
else:
return join_types(self.s, mypy.typeops.tuple_fallback(t))
Expand Down
Loading

0 comments on commit b327557

Please sign in to comment.