Skip to content

Commit

Permalink
Add support for functools.partial (#16939)
Browse files Browse the repository at this point in the history
Fixes #1484

Turns out that this is currently the second most popular mypy issue (and
first most popular is a type system feature request that would need a
PEP). I'm sure there's stuff missing, but this should handle most cases.
  • Loading branch information
hauntsaninja authored May 23, 2024
1 parent ca393dd commit 0871c93
Show file tree
Hide file tree
Showing 9 changed files with 454 additions and 27 deletions.
34 changes: 17 additions & 17 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,14 +1229,14 @@ def apply_function_plugin(
assert callback is not None # Assume that caller ensures this
return callback(
FunctionContext(
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)
else:
Expand All @@ -1246,15 +1246,15 @@ def apply_function_plugin(
object_type = get_proper_type(object_type)
return method_callback(
MethodContext(
object_type,
formal_arg_types,
formal_arg_kinds,
callee.arg_names,
formal_arg_names,
callee.ret_type,
formal_arg_exprs,
context,
self.chk,
type=object_type,
arg_types=formal_arg_types,
arg_kinds=formal_arg_kinds,
callee_arg_names=callee.arg_names,
arg_names=formal_arg_names,
default_return_type=callee.ret_type,
args=formal_arg_exprs,
context=context,
api=self.chk,
)
)

Expand Down
3 changes: 3 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def visit_instance(self, inst: Instance) -> None:
a.accept(self)
if inst.last_known_value is not None:
inst.last_known_value.accept(self)
if inst.extra_attrs:
for v in inst.extra_attrs.attrs.values():
v.accept(self)

def visit_type_alias_type(self, t: TypeAliasType) -> None:
type_ref = t.type_ref
Expand Down
15 changes: 12 additions & 3 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
return ctypes.array_constructor_callback
elif fullname == "functools.singledispatch":
return singledispatch.create_singledispatch_function_callback
elif fullname == "functools.partial":
import mypy.plugins.functools

return mypy.plugins.functools.partial_new_callback

return None

Expand Down Expand Up @@ -118,6 +122,10 @@ def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | No
return singledispatch.singledispatch_register_callback
elif fullname == singledispatch.REGISTER_CALLABLE_CALL_METHOD:
return singledispatch.call_singledispatch_function_after_register_argument
elif fullname == "functools.partial.__call__":
import mypy.plugins.functools

return mypy.plugins.functools.partial_call_callback
return None

def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
Expand Down Expand Up @@ -155,12 +163,13 @@ def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext],
def get_class_decorator_hook_2(
self, fullname: str
) -> Callable[[ClassDefContext], bool] | None:
from mypy.plugins import attrs, dataclasses, functools
import mypy.plugins.functools
from mypy.plugins import attrs, dataclasses

if fullname in dataclasses.dataclass_makers:
return dataclasses.dataclass_class_maker_callback
elif fullname in functools.functools_total_ordering_makers:
return functools.functools_total_ordering_maker_callback
elif fullname in mypy.plugins.functools.functools_total_ordering_makers:
return mypy.plugins.functools.functools_total_ordering_maker_callback
elif fullname in attrs.attr_class_makers:
return attrs.attr_class_maker_callback
elif fullname in attrs.attr_dataclass_makers:
Expand Down
144 changes: 142 additions & 2 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,22 @@

from typing import Final, NamedTuple

import mypy.checker
import mypy.plugin
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR2, ArgKind, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type
from mypy.types import (
AnyType,
CallableType,
Instance,
Overloaded,
Type,
TypeOfAny,
UnboundType,
UninhabitedType,
get_proper_type,
)

functools_total_ordering_makers: Final = {"functools.total_ordering"}

Expand Down Expand Up @@ -102,3 +114,131 @@ def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo |
comparison_methods[name] = None

return comparison_methods


def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.partial"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
return ctx.default_return_type
if len(ctx.arg_types[0]) != 1:
return ctx.default_return_type

if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
# TODO: handle overloads, just fall back to whatever the non-plugin code does
return ctx.default_return_type
fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type)
if fn_type is None:
return ctx.default_return_type

defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
]
)
if defaulted.line < 0:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)

actual_args = [a for param in ctx.args[1:] for a in param]
actual_arg_kinds = [a for param in ctx.arg_kinds[1:] for a in param]
actual_arg_names = [a for param in ctx.arg_names[1:] for a in param]
actual_types = [a for param in ctx.arg_types[1:] for a in param]

_, bound = ctx.api.expr_checker.check_call(
callee=defaulted,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=defaulted,
)
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
return ctx.default_return_type

formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)

