Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix special cases of kwargs + TypeVarTuple #17512

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
36 changes: 35 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,38 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# but not vice versa.
# TODO: infer more from prefixes when possible.
if unpack_present is not None and not cactual.param_spec():
# if there's anything that would get ignored later, handle them now.
# (assumes that if there's a kwarg on template, it should get matched.
# ... which isn't always a right assumption)
for arg in template.formal_arguments():
if arg.pos:
continue

# this arg will get dropped in `repack_callable_args` later;
# handle it instead! ... this isn't very thorough though
other = cactual.argument_by_name(arg.name)
assert not other or arg.required
if not other:
continue

# for now, simplify the problem: if `other` isn't at the end,
# or kw-only, give up
if (
other.pos is not None
and other.pos + 1 != cactual.max_possible_positional_args()
):
continue

cactual = cactual.copy_modified(
cactual.arg_types,
[
k if i != other.pos else ArgKind.ARG_NAMED
for (i, k) in enumerate(cactual.arg_kinds)
],
cactual.arg_names,
)
res.extend(infer_constraints(arg.typ, other.typ, self.direction))

# We need to re-normalize args to the form they appear in tuples,
# for callables we always pack the suffix inside another tuple.
unpack = template.arg_types[unpack_present]
Expand Down Expand Up @@ -1426,7 +1458,9 @@ def repack_callable_args(callable: CallableType, tuple_type: TypeInfo) -> list[T
in e.g. a TupleType).
"""
if ARG_STAR not in callable.arg_kinds:
return callable.arg_types
return [
t for (t, k) in zip(callable.arg_types, callable.arg_kinds) if k != ArgKind.ARG_NAMED
]
star_index = callable.arg_kinds.index(ARG_STAR)
arg_types = callable.arg_types[:star_index]
star_type = callable.arg_types[star_index]
Expand Down
4 changes: 3 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def visit_instance(self, t: Instance) -> ProperType:
args: list[Type] = []
for tv in t.type.defn.type_vars:
# Valid erasure for *Ts is *tuple[Any, ...], not just Any.
# TODO: try updating this to use TupleType
if isinstance(tv, TypeVarTupleType):
args.append(
UnpackType(
Expand Down Expand Up @@ -212,7 +213,8 @@ def visit_tuple_type(self, t: TupleType) -> Type:

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if self.erase_id(t.id):
return t.tuple_fallback.copy_modified(args=[self.replacement])
# TODO: should t.tuple_fallback become a TupleType?
return TupleType([], t.tuple_fallback, erased_typevartuple=True)
return t

def visit_param_spec(self, t: ParamSpecType) -> Type:
Expand Down
45 changes: 29 additions & 16 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,15 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
)
tvar = tvars_middle[0]
assert isinstance(tvar, TypeVarTupleType)
variables = {tvar.id: TupleType(list(args_middle), tvar.tuple_fallback)}

tvar_value: Type = TupleType(list(args_middle), tvar.tuple_fallback)
if len(args_middle) == 1:
# prevent nested Unpacks
middle_arg = get_proper_type(args_middle[0])
if isinstance(middle_arg, UnpackType):
tvar_value = middle_arg.type

variables = {tvar.id: tvar_value}
instance_args = args_prefix + args_suffix
tvars = tvars_prefix + tvars_suffix
else:
Expand Down Expand Up @@ -207,7 +215,7 @@ def visit_erased_type(self, t: ErasedType) -> Type:
return t

def visit_instance(self, t: Instance) -> Type:
args = self.expand_types_with_unpack(list(t.args))
args = self.expand_types_with_unpack(list(t.args))[0]
if t.type.fullname == "builtins.tuple":
# Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
arg = args[0]
Expand Down Expand Up @@ -291,23 +299,24 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
# example is non-normalized types when called from semanal.py.
return UnpackType(t.type.accept(self))

def expand_unpack(self, t: UnpackType) -> list[Type]:
# TODO: there must be a cleaner way to not discard `erased_typevartuple`
def expand_unpack(self, t: UnpackType) -> tuple[list[Type], bool]:
assert isinstance(t.type, TypeVarTupleType)
repl = get_proper_type(self.variables.get(t.type.id, t.type))
if isinstance(repl, UnpackType):
repl = get_proper_type(repl.type)
if isinstance(repl, TupleType):
return repl.items
return repl.items, repl.erased_typevartuple
elif (
isinstance(repl, Instance)
and repl.type.fullname == "builtins.tuple"
or isinstance(repl, TypeVarTupleType)
):
return [UnpackType(typ=repl)]
return [UnpackType(typ=repl)], False
elif isinstance(repl, (AnyType, UninhabitedType)):
# Replace *Ts = Any with *Ts = *tuple[Any, ...] and same for Never.
# These types may appear here as a result of user error or failed inference.
return [UnpackType(t.type.tuple_fallback.copy_modified(args=[repl]))]
return [UnpackType(t.type.tuple_fallback.copy_modified(args=[repl]))], False
else:
raise RuntimeError(f"Invalid type replacement to expand: {repl}")

Expand All @@ -329,13 +338,12 @@ def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> l
expanded_tuple = var_arg_type.accept(self)
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
expanded_items = expanded_tuple.items
fallback = var_arg_type.partial_fallback
new_unpack = UnpackType(TupleType(expanded_items, fallback))
new_unpack = UnpackType(var_arg_type.copy_modified(items=expanded_items))
elif isinstance(var_arg_type, TypeVarTupleType):
# We have plain Unpack[Ts]
fallback = var_arg_type.tuple_fallback
expanded_items = self.expand_unpack(var_arg)
new_unpack = UnpackType(TupleType(expanded_items, fallback))
expanded_items, etv = self.expand_unpack(var_arg)
new_unpack = UnpackType(TupleType(expanded_items, fallback, erased_typevartuple=etv))
else:
# We have invalid type in Unpack. This can happen when expanding aliases
# to Callable[[*Invalid], Ret]
Expand Down Expand Up @@ -415,18 +423,23 @@ def visit_overloaded(self, t: Overloaded) -> Type:
items.append(new_item)
return Overloaded(items)

def expand_types_with_unpack(self, typs: Sequence[Type]) -> list[Type]:
def expand_types_with_unpack(self, typs: Sequence[Type]) -> tuple[list[Type], bool]:
"""Expands a list of types that has an unpack."""
items: list[Type] = []
met_erased_typevartuple = False # not sure this is the right behavior.
for item in typs:
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
items.extend(self.expand_unpack(item))
its, etv = self.expand_unpack(item)
met_erased_typevartuple = met_erased_typevartuple or etv
items.extend(its)
else:
items.append(item.accept(self))
return items

assert not met_erased_typevartuple or len(typs) == 1
return items, met_erased_typevartuple

def visit_tuple_type(self, t: TupleType) -> Type:
items = self.expand_types_with_unpack(t.items)
items, etv = self.expand_types_with_unpack(t.items)
if len(items) == 1:
# Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
item = items[0]
Expand All @@ -441,7 +454,7 @@ def visit_tuple_type(self, t: TupleType) -> Type:
return unpacked
fallback = t.partial_fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(items=items, fallback=fallback)
return t.copy_modified(items=items, fallback=fallback, erased_typevartuple=etv)

def visit_typeddict_type(self, t: TypedDictType) -> Type:
fallback = t.fallback.accept(self)
Expand Down Expand Up @@ -480,7 +493,7 @@ def visit_type_type(self, t: TypeType) -> Type:
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Target of the type alias cannot contain type variables (not bound by the type
# alias itself), so we just expand the arguments.
args = self.expand_types_with_unpack(t.args)
args = self.expand_types_with_unpack(t.args)[0]
# TODO: normalize if target is Tuple, and args are [*tuple[X, ...]]?
return t.copy_modified(args=args)

Expand Down
20 changes: 18 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ def visit_instance(self, left: Instance) -> bool:
if isinstance(right, TupleType) and right.partial_fallback.type.is_enum:
return self._is_subtype(left, mypy.typeops.tuple_fallback(right))
if isinstance(right, TupleType):
if right.erased_typevartuple:
return True # treat it like Any

if len(right.items) == 1:
# Non-normalized Tuple type (may be left after semantic analysis
# because semanal_typearg visitor is not a type translator).
Expand Down Expand Up @@ -784,6 +787,8 @@ def visit_tuple_type(self, left: TupleType) -> bool:
return True
return False
elif isinstance(right, TupleType):
if right.erased_typevartuple:
return True # treat it like Any
# If right has a variadic unpack this needs special handling. If there is a TypeVarTuple
# unpack, item count must coincide. If the left has variadic unpack but right
# doesn't have one, we will fall through to False down the line.
Expand Down Expand Up @@ -824,6 +829,8 @@ def variadic_tuple_subtype(self, left: TupleType, right: TupleType) -> bool:
right_unpack = right.items[right_unpack_index]
assert isinstance(right_unpack, UnpackType)
right_unpacked = get_proper_type(right_unpack.type)
if isinstance(right_unpacked, TupleType) and right_unpacked.erased_typevartuple:
return True # treat it as Any
if not isinstance(right_unpacked, Instance):
# This case should be handled by the caller.
return False
Expand Down Expand Up @@ -1602,6 +1609,15 @@ def are_parameters_compatible(
if are_trivial_parameters(right) and not is_proper_subtype:
return True
trivial_suffix = is_trivial_suffix(right) and not is_proper_subtype
# erased typevartuples, like erased paramspecs or erased typevars are trivial
if right_star and isinstance(right_star.typ, UnpackType):
right_star_inner_type = get_proper_type(right_star.typ.type)
trivial_varargs = (
isinstance(right_star_inner_type, TupleType)
and right_star_inner_type.erased_typevartuple
)
else:
trivial_varargs = False

if (
right.arg_kinds == [ARG_STAR]
Expand Down Expand Up @@ -1644,7 +1660,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
if right_arg is None:
return False
if left_arg is None:
return not allow_partial_overlap and not trivial_suffix
return not allow_partial_overlap and not trivial_suffix and not trivial_varargs
return not is_compat(right_arg.typ, left_arg.typ)

if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
Expand Down Expand Up @@ -1673,7 +1689,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None and not trivial_suffix:
if right_star is not None and not trivial_suffix and not trivial_varargs:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand Down
2 changes: 1 addition & 1 deletion mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,7 @@ def make_call(*items: tuple[str, str | None]) -> CallExpr:
class TestExpandTypeLimitGetProperType(TestCase):
# WARNING: do not increase this number unless absolutely necessary,
# and you understand what you are doing.
ALLOWED_GET_PROPER_TYPES = 9
ALLOWED_GET_PROPER_TYPES = 10
A5rocks marked this conversation as resolved.
Show resolved Hide resolved

@skipUnless(mypy.expandtype.__file__.endswith(".py"), "Skip for compiled mypy")
def test_count_get_proper_type(self) -> None:
Expand Down
37 changes: 31 additions & 6 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,6 +2126,8 @@ def with_normalized_var_args(self) -> Self:
# this should be done once in semanal_typeargs.py for user-defined types,
# and we ourselves rarely construct such type.
return self
if unpacked.erased_typevartuple:
return self
unpack_index = find_unpack_in_list(unpacked.items)
if unpack_index == 0 and len(unpacked.items) > 1:
# Already normalized.
Expand Down Expand Up @@ -2352,13 +2354,16 @@ class TupleType(ProperType):
a tuple base class. Use mypy.typeops.tuple_fallback to calculate the
precise fallback type derived from item types.
implicit: If True, derived from a tuple expression (t,....) instead of Tuple[t, ...]
erased_typevartuple: If True, this came from a (now-erased) TypeVarTuple. This
indicates that this tuple should act more like an Any.
"""

