Skip to content

Commit

Permalink
Support ParamSpec mapping with functools.partial (#17355)
Browse files Browse the repository at this point in the history
Follow-up for #17323, resolving a false positive discovered there.

Closes #17960.

This enables use of `functools.partial` to bind some `*args` or
`**kwargs` on a callable typed with `ParamSpec`.

---------

Co-authored-by: Shantanu Jain <[email protected]>
  • Loading branch information
sterliakov and hauntsaninja authored Oct 26, 2024
1 parent e7db89c commit a706914
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 7 deletions.
6 changes: 5 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,7 +2377,11 @@ def check_argument_count(
# Positional argument when expecting a keyword argument.
self.msg.too_many_positional_arguments(callee, context)
ok = False
elif callee.param_spec() is not None and not formal_to_actual[i]:
elif (
callee.param_spec() is not None
and not formal_to_actual[i]
and callee.special_sig != "partial"
):
self.msg.too_few_arguments(callee, context, actual_names)
ok = False
return ok
Expand Down
47 changes: 43 additions & 4 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
import mypy.plugin
import mypy.semanal
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, Var
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, CallExpr, FuncItem, NameExpr, Var
from mypy.plugins.common import add_method_to_class
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Overloaded,
ParamSpecFlavor,
ParamSpecType,
Type,
TypeOfAny,
TypeVarType,
Expand Down Expand Up @@ -202,6 +204,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
continue
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})

# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
defaulted = fn_type.copy_modified(
arg_kinds=[
(
Expand All @@ -218,6 +221,7 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
],
special_sig="partial",
)
if defaulted.line < 0:
# Make up a line number if we don't have one
Expand Down Expand Up @@ -296,10 +300,19 @@ def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
special_sig="partial",
)

ret = ctx.api.named_generic_type(PARTIAL, [ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
if partially_applied.param_spec():
assert ret.extra_attrs is not None # copy_with_extra_attr above ensures this
attrs = ret.extra_attrs.copy()
if ArgKind.ARG_STAR in actual_arg_kinds:
attrs.immutable.add("__mypy_partial_paramspec_args_bound")
if ArgKind.ARG_STAR2 in actual_arg_kinds:
attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound")
ret.extra_attrs = attrs
return ret


Expand All @@ -314,7 +327,8 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
):
return ctx.default_return_type

partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
extra_attrs = ctx.type.extra_attrs
partial_type = get_proper_type(extra_attrs.attrs["__mypy_partial"])
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

Expand All @@ -332,11 +346,36 @@ def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])

result = ctx.api.expr_checker.check_call(
result, _ = ctx.api.expr_checker.check_call(
callee=partial_type,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
return result[0]
if not isinstance(partial_type, CallableType) or partial_type.param_spec() is None:
return result

args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable
kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable

passed_paramspec_parts = [
arg.node.type
for arg in actual_args
if isinstance(arg, NameExpr)
and isinstance(arg.node, Var)
and isinstance(arg.node.type, ParamSpecType)
]
# ensure *args: P.args
args_passed = any(part.flavor == ParamSpecFlavor.ARGS for part in passed_paramspec_parts)
if not args_bound and not args_passed:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
elif args_bound and args_passed:
ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context)

# ensure **kwargs: P.kwargs
kwargs_passed = any(part.flavor == ParamSpecFlavor.KWARGS for part in passed_paramspec_parts)
if not kwargs_bound and not kwargs_passed:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)

return result
4 changes: 2 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,8 +1827,8 @@ class CallableType(FunctionLike):
"implicit", # Was this type implicitly generated instead of explicitly
# specified by the user?
"special_sig", # Non-None for signatures that require special handling
# (currently only value is 'dict' for a signature similar to
# 'dict')
# (currently only values are 'dict' for a signature similar to
# 'dict' and 'partial' for a `functools.partial` evaluation)
"from_type_type", # Was this callable generated by analyzing Type[...]
# instantiation?
"bound_args", # Bound type args, mostly unused but may be useful for
Expand Down
124 changes: 124 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2338,3 +2338,127 @@ reveal_type(handle_reversed(Child())) # N: Revealed type is "builtins.str"
reveal_type(handle_reversed(NotChild())) # N: Revealed type is "builtins.str"

[builtins fixtures/paramspec.pyi]

[case testBindPartial]
from functools import partial
from typing_extensions import ParamSpec
from typing import Callable, TypeVar

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
return func2(*args)

def run2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args)
return func2(**kwargs)

def run3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2()

def run4(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2(**kwargs)

def run_bad(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args, **kwargs)
return func2(*args) # E: Too many arguments

def run_bad2(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
return func2(**kwargs) # E: Too few arguments

def run_bad3(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args)
return func2() # E: Too few arguments

[builtins fixtures/paramspec.pyi]

[case testBindPartialConcatenate]
from functools import partial
from typing_extensions import Concatenate, ParamSpec
from typing import Callable, TypeVar

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, **kwargs)
return func2(*args)

def run2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
p = [""]
func2(1, *p) # E: Too few arguments \
# E: Argument 2 has incompatible type "*List[str]"; expected "P.args"
func2(1, 2, *p) # E: Too few arguments \
# E: Argument 2 has incompatible type "int"; expected "P.args" \
# E: Argument 3 has incompatible type "*List[str]"; expected "P.args"
func2(1, *args, *p) # E: Argument 3 has incompatible type "*List[str]"; expected "P.args"
func2(1, *p, *args) # E: Argument 2 has incompatible type "*List[str]"; expected "P.args"
return func2(1, *args)

def run3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args)
d = {"":""}
func2(**d) # E: Too few arguments \
# E: Argument 1 has incompatible type "**Dict[str, str]"; expected "P.kwargs"
return func2(**kwargs)

def run4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1)
return func2(*args, **kwargs)

def run5(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args, **kwargs)
func2()
return func2(**kwargs)

def run_bad(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, *args) # E: Argument 1 has incompatible type "*P.args"; expected "int"
return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args"

def run_bad2(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, *args)
func2() # E: Too few arguments
func2(*args, **kwargs) # E: Too many arguments
return func2(1, **kwargs) # E: Argument 1 has incompatible type "int"; expected "P.args"

def run_bad3(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1, **kwargs)
func2() # E: Too few arguments
return func2(1, *args) # E: Argument 1 has incompatible type "int"; expected "P.args"

def run_bad4(func: Callable[Concatenate[int, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, 1)
func2() # E: Too few arguments
func2(*args) # E: Too few arguments
func2(1, *args) # E: Too few arguments \
# E: Argument 1 has incompatible type "int"; expected "P.args"
func2(1, **kwargs) # E: Too few arguments \
# E: Argument 1 has incompatible type "int"; expected "P.args"
return func2(**kwargs) # E: Too few arguments

[builtins fixtures/paramspec.pyi]

[case testOtherVarArgs]
from functools import partial
from typing_extensions import Concatenate, ParamSpec
from typing import Callable, TypeVar, Tuple

P = ParamSpec("P")
T = TypeVar("T")

def run(func: Callable[Concatenate[int, str, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
func2 = partial(func, **kwargs)
args_prefix: Tuple[int, str] = (1, 'a')
func2(*args_prefix) # E: Too few arguments
func2(*args, *args_prefix) # E: Argument 1 has incompatible type "*P.args"; expected "int" \
# E: Argument 1 has incompatible type "*P.args"; expected "str" \
# E: Argument 2 has incompatible type "*Tuple[int, str]"; expected "P.args"
return func2(*args_prefix, *args)

[builtins fixtures/paramspec.pyi]

0 comments on commit a706914

Please sign in to comment.