partial_kinds = []
partial_types = []
partial_names = []
# We need to fully apply any positional arguments (they cannot be respecified)
# However, keyword arguments can be respecified, so just give them a default
for i, actuals in enumerate(formal_to_actual):
if len(bound.arg_types) == len(fn_type.arg_types):
arg_type = bound.arg_types[i]
if isinstance(get_proper_type(arg_type), UninhabitedType):
arg_type = fn_type.arg_types[i] # bit of a hack
else:
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
arg_type = fn_type.arg_types[i]

if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
partial_kinds.append(fn_type.arg_kinds[i])
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
elif actuals:
if any(actual_arg_kinds[j] == ArgKind.ARG_POS for j in actuals):
continue
kind = actual_arg_kinds[actuals[0]]
if kind == ArgKind.ARG_NAMED:
kind = ArgKind.ARG_NAMED_OPT
partial_kinds.append(kind)
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])

ret_type = bound.ret_type
if isinstance(get_proper_type(ret_type), UninhabitedType):
ret_type = fn_type.ret_type # same kind of hack as above

partially_applied = fn_type.copy_modified(
arg_types=partial_types,
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
)

ret = ctx.api.named_generic_type("functools.partial", [ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
return ret


def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Infer a more precise return type for functools.partial.__call__."""
if (
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
or not isinstance(ctx.type, Instance)
or ctx.type.type.fullname != "functools.partial"
or not ctx.type.extra_attrs
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
):
return ctx.default_return_type

partial_type = ctx.type.extra_attrs.attrs["__mypy_partial"]
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type

args = [a for param in ctx.args for a in param]
arg_kinds = [a for param in ctx.arg_kinds for a in param]
arg_names = [a for param in ctx.arg_names for a in param]

result = ctx.api.expr_checker.check_call(
callee=partial_type,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=ctx.context,
)
return result[0]
9 changes: 9 additions & 0 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,11 +378,20 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem:
return snapshot_simple_type(typ)

def visit_instance(self, typ: Instance) -> SnapshotItem:
extra_attrs: SnapshotItem
if typ.extra_attrs:
extra_attrs = (
tuple(sorted((k, v.accept(self)) for k, v in typ.extra_attrs.attrs.items())),
tuple(typ.extra_attrs.immutable),
)
else:
extra_attrs = ()
return (
"Instance",
encode_optional_str(typ.type.fullname),
snapshot_types(typ.args),
("None",) if typ.last_known_value is None else snapshot_type(typ.last_known_value),
extra_attrs,
)

def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
Expand Down
29 changes: 25 additions & 4 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,23 @@ def copy(self) -> ExtraAttrs:
def __repr__(self) -> str:
return f"ExtraAttrs({self.attrs!r}, {self.immutable!r}, {self.mod_name!r})"

def serialize(self) -> JsonDict:
return {
".class": "ExtraAttrs",
"attrs": {k: v.serialize() for k, v in self.attrs.items()},
"immutable": list(self.immutable),
"mod_name": self.mod_name,
}

@classmethod
def deserialize(cls, data: JsonDict) -> ExtraAttrs:
assert data[".class"] == "ExtraAttrs"
return ExtraAttrs(
{k: deserialize_type(v) for k, v in data["attrs"].items()},
set(data["immutable"]),
data["mod_name"],
)


class Instance(ProperType):
"""An instance type of form C[T1, ..., Tn].
Expand Down Expand Up @@ -1434,6 +1451,7 @@ def serialize(self) -> JsonDict | str:
data["args"] = [arg.serialize() for arg in self.args]
if self.last_known_value is not None:
data["last_known_value"] = self.last_known_value.serialize()
data["extra_attrs"] = self.extra_attrs.serialize() if self.extra_attrs else None
return data

@classmethod
Expand All @@ -1452,6 +1470,8 @@ def deserialize(cls, data: JsonDict | str) -> Instance:
inst.type_ref = data["type_ref"] # Will be fixed up by fixup.py later.
if "last_known_value" in data:
inst.last_known_value = LiteralType.deserialize(data["last_known_value"])
if data.get("extra_attrs") is not None:
inst.extra_attrs = ExtraAttrs.deserialize(data["extra_attrs"])
return inst

def copy_modified(
Expand All @@ -1461,13 +1481,14 @@ def copy_modified(
last_known_value: Bogus[LiteralType | None] = _dummy,
) -> Instance:
new = Instance(
self.type,
args if args is not _dummy else self.args,
self.line,
self.column,
typ=self.type,
args=args if args is not _dummy else self.args,
line=self.line,
column=self.column,
last_known_value=(
last_known_value if last_known_value is not _dummy else self.last_known_value
),
extra_attrs=self.extra_attrs,
)
# We intentionally don't copy the extra_attrs here, so they will be erased.
new.can_be_true = self.can_be_true
Expand Down
Loading

0 comments on commit 0871c93

Please sign in to comment.