Skip to content

Commit

Permalink
Polymorphic inference: basic support for variadic types (#15879)
Browse files Browse the repository at this point in the history
This is the fifth PR in the series started by #15287, and a last one for
the foreseeable future. This completes polymorphic inference
sufficiently for extensive experimentation, and enabling polymorphic
fallback by default.

Remaining items for which I am going to open follow-up issues:
* Enable `--new-type-inference` by default (should be done before
everything else in this list).
* Use polymorphic inference during unification.
* Use polymorphic inference as primary an only mechanism, rather than a
fallback if basic inference fails in some way.
* Move `apply_poly()` logic from `checkexpr.py` to `applytype.py` (this
one depends on everything above).
* Experiment with backtracking in the new solver.
* Experiment with universal quantification for types other that
`Callable` (btw we already have a hacky support for capturing a generic
function in an instance with `ParamSpec`).

Now some comments on the PR proper. First of all I decided to do some
clean-up of `TypeVarTuple` support, but added only strictly necessary
parts of the cleanup here. Everything else will be in follow up PR(s).
The polymorphic inference/solver/application is practically trivial
here, so here is my view on how I see large-scale structure of
`TypeVarTuple` implementation:
* There should be no special-casing in `applytype.py`, so I deleted
everything from there (as I did for `ParamSpec`) and complemented
`visit_callable_type()` in `expandtype.py`. Basically, `applytype.py`
should have three simple steps: validate substitutions (upper bounds,
values, argument kinds etc.); call `expand_type()`; update callable type
variables (currently we just reduce the number, but in future we may
also add variables there, see TODO that I added).
* The only valid positions for a variadic item (a.k.a. `UnpackType`) are
inside `Instance`s, `TupleType`s, and `CallableType`s. I like how there
is an agreement that for callables there should never be a prefix, and
instead prefix should be represented with regular positional arguments.
I think that ideally we should enforce this with an `assert` in
`CallableType` constructor (similar to how I did this for `ParamSpec`).
* Completing `expand_type()` should be a priority (since it describes
basic semantics of `TypeVarLikeType`s). I think I made good progress in
this direction. IIUC the only valid substitution for `*Ts` are
`TupleType.items`, `*tuple[X, ...]`, `Any`, and `<nothing>`, so it was
not hard.
* I propose to only allow `TupleType` (mostly for `semanal.py`, see item
below), plain `TypeVarTupleType`, and a homogeneous `tuple` instances
inside `UnpackType`. Supporting unions of those is not specified by the
PEP and support will likely be quite tricky to implement. Also I propose
to even eagerly expand type aliases to tuples (since there is no point
in supporting recursive types like `A = Tuple[int, *A]`).
* I propose to forcefully flatten nested `TupleType`s, there should be
no things like `Tuple[X1, *Tuple[X2, *Ts, Y2], Y1]` etc after semantic
analysis. (Similarly to how we always flatten `Parameters` for
`ParamSpec`, and how we flatten nested unions in `UnionType`
_constructor_). Currently we do the flattening/normalization of tuples
in `expand_type()` etc.
* I suspect `build_constraints_for_unpack()` may be broken, at least
when it was used for tuples and callables it did something wrong in few
cases I tested (and there are other symptoms I mentioned in a TODO). I
therefore re-implemented logic for callables/tuples using a separate
dedicated helper. I will investigate more later.

As I mentioned above I only implemented strictly minimal amount of the
above plan to make my tests pass, but still wanted to write this out to
see if there are any objections (or maybe I don't understand something).
If there are no objections to this plan, I will continue it in separate
PR(s). Btw, I like how with this plan we will have clear logical
parallels between `TypeVarTuple` implementation and (recently updated)
`ParamSpec` implementation.

---------

Co-authored-by: Ivan Levkivskyi <[email protected]>
  • Loading branch information
ilevkivskyi and Ivan Levkivskyi authored Aug 18, 2023
1 parent fa84534 commit b02ddf1
Show file tree
Hide file tree
Showing 10 changed files with 440 additions and 209 deletions.
64 changes: 10 additions & 54 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
from typing import Callable, Sequence

import mypy.subtypes
from mypy.expandtype import expand_type, expand_unpack_with_variables
from mypy.nodes import ARG_STAR, Context
from mypy.expandtype import expand_type
from mypy.nodes import Context
from mypy.types import (
AnyType,
CallableType,
Instance,
ParamSpecType,
PartialType,
TupleType,
Type,
TypeVarId,
TypeVarLikeType,
Expand All @@ -21,7 +19,6 @@
UnpackType,
get_proper_type,
)
from mypy.typevartuples import find_unpack_in_list, replace_starargs


def get_target_type(
Expand Down Expand Up @@ -107,6 +104,8 @@ def apply_generic_arguments(
if target_type is not None:
id_to_type[tvar.id] = target_type

# TODO: validate arg_kinds/arg_names for ParamSpec and TypeVarTuple replacements,
# not just type variable bounds above.
param_spec = callable.param_spec()
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
Expand All @@ -122,55 +121,9 @@ def apply_generic_arguments(
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
star_index = callable.arg_kinds.index(ARG_STAR)
callable = callable.copy_modified(
arg_types=(
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
+ [callable.arg_types[star_index]]
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
)
)

unpacked_type = get_proper_type(var_arg.typ.type)
if isinstance(unpacked_type, TupleType):
# Assuming for now that because we convert prefixes to positional arguments,
# the first argument is always an unpack.
expanded_tuple = expand_type(unpacked_type, id_to_type)
if isinstance(expanded_tuple, TupleType):
# TODO: handle the case where the tuple has an unpack. This will
# hit an assert below.
expanded_unpack = find_unpack_in_list(expanded_tuple.items)
if expanded_unpack is not None:
callable = callable.copy_modified(
arg_types=(
callable.arg_types[:star_index]
+ [expanded_tuple]
+ callable.arg_types[star_index + 1 :]
)
)
else:
callable = replace_starargs(callable, expanded_tuple.items)
else:
# TODO: handle the case for if we get a variable length tuple.
assert False, f"mypy bug: unimplemented case, {expanded_tuple}"
elif isinstance(unpacked_type, TypeVarTupleType):
expanded_tvt = expand_unpack_with_variables(var_arg.typ, id_to_type)
if isinstance(expanded_tvt, list):
for t in expanded_tvt:
assert not isinstance(t, UnpackType)
callable = replace_starargs(callable, expanded_tvt)
else:
assert isinstance(expanded_tvt, Instance)
assert expanded_tvt.type.fullname == "builtins.tuple"
callable = callable.copy_modified(
arg_types=(
callable.arg_types[:star_index]
+ [expanded_tvt.args[0]]
+ callable.arg_types[star_index + 1 :]
)
)
else:
assert False, "mypy bug: unhandled case applying unpack"
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])
else:
callable = callable.copy_modified(
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
Expand All @@ -183,6 +136,9 @@ def apply_generic_arguments(
type_guard = None

# The callable may retain some type vars if only some were applied.
# TODO: move apply_poly() logic from checkexpr.py here when new inference
# becomes universally used (i.e. in all passes + in unification).
# With this new logic we can actually *add* some new free variables.
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]

return callable.copy_modified(
Expand Down
24 changes: 17 additions & 7 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,11 +2373,15 @@ def check_argument_types(
]
actual_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (len(actuals) - 1)

assert isinstance(orig_callee_arg_type, TupleType)
assert orig_callee_arg_type.items
callee_arg_types = orig_callee_arg_type.items
# TODO: can we really assert this? What if formal is just plain Unpack[Ts]?
assert isinstance(orig_callee_arg_type, UnpackType)
assert isinstance(orig_callee_arg_type.type, ProperType) and isinstance(
orig_callee_arg_type.type, TupleType
)
assert orig_callee_arg_type.type.items
callee_arg_types = orig_callee_arg_type.type.items
callee_arg_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (
len(orig_callee_arg_type.items) - 1
len(orig_callee_arg_type.type.items) - 1
)
expanded_tuple = True

Expand Down Expand Up @@ -5853,8 +5857,9 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# TODO: Support polymorphic apply for TypeVarTuple.
raise PolyTranslationError()
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var_tuple(t)

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
if not t.args:
Expand Down Expand Up @@ -5888,7 +5893,6 @@ def visit_instance(self, t: Instance) -> Type:
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
Expand Down Expand Up @@ -5923,6 +5927,12 @@ def __init__(self) -> None:
def visit_type_var(self, t: TypeVarType) -> bool:
return True

def visit_param_spec(self, t: ParamSpecType) -> bool:
return True

def visit_type_var_tuple(self, t: TypeVarTupleType) -> bool:
return True


def has_erased_component(t: Type | None) -> bool:
return t is not None and t.accept(HasErasedComponentsQuery())
Expand Down
Loading

0 comments on commit b02ddf1

Please sign in to comment.