diff --git a/mypy/constraints.py b/mypy/constraints.py index 49a2aea8fa05..8ff643d2824c 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -1076,6 +1076,37 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # but not vice versa. # TODO: infer more from prefixes when possible. if unpack_present is not None and not cactual.param_spec(): + # if there's anything that would get ignored later, handle them now. + # (assumes that if there's a kwarg on template, it should get matched. + # ... which isn't always a right assumption) + for arg in template.formal_arguments(): + if arg.pos: + continue + + # this arg will get dropped in `repack_callable_args` later; + # handle it instead! ... this isn't very thorough though + other = cactual.argument_by_name(arg.name) + if not other: + continue + + # for now, simplify the problem: if `other` isn't at the end, + # or kw-only, give up + if ( + other.pos is not None + and other.pos + 1 != cactual.max_possible_positional_args() + ): + continue + + cactual = cactual.copy_modified( + cactual.arg_types, + [ + k if i != other.pos else ArgKind.ARG_NAMED + for (i, k) in enumerate(cactual.arg_kinds) + ], + cactual.arg_names, + ) + res.extend(infer_constraints(arg.typ, other.typ, self.direction)) + # We need to re-normalize args to the form they appear in tuples, # for callables we always pack the suffix inside another tuple. unpack = template.arg_types[unpack_present] @@ -1426,7 +1457,9 @@ def repack_callable_args(callable: CallableType, tuple_type: TypeInfo) -> list[T in e.g. a TupleType). """ if ARG_STAR not in callable.arg_kinds: - return callable.arg_types + return [ + t for (t, k) in zip(callable.arg_types, callable.arg_kinds) if k != ArgKind.ARG_NAMED + ] star_index = callable.arg_kinds.index(ARG_STAR) arg_types = callable.arg_types[:star_index] star_type = callable.arg_types[star_index] diff --git a/mypy/subtypes.py b/mypy/subtypes.py index c76b3569fdd4..45e940315390 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1603,6 +1603,17 @@ def are_parameters_compatible( return True trivial_suffix = is_trivial_suffix(right) and not is_proper_subtype + # def _(*a: Unpack[tuple[object, ...]]) allows any number of arguments, not just infinite. + if right_star and isinstance(right_star.typ, UnpackType): + right_star_inner_type = get_proper_type(right_star.typ.type) + trivial_varargs = ( + isinstance(right_star_inner_type, Instance) + and right_star_inner_type.type.fullname == "builtins.tuple" + and len(right_star_inner_type.args) == 1 + ) + else: + trivial_varargs = False + if ( right.arg_kinds == [ARG_STAR] and isinstance(get_proper_type(right.arg_types[0]), AnyType) @@ -1640,14 +1651,17 @@ def are_parameters_compatible( # Furthermore, if we're checking for compatibility in all cases, # we confirm that if R accepts an infinite number of arguments, # L must accept the same. - def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | None) -> bool: + def _incompatible( + left_arg: FormalArgument | None, right_arg: FormalArgument | None, varargs: bool + ) -> bool: if right_arg is None: return False if left_arg is None: - return not allow_partial_overlap and not trivial_suffix + return not (allow_partial_overlap or trivial_suffix or (varargs and trivial_varargs)) + return not is_compat(right_arg.typ, left_arg.typ) - if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2): + if _incompatible(left_star, right_star, True) or _incompatible(left_star2, right_star2, False): return False # Phase 1b: Check non-star args: for every arg right can accept, left must @@ -1672,8 +1686,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1c: Check var args. Right has an infinite series of optional positional # arguments. Get all further positional args of left, and make sure # they're more general than the corresponding member in right. - # TODO: are we handling UnpackType correctly here? - if right_star is not None and not trivial_suffix: + if right_star is not None and not trivial_suffix and not trivial_varargs: # Synthesize an anonymous formal argument for the right right_by_position = right.try_synthesizing_arg_from_vararg(None) assert right_by_position is not None