Skip to content

Commit

Permalink
Fix missing type store for overloads (#16803)
Browse files Browse the repository at this point in the history
Add missing call to store inferred types if an overload match is found
early. All other code paths already do that.


### Some background on the issue this fixes
I recently saw an interesting pattern in `aiohttp` to type values in an
`dict[str, Any]` by subclassing dict.

```py
T = TypeVar("T")
U = TypeVar("U")

class Key(Generic[T]):
    ...

class CustomDict(dict[Key[Any] | str, Any]):
    @overload  # type: ignore[override]
    def get(self, __key: Key[T]) -> T | None:
        ...

    @overload
    def get(self, __key: Key[T], __default: U) -> T | U:
        ...

    @overload
    def get(self, __key: str) -> Any | None:
        ...

    @overload
    def get(self, __key: str, __default: Any) -> Any:
        ...

    def get(self, __key: Key[Any] | str, __default: Any = None) -> Any:
        """Forward to super implementation."""
        return super().get(__key, __default)

    # overloads for __getitem__, setdefault, pop
    # ...

    @overload  # type: ignore[override]
    def __setitem__(self, key: Key[T], value: T) -> None:
        ...

    @overload
    def __setitem__(self, key: str, value: Any) -> None:
        ...

    def __setitem__(self, key: Key[Any] | str, value: Any) -> None:
        """Forward to super implementation."""
        return super().__setitem__(key, value)
```

With the exception that these overloads aren't technically compatible
with the supertype, they do the job.
```py
d = CustomDict()
key = Key[int]()
other_key = "other"
assert_type(d.get(key), int | None)
assert_type(d.get("other"), Any | None)
```

The issue exists for the `__setitem__` case. Without this PR the
following would create an issue. Here `var` would be inferred as
`dict[Never, Never]`, even though it should be `dict[Any, Any]` which is
the case for non-subclassed dicts.
```py
def a2(d: CustomDict) -> None:
    if (var := d.get("arg")) is None:
        var = d["arg"] = {}
        reveal_type(var)
```
  • Loading branch information
cdce8p authored Jan 31, 2024
1 parent 55247c4 commit 5bf7742
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,6 +2825,7 @@ def infer_overload_return_type(
# Return early if possible; otherwise record info, so we can
# check for ambiguity due to 'Any' below.
if not args_contain_any:
self.chk.store_types(m)
return ret_type, infer_type
p_infer_type = get_proper_type(infer_type)
if isinstance(p_infer_type, CallableType):
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,30 @@ if int():
b = f(b)
[builtins fixtures/list.pyi]

[case testGenericDictWithOverload]
from typing import Dict, Generic, TypeVar, Any, overload
T = TypeVar("T")

class Key(Generic[T]): ...
class CustomDict(dict):
@overload # type: ignore[override]
def __setitem__(self, key: Key[T], value: T) -> None: ...
@overload
def __setitem__(self, key: str, value: Any) -> None: ...
def __setitem__(self, key, value):
return super().__setitem__(key, value)

def a1(d: Dict[str, Any]) -> None:
if (var := d.get("arg")) is None:
var = d["arg"] = {}
reveal_type(var) # N: Revealed type is "builtins.dict[Any, Any]"

def a2(d: CustomDict) -> None:
if (var := d.get("arg")) is None:
var = d["arg"] = {}
reveal_type(var) # N: Revealed type is "builtins.dict[Any, Any]"
[builtins fixtures/dict.pyi]


-- Type variable scoping
-- ---------------------
Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/typexport-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,27 @@ LambdaExpr(10) : def (x: builtins.int) -> builtins.int
LambdaExpr(12) : def (y: builtins.str) -> builtins.str
LambdaExpr(13) : def (x: builtins.str) -> builtins.str

[case testExportOverloadArgTypeDict]
## DictExpr
from typing import TypeVar, Generic, Any, overload, Dict
T = TypeVar("T")
class Key(Generic[T]): ...
@overload
def f(x: Key[T], y: T) -> T: ...
@overload
def f(x: int, y: Any) -> Any: ...
def f(x, y): ...
d: Dict = {}
d.get(
"", {})
f(
2, {})
[builtins fixtures/dict.pyi]
[out]
DictExpr(10) : builtins.dict[Any, Any]
DictExpr(12) : builtins.dict[Any, Any]
DictExpr(14) : builtins.dict[Any, Any]

-- TODO
--
-- test expressions
Expand Down

0 comments on commit 5bf7742

Please sign in to comment.