__slots__ = ("items", "partial_fallback", "implicit")
__slots__ = ("items", "partial_fallback", "implicit", "erased_typevartuple")

items: list[Type]
partial_fallback: Instance
implicit: bool
erased_typevartuple: bool

def __init__(
self,
Expand All @@ -2367,11 +2372,13 @@ def __init__(
line: int = -1,
column: int = -1,
implicit: bool = False,
erased_typevartuple: bool = False,
) -> None:
super().__init__(line, column)
self.partial_fallback = fallback
self.items = items
self.implicit = implicit
self.erased_typevartuple = erased_typevartuple

def can_be_true_default(self) -> bool:
if self.can_be_any_bool():
Expand Down Expand Up @@ -2412,19 +2419,24 @@ def accept(self, visitor: TypeVisitor[T]) -> T:
return visitor.visit_tuple_type(self)

def __hash__(self) -> int:
return hash((tuple(self.items), self.partial_fallback))
return hash((tuple(self.items), self.partial_fallback, self.erased_typevartuple))

def __eq__(self, other: object) -> bool:
if not isinstance(other, TupleType):
return NotImplemented
return self.items == other.items and self.partial_fallback == other.partial_fallback
return (
self.items == other.items
and self.partial_fallback == other.partial_fallback
and self.erased_typevartuple == other.erased_typevartuple
)

def serialize(self) -> JsonDict:
return {
".class": "TupleType",
"items": [t.serialize() for t in self.items],
"partial_fallback": self.partial_fallback.serialize(),
"implicit": self.implicit,
"erased_typevartuple": self.erased_typevartuple,
}

@classmethod
Expand All @@ -2434,16 +2446,25 @@ def deserialize(cls, data: JsonDict) -> TupleType:
[deserialize_type(t) for t in data["items"]],
Instance.deserialize(data["partial_fallback"]),
implicit=data["implicit"],
erased_typevartuple=data["erased_typevartuple"],
)

