From 5cdb753f74f5807887c40e8aee138291d1f5b920 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 4 Jun 2024 03:02:44 +0200 Subject: [PATCH 1/6] Reject ParamSpec-typed callables calls with insufficient arguments --- mypy/checkexpr.py | 10 +++- .../unit/check-parameter-specification.test | 51 +++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 0a4af069ea17..7d300f3a2bcb 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1745,7 +1745,11 @@ def check_callable_call( ) param_spec = callee.param_spec() - if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]: + if ( + param_spec is not None + and arg_kinds == [ARG_STAR, ARG_STAR2] + and len(formal_to_actual) == 2 + ): arg1 = self.accept(args[0]) arg2 = self.accept(args[1]) if ( @@ -2351,6 +2355,10 @@ def check_argument_count( # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False + elif callee.param_spec() is not None: + if not formal_to_actual[i]: + self.msg.too_few_arguments(callee, context, actual_names) + ok = False return ok def check_for_extra_actual_arguments( diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index cab7d2bf6819..63a5e9cb1777 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2204,3 +2204,54 @@ parametrize(_test, Case(1, b=2), Case(3, b=4)) parametrize(_test, Case(1, 2), Case(3)) parametrize(_test, Case(1, 2), Case(3, b=4)) [builtins fixtures/paramspec.pyi] + +[case testRunParamSpecInsufficientArgs] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable + +_P = ParamSpec("_P") + +def run(predicate: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> None: + predicate() # E: Too few arguments + predicate(*args) # E: Too few arguments + predicate(**kwargs) # E: Too few arguments + predicate(*args, **kwargs) + +[builtins fixtures/paramspec.pyi] + +[case testRunParamSpecConcatenateInsufficientArgs] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable + +_P = ParamSpec("_P") + +def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs: _P.kwargs) -> None: + predicate() # E: Too few arguments + predicate(1) # E: Too few arguments + predicate(1, *args) # E: Too few arguments + predicate(1, *args) # E: Too few arguments + predicate(1, **kwargs) # E: Too few arguments + predicate(*args, **kwargs) # E: Argument 1 has incompatible type "*_P.args"; expected "int" + predicate(1, *args, **kwargs) + +[builtins fixtures/paramspec.pyi] + +[case testRunParamSpecConcatenateInsufficientArgsInDecorator] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable + +P = ParamSpec("P") + +def decorator(fn: Callable[Concatenate[str, P], None]) -> Callable[P, None]: + def inner(*args: P.args, **kwargs: P.kwargs) -> None: + fn("value") # E: Too few arguments + fn("value", *args) # E: Too few arguments + fn("value", **kwargs) # E: Too few arguments + fn(*args, **kwargs) # E: Argument 1 has incompatible type "*P.args"; expected "str" + fn("value", *args, **kwargs) + return inner + +@decorator +def foo(s: str, s2: str) -> None: ... + +[builtins fixtures/paramspec.pyi] From 3b2297f1b370549695dae895045282831b901a64 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 4 Jun 2024 04:46:21 +0200 Subject: [PATCH 2/6] Reuse params preprocessing logic for generic functions --- mypy/checkexpr.py | 74 ++++++++++++------- .../unit/check-parameter-specification.test | 64 +++++++++++++++- 2 files changed, 109 insertions(+), 29 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7d300f3a2bcb..fcf61d73579c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1716,33 +1716,9 @@ def check_callable_call( callee = callee.copy_modified(ret_type=fresh_ret_type) if callee.is_generic(): - need_refresh = any( - isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables + callee, formal_to_actual = self.adjust_generic_callable_params_mapping( + callee, args, arg_kinds, arg_names, formal_to_actual, context ) - callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context(callee, context) - if need_refresh: - # Argument kinds etc. may have changed due to - # ParamSpec or TypeVarTuple variables being replaced with an arbitrary - # number of arguments; recalculate actual-to-formal map - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) - callee = self.infer_function_type_arguments( - callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context - ) - if need_refresh: - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) param_spec = callee.param_spec() if ( @@ -2633,7 +2609,7 @@ def check_overload_call( arg_types = self.infer_arg_types_in_empty_context(args) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( - arg_types, arg_kinds, arg_names, callee + args, arg_types, arg_kinds, arg_names, callee, context ) # Step 2: If the arguments contain a union, we try performing union math first, @@ -2751,12 +2727,52 @@ def check_overload_call( self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context) return result + def adjust_generic_callable_params_mapping( + self, + callee: CallableType, + args: list[Expression], + arg_kinds: list[ArgKind], + arg_names: Sequence[str | None] | None, + formal_to_actual: list[list[int]], + context: Context, + ) -> tuple[CallableType, list[list[int]]]: + need_refresh = any( + isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables + ) + callee = freshen_function_type_vars(callee) + callee = self.infer_function_type_arguments_using_context(callee, context) + if need_refresh: + # Argument kinds etc. may have changed due to + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # number of arguments; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) + callee = self.infer_function_type_arguments( + callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context + ) + if need_refresh: + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) + return callee, formal_to_actual + def plausible_overload_call_targets( self, + args: list[Expression], arg_types: list[Type], arg_kinds: list[ArgKind], arg_names: Sequence[str | None] | None, overload: Overloaded, + context: Context, ) -> list[CallableType]: """Returns all overload call targets that having matching argument counts. @@ -2790,6 +2806,10 @@ def has_shape(typ: Type) -> bool: formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i] ) + if typ.is_generic(): + typ, formal_to_actual = self.adjust_generic_callable_params_mapping( + typ, args, arg_kinds, arg_names, formal_to_actual, context + ) with self.msg.filter_errors(): if self.check_argument_count( diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 63a5e9cb1777..fc12780aa89f 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2211,12 +2211,22 @@ from typing import Callable _P = ParamSpec("_P") -def run(predicate: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs) -> None: +def run(predicate: Callable[_P, None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here predicate() # E: Too few arguments predicate(*args) # E: Too few arguments predicate(**kwargs) # E: Too few arguments predicate(*args, **kwargs) +def fn() -> None: ... +def fn_args(x: int) -> None: ... +def fn_posonly(x: int, /) -> None: ... + +run(fn) +run(fn_args, 1) +run(fn_args, x=1) +run(fn_posonly, 1) +run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run" + [builtins fixtures/paramspec.pyi] [case testRunParamSpecConcatenateInsufficientArgs] @@ -2225,7 +2235,7 @@ from typing import Callable _P = ParamSpec("_P") -def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs: _P.kwargs) -> None: +def run(predicate: Callable[Concatenate[int, _P], None], *args: _P.args, **kwargs: _P.kwargs) -> None: # N: "run" defined here predicate() # E: Too few arguments predicate(1) # E: Too few arguments predicate(1, *args) # E: Too few arguments @@ -2234,6 +2244,22 @@ def run(predicate: Callable[Concatenate[int, _P], str], *args: _P.args, **kwargs predicate(*args, **kwargs) # E: Argument 1 has incompatible type "*_P.args"; expected "int" predicate(1, *args, **kwargs) +def fn() -> None: ... +def fn_args(x: int, y: str) -> None: ... +def fn_posonly(x: int, /) -> None: ... +def fn_posonly_args(x: int, /, y: str) -> None: ... + +run(fn) # E: Argument 1 to "run" has incompatible type "Callable[[], None]"; expected "Callable[[int], None]" +run(fn_args, 1, 'a') # E: Too many arguments for "run" \ + # E: Argument 2 to "run" has incompatible type "int"; expected "str" +run(fn_args, y='a') +run(fn_args, 'a') +run(fn_posonly) +run(fn_posonly, x=1) # E: Unexpected keyword argument "x" for "run" +run(fn_posonly_args) # E: Missing positional argument "y" in call to "run" +run(fn_posonly_args, 'a') +run(fn_posonly_args, y='a') + [builtins fixtures/paramspec.pyi] [case testRunParamSpecConcatenateInsufficientArgsInDecorator] @@ -2255,3 +2281,37 @@ def decorator(fn: Callable[Concatenate[str, P], None]) -> Callable[P, None]: def foo(s: str, s2: str) -> None: ... [builtins fixtures/paramspec.pyi] + +[case testRunParamSpecOverload] +from typing_extensions import ParamSpec, Concatenate +from typing import Callable, overload, NoReturn, TypeVar, Union + +P = ParamSpec("P") +T = TypeVar("T") + +@overload +def capture( + sync_fn: Callable[P, NoReturn], + *args: P.args, + **kwargs: P.kwargs, +) -> int: ... +@overload +def capture( + sync_fn: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> Union[T, int]: ... +def capture( + sync_fn: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> Union[T, int]: + return sync_fn(*args, **kwargs) + +def fn() -> str: return '' +def err() -> NoReturn: ... + +reveal_type(capture(fn)) # N: Revealed type is "Union[builtins.str, builtins.int]" +reveal_type(capture(err)) # N: Revealed type is "builtins.int" + +[builtins fixtures/paramspec.pyi] From a32ad3f6618a93cd065429d9eabb6fa194204bd1 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 4 Jun 2024 05:52:36 +0200 Subject: [PATCH 3/6] Only perform deep expansion on overloads when ParamSpec is present --- mypy/checkexpr.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index fcf61d73579c..25d9c808e2d2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2806,12 +2806,12 @@ def has_shape(typ: Type) -> bool: formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i] ) - if typ.is_generic(): - typ, formal_to_actual = self.adjust_generic_callable_params_mapping( - typ, args, arg_kinds, arg_names, formal_to_actual, context - ) - with self.msg.filter_errors(): + if typ.is_generic() and typ.param_spec() is not None: + typ, formal_to_actual = self.adjust_generic_callable_params_mapping( + typ, args, arg_kinds, arg_names, formal_to_actual, context + ) + if self.check_argument_count( typ, arg_types, arg_kinds, arg_names, formal_to_actual, None ): From 63995e3178b2e8ce4f98ad52ea6848df173c33c8 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Mon, 10 Jun 2024 02:31:38 +0200 Subject: [PATCH 4/6] Tidy up code a bit --- mypy/checkexpr.py | 7 +++---- test-data/unit/check-parameter-specification.test | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 25d9c808e2d2..f3d21d24a428 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2331,10 +2331,9 @@ def check_argument_count( # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False - elif callee.param_spec() is not None: - if not formal_to_actual[i]: - self.msg.too_few_arguments(callee, context, actual_names) - ok = False + elif callee.param_spec() is not None and not formal_to_actual[i]: + self.msg.too_few_arguments(callee, context, actual_names) + ok = False return ok def check_for_extra_actual_arguments( diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index fc12780aa89f..16dcff7f630c 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2283,8 +2283,8 @@ def foo(s: str, s2: str) -> None: ... [builtins fixtures/paramspec.pyi] [case testRunParamSpecOverload] -from typing_extensions import ParamSpec, Concatenate -from typing import Callable, overload, NoReturn, TypeVar, Union +from typing_extensions import ParamSpec +from typing import Callable, NoReturn, TypeVar, Union, overload P = ParamSpec("P") T = TypeVar("T") From 0a658a724dc87ae80f414b8b1d219371afb09f69 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 21 Aug 2024 15:33:39 +0200 Subject: [PATCH 5/6] Always pick ParamSpec-containing overloads as plausible candidates --- mypy/checkexpr.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d415195e433d..d9729384c465 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2786,9 +2786,9 @@ def plausible_overload_call_targets( ) -> list[CallableType]: """Returns all overload call targets that having matching argument counts. - If the given args contains a star-arg (*arg or **kwarg argument), this method - will ensure all star-arg overloads appear at the start of the list, instead - of their usual location. + If the given args contains a star-arg (*arg or **kwarg argument, including + ParamSpec), this method will ensure all star-arg overloads appear at the start + of the list, instead of their usual location. The only exception is if the starred argument is something like a Tuple or a NamedTuple, which has a definitive "shape". If so, we don't move the corresponding @@ -2817,12 +2817,12 @@ def has_shape(typ: Type) -> bool: arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i] ) with self.msg.filter_errors(): - if typ.is_generic() and typ.param_spec() is not None: - typ, formal_to_actual = self.adjust_generic_callable_params_mapping( - typ, args, arg_kinds, arg_names, formal_to_actual, context - ) - - if self.check_argument_count( + if typ.param_spec() is not None: + # ParamSpec can be expanded in a lot of different ways. We may try + # to expand it here instead, but picking an impossible overload + # is safe: it will be filtered out later. + star_matches.append(typ) + elif self.check_argument_count( typ, arg_types, arg_kinds, arg_names, formal_to_actual, None ): if args_have_var_arg and typ.is_var_arg: From 63f9438bfd08f6303e0b67878b46063c011b8162 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 21 Aug 2024 20:29:47 +0200 Subject: [PATCH 6/6] Remove parameters thaat are no longer used --- mypy/checkexpr.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d9729384c465..000a17693e9d 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2619,7 +2619,7 @@ def check_overload_call( arg_types = self.infer_arg_types_in_empty_context(args) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( - args, arg_types, arg_kinds, arg_names, callee, context + arg_types, arg_kinds, arg_names, callee ) # Step 2: If the arguments contain a union, we try performing union math first, @@ -2777,12 +2777,10 @@ def adjust_generic_callable_params_mapping( def plausible_overload_call_targets( self, - args: list[Expression], arg_types: list[Type], arg_kinds: list[ArgKind], arg_names: Sequence[str | None] | None, overload: Overloaded, - context: Context, ) -> list[CallableType]: """Returns all overload call targets that having matching argument counts.