From 746ffd8556437d2eaffc879a2a96a60a8e711c83 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Mon, 29 May 2023 16:00:07 +0900 Subject: [PATCH 1/4] Reapply overloading fixes --- mypy/subtypes.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a3b28a3e24de..902d2c824a48 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1692,11 +1692,30 @@ def unify_generic_callable( return_constraint_direction = mypy.constraints.SUBTYPE_OF constraints: list[mypy.constraints.Constraint] = [] - for arg_type, target_arg_type in zip(type.arg_types, target.arg_types): - c = mypy.constraints.infer_constraints( - arg_type, target_arg_type, mypy.constraints.SUPERTYPE_OF - ) - constraints.extend(c) + # check by names + argument_names_map = {} + + for i in range(len(type.arg_types)): + if type.arg_names[i]: + argument_names_map[type.arg_names[i]] = type.arg_types[i] + + for i in range(len(target.arg_types)): + if target.arg_names[i] and target.arg_names[i] in argument_names_map: + c = mypy.constraints.infer_constraints( + argument_names_map[target.arg_names[i]], + target.arg_types[i], + mypy.constraints.SUPERTYPE_OF, + ) + constraints.extend(c) + + # check pos-only arguments + for arg, target_arg in zip(type.formal_arguments(), target.formal_arguments()): + if arg.pos is not None and target_arg.pos is not None: + c = mypy.constraints.infer_constraints( + arg.typ, target_arg.typ, mypy.constraints.SUPERTYPE_OF + ) + constraints.extend(c) + if not ignore_return: c = mypy.constraints.infer_constraints( type.ret_type, target.ret_type, return_constraint_direction From d067542870f8bc5cb7bd68f8b3d8022fa3018cba Mon Sep 17 00:00:00 2001 From: A5rocks Date: Mon, 29 May 2023 16:05:27 +0900 Subject: [PATCH 2/4] Defensively account for pos-only arguments --- mypy/subtypes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 902d2c824a48..bff1d93bc363 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -14,6 +14,7 @@ # Circular import; done in the function instead. # import mypy.solve from mypy.nodes import ( + ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, @@ -1696,7 +1697,7 @@ def unify_generic_callable( argument_names_map = {} for i in range(len(type.arg_types)): - if type.arg_names[i]: + if type.arg_names[i] and type.arg_kinds[i] != ARG_POS: argument_names_map[type.arg_names[i]] = type.arg_types[i] for i in range(len(target.arg_types)): From 712c88ed3c4b7fc09ded15a4b616b25ea8037d43 Mon Sep 17 00:00:00 2001 From: EXPLOSION Date: Wed, 14 Jun 2023 23:24:44 +0000 Subject: [PATCH 3/4] Add a trivial test --- test-data/unit/check-overloading.test | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 4209f4ec9164..4ae38c06728f 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6545,3 +6545,21 @@ class Snafu(object): reveal_type(Snafu().snafu('123')) # N: Revealed type is "builtins.str" reveal_type(Snafu.snafu('123')) # N: Revealed type is "builtins.str" [builtins fixtures/staticmethod.pyi] + +[case testOverloadedFunctionWithTypevarMissing] +import typing + +class A: ... + +T = typing.TypeVar("T", bound=A) + +@typing.overload +def f(a: T) -> T: ... + +@typing.overload +def f(*, copy: bool = False) -> None: ... + +def f(a: T = ..., *, copy: bool = False) -> T: + ... + +reveal_type(f) # N: Revealed type is "Overload(def [T <: __main__.A] (a: T`-1) -> T`-1, def (*, copy: builtins.bool =))" From 173ca78bda8cd9dfe5076cda916b8b2c7f10630c Mon Sep 17 00:00:00 2001 From: A5rocks Date: Sun, 18 Jun 2023 10:24:25 +0900 Subject: [PATCH 4/4] Fix confusion about `target` --- mypy/subtypes.py | 14 +++++++------- test-data/unit/check-overloading.test | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index bff1d93bc363..c591730e431e 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1696,15 +1696,15 @@ def unify_generic_callable( # check by names argument_names_map = {} - for i in range(len(type.arg_types)): - if type.arg_names[i] and type.arg_kinds[i] != ARG_POS: - argument_names_map[type.arg_names[i]] = type.arg_types[i] - for i in range(len(target.arg_types)): - if target.arg_names[i] and target.arg_names[i] in argument_names_map: + if target.arg_names[i] and target.arg_kinds[i] != ARG_POS: + argument_names_map[target.arg_names[i]] = target.arg_types[i] + + for i in range(len(type.arg_types)): + if type.arg_names[i] and type.arg_names[i] in argument_names_map: c = mypy.constraints.infer_constraints( - argument_names_map[target.arg_names[i]], - target.arg_types[i], + argument_names_map[type.arg_names[i]], + type.arg_types[i], mypy.constraints.SUPERTYPE_OF, ) constraints.extend(c) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 4ae38c06728f..f8afd9057ba4 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6563,3 +6563,30 @@ def f(a: T = ..., *, copy: bool = False) -> T: ... reveal_type(f) # N: Revealed type is "Overload(def [T <: __main__.A] (a: T`-1) -> T`-1, def (*, copy: builtins.bool =))" + +[case testOverloadingWithTypeVarWhereTargetIsPosOnly] +from typing import overload, TypeVar, Union + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + +@overload +def f( + *, + x: T1, +) -> None: + ... + +@overload +def f( # type: ignore + *, + x: T2, +) -> None: + ... + +def f( + x: Union[T1, T2] +) -> None: + ... + +reveal_type(f) # N: Revealed type is "Overload(def [T1] (*, x: T1`-1), def [T2] (*, x: T2`-1))"