Skip to content

Commit

Permalink
Better support for variadic calls and indexing (#16131)
Browse files Browse the repository at this point in the history
This improves support for two features that were supported but only
partially: variadic calls, and variadic indexing. Some notes:
* I did not add dedicated support for slicing of tuples with homogeneous
variadic items (except for cases covered by TypeVarTuple support, i.e.
those not involving splitting of variadic item). This is tricky and it
is not clear what cases people actually want. I left a TODO for this.
* I prohibit multiple variadic items in a call expression. Technically,
we can support some situations involving these, but this is tricky, and
prohibiting this would be in the "spirit" of the PEP, IMO.
* I may have still missed some cases for the calls, since there are so
many options. If you have ideas for additional test cases, please let me
know.
* It was necessary to fix overload ambiguity logic to make some tests
pass. This goes beyond TypeVarTuple support, but I think this is a
correct change.
  • Loading branch information
ilevkivskyi authored Sep 28, 2023
1 parent d25d680 commit 0291ec9
Show file tree
Hide file tree
Showing 8 changed files with 306 additions and 57 deletions.
156 changes: 132 additions & 24 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,27 @@ def check_callable_call(
callee.type_object().name, abstract_attributes, context
)

var_arg = callee.var_arg()
if var_arg and isinstance(var_arg.typ, UnpackType):
# It is hard to support multiple variadic unpacks (except for old-style *args: int),
# fail gracefully to avoid crashes later.
seen_unpack = False
for arg, arg_kind in zip(args, arg_kinds):
if arg_kind != ARG_STAR:
continue
arg_type = get_proper_type(self.accept(arg))
if not isinstance(arg_type, TupleType) or any(
isinstance(t, UnpackType) for t in arg_type.items
):
if seen_unpack:
self.msg.fail(
"Passing multiple variadic unpacks in a call is not supported",
context,
code=codes.CALL_ARG,
)
return AnyType(TypeOfAny.from_error), callee
seen_unpack = True

formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
Expand Down Expand Up @@ -2405,7 +2426,7 @@ def check_argument_types(
]
actual_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (len(actuals) - 1)

# TODO: can we really assert this? What if formal is just plain Unpack[Ts]?
# If we got here, the callee was previously inferred to have a suffix.
assert isinstance(orig_callee_arg_type, UnpackType)
assert isinstance(orig_callee_arg_type.type, ProperType) and isinstance(
orig_callee_arg_type.type, TupleType
Expand All @@ -2431,22 +2452,29 @@ def check_argument_types(
inner_unpack = unpacked_type.items[inner_unpack_index]
assert isinstance(inner_unpack, UnpackType)
inner_unpacked_type = get_proper_type(inner_unpack.type)
# We assume heterogenous tuples are desugared earlier
assert isinstance(inner_unpacked_type, Instance)
assert inner_unpacked_type.type.fullname == "builtins.tuple"
callee_arg_types = (
unpacked_type.items[:inner_unpack_index]
+ [inner_unpacked_type.args[0]]
* (len(actuals) - len(unpacked_type.items) + 1)
+ unpacked_type.items[inner_unpack_index + 1 :]
)
callee_arg_kinds = [ARG_POS] * len(actuals)
if isinstance(inner_unpacked_type, TypeVarTupleType):
# This branch mimics the expanded_tuple case above but for
# the case where caller passed a single * unpacked tuple argument.
callee_arg_types = unpacked_type.items
callee_arg_kinds = [
ARG_POS if i != inner_unpack_index else ARG_STAR
for i in range(len(unpacked_type.items))
]
else:
# We assume heterogeneous tuples are desugared earlier.
assert isinstance(inner_unpacked_type, Instance)
assert inner_unpacked_type.type.fullname == "builtins.tuple"
callee_arg_types = (
unpacked_type.items[:inner_unpack_index]
+ [inner_unpacked_type.args[0]]
* (len(actuals) - len(unpacked_type.items) + 1)
+ unpacked_type.items[inner_unpack_index + 1 :]
)
callee_arg_kinds = [ARG_POS] * len(actuals)
elif isinstance(unpacked_type, TypeVarTupleType):
callee_arg_types = [orig_callee_arg_type]
callee_arg_kinds = [ARG_STAR]
else:
# TODO: Any and Never can appear in Unpack (as a result of user error),
# fail gracefully here and elsewhere (and/or normalize them away).
assert isinstance(unpacked_type, Instance)
assert unpacked_type.type.fullname == "builtins.tuple"
callee_arg_types = [unpacked_type.args[0]] * len(actuals)
Expand All @@ -2458,8 +2486,10 @@ def check_argument_types(
assert len(actual_types) == len(actuals) == len(actual_kinds)

if len(callee_arg_types) != len(actual_types):
# TODO: Improve error message
self.chk.fail("Invalid number of arguments", context)
if len(actual_types) > len(callee_arg_types):
self.chk.msg.too_many_arguments(callee, context)
else:
self.chk.msg.too_few_arguments(callee, context, None)
continue

assert len(callee_arg_types) == len(actual_types)
Expand Down Expand Up @@ -2764,11 +2794,17 @@ def infer_overload_return_type(
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info so we can
# Return early if possible; otherwise record info, so we can
# check for ambiguity due to 'Any' below.
if not args_contain_any:
return ret_type, infer_type
matches.append(typ)
p_infer_type = get_proper_type(infer_type)
if isinstance(p_infer_type, CallableType):
# Prefer inferred types if possible, this will avoid false triggers for
# Any-ambiguity caused by arguments with Any passed to generic overloads.
matches.append(p_infer_type)
else:
matches.append(typ)
return_types.append(ret_type)
inferred_types.append(infer_type)
type_maps.append(m)
Expand Down Expand Up @@ -4109,6 +4145,12 @@ def visit_index_with_type(
# Visit the index, just to make sure we have a type for it available
self.accept(index)

if isinstance(left_type, TupleType) and any(
isinstance(it, UnpackType) for it in left_type.items
):
# Normalize variadic tuples for consistency.
left_type = expand_type(left_type, {})

if isinstance(left_type, UnionType):
original_type = original_type or left_type
# Don't combine literal types, since we may need them for type narrowing.
Expand All @@ -4129,12 +4171,15 @@ def visit_index_with_type(
if ns is not None:
out = []
for n in ns:
if n < 0:
n += len(left_type.items)
if 0 <= n < len(left_type.items):
out.append(left_type.items[n])
item = self.visit_tuple_index_helper(left_type, n)
if item is not None:
out.append(item)
else:
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
if any(isinstance(t, UnpackType) for t in left_type.items):
self.chk.note(
f"Variadic tuple can have length {left_type.length() - 1}", e
)
return AnyType(TypeOfAny.from_error)
return make_simplified_union(out)
else:
Expand All @@ -4158,6 +4203,46 @@ def visit_index_with_type(
e.method_type = method_type
return result

def visit_tuple_index_helper(self, left: TupleType, n: int) -> Type | None:
unpack_index = find_unpack_in_list(left.items)
if unpack_index is None:
if n < 0:
n += len(left.items)
if 0 <= n < len(left.items):
return left.items[n]
return None
unpack = left.items[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
if isinstance(unpacked, TypeVarTupleType):
# Usually we say that TypeVarTuple can't be split, be in case of
# indexing it seems benign to just return the fallback item, similar
# to what we do when indexing a regular TypeVar.
middle = unpacked.tuple_fallback.args[0]
else:
assert isinstance(unpacked, Instance)
assert unpacked.type.fullname == "builtins.tuple"
middle = unpacked.args[0]
if n >= 0:
if n < unpack_index:
return left.items[n]
if n >= len(left.items) - 1:
# For tuple[int, *tuple[str, ...], int] we allow either index 0 or 1,
# since variadic item may have zero items.
return None
return UnionType.make_union(
[middle] + left.items[unpack_index + 1 : n + 2], left.line, left.column
)
n += len(left.items)
if n <= 0:
# Similar to above, we only allow -1, and -2 for tuple[int, *tuple[str, ...], int]
return None
if n > unpack_index:
return left.items[n]
return UnionType.make_union(
left.items[n - 1 : unpack_index] + [middle], left.line, left.column
)

def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type:
begin: Sequence[int | None] = [None]
end: Sequence[int | None] = [None]
Expand All @@ -4183,7 +4268,11 @@ def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Typ

items: list[Type] = []
for b, e, s in itertools.product(begin, end, stride):
items.append(left_type.slice(b, e, s))
item = left_type.slice(b, e, s)
if item is None:
self.chk.fail(message_registry.AMBIGUOUS_SLICE_OF_VARIADIC_TUPLE, slic)
return AnyType(TypeOfAny.from_error)
items.append(item)
return make_simplified_union(items)

def try_getting_int_literals(self, index: Expression) -> list[int] | None:
Expand All @@ -4192,7 +4281,7 @@ def try_getting_int_literals(self, index: Expression) -> list[int] | None:
Otherwise, returns None.
Specifically, this function is guaranteed to return a list with
one or more ints if one one the following is true:
one or more ints if one the following is true:
1. 'expr' is a IntExpr or a UnaryExpr backed by an IntExpr
2. 'typ' is a LiteralType containing an int
Expand Down Expand Up @@ -4223,11 +4312,30 @@ def try_getting_int_literals(self, index: Expression) -> list[int] | None:
def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) -> Type:
self.check_method_call_by_name("__getitem__", left_type, [index], [ARG_POS], context=index)
# We could return the return type from above, but unions are often better than the join
union = make_simplified_union(left_type.items)
union = self.union_tuple_fallback_item(left_type)
if isinstance(index, SliceExpr):
return self.chk.named_generic_type("builtins.tuple", [union])
return union

def union_tuple_fallback_item(self, left_type: TupleType) -> Type:
# TODO: this duplicates logic in typeops.tuple_fallback().
items = []
for item in left_type.items:
if isinstance(item, UnpackType):
unpacked_type = get_proper_type(item.type)
if isinstance(unpacked_type, TypeVarTupleType):
unpacked_type = get_proper_type(unpacked_type.upper_bound)
if (
isinstance(unpacked_type, Instance)
and unpacked_type.type.fullname == "builtins.tuple"
):
items.append(unpacked_type.args[0])
else:
raise NotImplementedError
else:
items.append(item)
return make_simplified_union(items)

def visit_typeddict_index_expr(
self, td_type: TypedDictType, index: Expression, setitem: bool = False
) -> Type:
Expand Down
31 changes: 22 additions & 9 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,38 @@ def infer_constraints_for_callable(
unpack_type = callee.arg_types[i]
assert isinstance(unpack_type, UnpackType)

# In this case we are binding all of the actuals to *args
# In this case we are binding all the actuals to *args,
# and we want a constraint that the typevar tuple being unpacked
# is equal to a type list of all the actuals.
actual_types = []

unpacked_type = get_proper_type(unpack_type.type)
if isinstance(unpacked_type, TypeVarTupleType):
tuple_instance = unpacked_type.tuple_fallback
elif isinstance(unpacked_type, TupleType):
tuple_instance = unpacked_type.partial_fallback
else:
assert False, "mypy bug: unhandled constraint inference case"

for actual in actuals:
actual_arg_type = arg_types[actual]
if actual_arg_type is None:
continue

actual_types.append(
mapper.expand_actual_type(
actual_arg_type,
arg_kinds[actual],
callee.arg_names[i],
callee.arg_kinds[i],
)
expanded_actual = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)

unpacked_type = get_proper_type(unpack_type.type)
if arg_kinds[actual] != ARG_STAR or isinstance(
get_proper_type(actual_arg_type), TupleType
):
actual_types.append(expanded_actual)
else:
# If we are expanding an iterable inside * actual, append a homogeneous item instead
actual_types.append(
UnpackType(tuple_instance.copy_modified(args=[expanded_actual]))
)

if isinstance(unpacked_type, TypeVarTupleType):
constraints.append(
Constraint(
Expand Down
4 changes: 3 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def visit_instance(self, t: Instance) -> ProperType:
# Valid erasure for *Ts is *tuple[Any, ...], not just Any.
if isinstance(tv, TypeVarTupleType):
args.append(
tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
UnpackType(
tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
)
)
else:
args.append(AnyType(TypeOfAny.special_form))
Expand Down
1 change: 1 addition & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
INCOMPATIBLE_TYPES_IN_CAPTURE: Final = ErrorMessage("Incompatible types in capture pattern")
MUST_HAVE_NONE_RETURN_TYPE: Final = ErrorMessage('The return type of "{}" must be None')
TUPLE_INDEX_OUT_OF_RANGE: Final = ErrorMessage("Tuple index out of range")
AMBIGUOUS_SLICE_OF_VARIADIC_TUPLE: Final = ErrorMessage("Ambiguous slice of a variadic tuple")
INVALID_SLICE_INDEX: Final = ErrorMessage("Slice index must be an integer, SupportsIndex or None")
CANNOT_INFER_LAMBDA_TYPE: Final = ErrorMessage("Cannot infer type of lambda")
CANNOT_ACCESS_INIT: Final = (
Expand Down
55 changes: 47 additions & 8 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2416,14 +2416,53 @@ def copy_modified(
items = self.items
return TupleType(items, fallback, self.line, self.column)

def slice(self, begin: int | None, end: int | None, stride: int | None) -> TupleType:
return TupleType(
self.items[begin:end:stride],
self.partial_fallback,
self.line,
self.column,
self.implicit,
)
def slice(self, begin: int | None, end: int | None, stride: int | None) -> TupleType | None:
if any(isinstance(t, UnpackType) for t in self.items):
total = len(self.items)
unpack_index = find_unpack_in_list(self.items)
assert unpack_index is not None
if begin is None and end is None:
# We special-case this to support reversing variadic tuples.
# General support for slicing is tricky, so we handle only simple cases.
if stride == -1:
slice_items = self.items[::-1]
elif stride is None or stride == 1:
slice_items = self.items
else:
return None
elif (begin is None or unpack_index >= begin >= 0) and (
end is not None and unpack_index >= end >= 0
):
# Start and end are in the prefix, everything works in this case.
slice_items = self.items[begin:end:stride]
elif (begin is not None and unpack_index - total < begin < 0) and (
end is None or unpack_index - total < end < 0
):
# Start and end are in the suffix, everything works in this case.
slice_items = self.items[begin:end:stride]
elif (begin is None or unpack_index >= begin >= 0) and (
end is None or unpack_index - total < end < 0
):
# Start in the prefix, end in the suffix, we can support only trivial strides.
if stride is None or stride == 1:
slice_items = self.items[begin:end:stride]
else:
return None
elif (begin is not None and unpack_index - total < begin < 0) and (
end is not None and unpack_index >= end >= 0
):
# Start in the suffix, end in the prefix, we can support only trivial strides.
if stride is None or stride == -1:
slice_items = self.items[begin:end:stride]
else:
return None
else:
# TODO: there some additional cases we can support for homogeneous variadic
# items, we can "eat away" finite number of items.
return None
else:
slice_items = self.items[begin:end:stride]
return TupleType(slice_items, self.partial_fallback, self.line, self.column, self.implicit)


class TypedDictType(ProperType):
Expand Down
3 changes: 1 addition & 2 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6501,8 +6501,7 @@ eggs = lambda: 'eggs'
reveal_type(func(eggs)) # N: Revealed type is "def (builtins.str) -> builtins.str"

spam: Callable[..., str] = lambda x, y: 'baz'
reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> Any"

reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> builtins.str"
[builtins fixtures/paramspec.pyi]

[case testGenericOverloadOverlapWithType]
Expand Down
5 changes: 2 additions & 3 deletions test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,6 @@ def zip(*i: Iterable[Any]) -> Iterator[Tuple[Any, ...]]: ...
def zip(i): ...

def g(t: Tuple):
# Ideally, we'd infer that these are iterators of tuples
reveal_type(zip(*t)) # N: Revealed type is "typing.Iterator[Any]"
reveal_type(zip(t)) # N: Revealed type is "typing.Iterator[Any]"
reveal_type(zip(*t)) # N: Revealed type is "typing.Iterator[builtins.tuple[Any, ...]]"
reveal_type(zip(t)) # N: Revealed type is "typing.Iterator[Tuple[Any]]"
[builtins fixtures/tuple.pyi]
Loading

0 comments on commit 0291ec9

Please sign in to comment.