Skip to content

Commit

Permalink
Support unions in functools.partial (#17284)
Browse files Browse the repository at this point in the history
Co-authored-by: cdce8p
  • Loading branch information
hauntsaninja authored May 25, 2024
1 parent 43a605f commit 3ddc009
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
17 changes: 16 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing_extensions import TypeAlias as _TypeAlias

import mypy.checkexpr
from mypy import errorcodes as codes, message_registry, nodes, operators
from mypy import errorcodes as codes, join, message_registry, nodes, operators
from mypy.binder import ConditionalTypeBinder, Frame, get_declaration
from mypy.checkmember import (
MemberContext,
Expand Down Expand Up @@ -699,6 +699,21 @@ def extract_callable_type(self, inner_type: Type | None, ctx: Context) -> Callab
)
if isinstance(inner_call, CallableType):
outer_type = inner_call
elif isinstance(inner_type, UnionType):
union_type = make_simplified_union(inner_type.items)
if isinstance(union_type, UnionType):
items = []
for item in union_type.items:
callable_item = self.extract_callable_type(item, ctx)
if callable_item is None:
break
items.append(callable_item)
else:
joined_type = get_proper_type(join.join_type_list(items))
if isinstance(joined_type, CallableType):
outer_type = joined_type
else:
return self.extract_callable_type(union_type, ctx)
if outer_type is None:
self.msg.not_callable(inner_type, ctx)
return outer_type
Expand Down
4 changes: 2 additions & 2 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import overload
from typing import Sequence, overload

import mypy.typeops
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -853,7 +853,7 @@ def object_or_any_from_type(typ: ProperType) -> ProperType:
return AnyType(TypeOfAny.implementation_artifact)


def join_type_list(types: list[Type]) -> Type:
def join_type_list(types: Sequence[Type]) -> Type:
if not types:
# This is a little arbitrary but reasonable. Any empty tuple should be compatible
# with all variable length tuples, and this makes it possible.
Expand Down
22 changes: 22 additions & 0 deletions test-data/unit/check-functools.test
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,25 @@ p(bar, 1, "a", 3.0) # OK
p(bar, 1, "a", 3.0, kwarg="asdf") # OK
p(bar, 1, "a", "b") # E: Argument 1 to "foo" has incompatible type "Callable[[int, str, float], None]"; expected "Callable[[int, str, str], None]"
[builtins fixtures/dict.pyi]

[case testFunctoolsPartialUnion]
import functools
from typing import Any, Callable, Union

cls1: Any
cls2: Union[Any, Any]
reveal_type(functools.partial(cls1, 2)()) # N: Revealed type is "Any"
reveal_type(functools.partial(cls2, 2)()) # N: Revealed type is "Any"

fn1: Union[Callable[[int], int], Callable[[int], int]]
reveal_type(functools.partial(fn1, 2)()) # N: Revealed type is "builtins.int"

fn2: Union[Callable[[int], int], Callable[[int], str]]
reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "builtins.object"

fn3: Union[Callable[[int], int], str]
reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \
# E: "Union[Callable[[int], int], str]" not callable \
# N: Revealed type is "builtins.int" \
# E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]"
[builtins fixtures/tuple.pyi]

0 comments on commit 3ddc009

Please sign in to comment.