def copy_modified(
self, *, fallback: Instance | None = None, items: list[Type] | None = None
self,
*,
fallback: Instance | None = None,
items: list[Type] | None = None,
erased_typevartuple: bool | None = None,
) -> TupleType:
if fallback is None:
fallback = self.partial_fallback
if items is None:
items = self.items
return TupleType(items, fallback, self.line, self.column)
if erased_typevartuple is None:
erased_typevartuple = self.erased_typevartuple
return TupleType(
items, fallback, self.line, self.column, erased_typevartuple=erased_typevartuple
)

def slice(
self, begin: int | None, end: int | None, stride: int | None, *, fallback: Instance | None
Expand Down Expand Up @@ -2496,7 +2517,9 @@ def slice(
return None
else:
slice_items = self.items[begin:end:stride]
return TupleType(slice_items, fallback, self.line, self.column, self.implicit)
return TupleType(
slice_items, fallback, self.line, self.column, self.implicit, self.erased_typevartuple
)


class TypedDictType(ProperType):
Expand Down Expand Up @@ -3373,6 +3396,8 @@ def visit_overloaded(self, t: Overloaded) -> str:
return f"Overload({', '.join(a)})"

def visit_tuple_type(self, t: TupleType) -> str:
if t.erased_typevartuple:
return "tuple[...]"
s = self.list_str(t.items) or "()"
tuple_name = "tuple" if self.options.use_lowercase_names() else "Tuple"
if t.partial_fallback and t.partial_fallback.type:
Expand Down
1 change: 1 addition & 0 deletions mypy/typevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def fill_typevars_with_any(typ: TypeInfo) -> Instance | TupleType:
args: list[Type] = []
for tv in typ.defn.type_vars:
# Valid erasure for *Ts is *tuple[Any, ...], not just Any.
# TODO: use TupleType
if isinstance(tv, TypeVarTupleType):
args.append(
UnpackType(tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)]))
Expand Down
1 change: 1 addition & 0 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,7 @@ get_items(b) # E: Argument 1 to "get_items" has incompatible type "Bad"; expect
# N: def items(self) -> Tuple[Never, ...] \
# N: Got: \
# N: def items(self) -> List[int]
# TODO: this *should* work.
match(b) # E: Argument 1 to "match" has incompatible type "Bad"; expected "PC[Unpack[Tuple[Never, ...]]]" \
# N: Following member(s) of "Bad" have conflicts: \
# N: Expected: \
Expand Down