diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a3b28a3e24de..c591730e431e 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -14,6 +14,7 @@ # Circular import; done in the function instead. # import mypy.solve from mypy.nodes import ( + ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, @@ -1692,11 +1693,30 @@ def unify_generic_callable( return_constraint_direction = mypy.constraints.SUBTYPE_OF constraints: list[mypy.constraints.Constraint] = [] - for arg_type, target_arg_type in zip(type.arg_types, target.arg_types): - c = mypy.constraints.infer_constraints( - arg_type, target_arg_type, mypy.constraints.SUPERTYPE_OF - ) - constraints.extend(c) + # check by names + argument_names_map = {} + + for i in range(len(target.arg_types)): + if target.arg_names[i] and target.arg_kinds[i] != ARG_POS: + argument_names_map[target.arg_names[i]] = target.arg_types[i] + + for i in range(len(type.arg_types)): + if type.arg_names[i] and type.arg_names[i] in argument_names_map: + c = mypy.constraints.infer_constraints( + argument_names_map[type.arg_names[i]], + type.arg_types[i], + mypy.constraints.SUPERTYPE_OF, + ) + constraints.extend(c) + + # check pos-only arguments + for arg, target_arg in zip(type.formal_arguments(), target.formal_arguments()): + if arg.pos is not None and target_arg.pos is not None: + c = mypy.constraints.infer_constraints( + arg.typ, target_arg.typ, mypy.constraints.SUPERTYPE_OF + ) + constraints.extend(c) + if not ignore_return: c = mypy.constraints.infer_constraints( type.ret_type, target.ret_type, return_constraint_direction diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 4209f4ec9164..f8afd9057ba4 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6545,3 +6545,48 @@ class Snafu(object): reveal_type(Snafu().snafu('123')) # N: Revealed type is "builtins.str" reveal_type(Snafu.snafu('123')) # N: Revealed type is "builtins.str" [builtins fixtures/staticmethod.pyi] + +[case testOverloadedFunctionWithTypevarMissing] +import typing + +class A: ... + +T = typing.TypeVar("T", bound=A) + +@typing.overload +def f(a: T) -> T: ... + +@typing.overload +def f(*, copy: bool = False) -> None: ... + +def f(a: T = ..., *, copy: bool = False) -> T: + ... + +reveal_type(f) # N: Revealed type is "Overload(def [T <: __main__.A] (a: T`-1) -> T`-1, def (*, copy: builtins.bool =))" + +[case testOverloadingWithTypeVarWhereTargetIsPosOnly] +from typing import overload, TypeVar, Union + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + +@overload +def f( + *, + x: T1, +) -> None: + ... + +@overload +def f( # type: ignore + *, + x: T2, +) -> None: + ... + +def f( + x: Union[T1, T2] +) -> None: + ... + +reveal_type(f) # N: Revealed type is "Overload(def [T1] (*, x: T1`-1), def [T2] (*, x: T2`-1))"