From 8236c93d899fa5225eb23644db802cf1e09196a7 Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Mon, 23 Oct 2023 15:52:42 +0300 Subject: [PATCH] Add `|=` and `|` operators support for `TypedDict` (#16249) Please, note that there are several problems with `__ror__` definitions. 1. `dict.__ror__` does not define support for `Mapping?` types. For example: ```python >>> import types >>> {'a': 1} | types.MappingProxyType({'b': 2}) {'a': 1, 'b': 2} >>> ``` 2. `TypedDict.__ror__` also does not define this support So, I would like to defer this feature for the future, we need some discussion to happen. However, this PR does fully solve the problem OP had. Closes https://github.com/python/mypy/issues/16244 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- mypy/checker.py | 19 ++- mypy/checkexpr.py | 55 ++++++- mypy/plugins/default.py | 22 ++- test-data/unit/check-typeddict.test | 143 ++++++++++++++++++ test-data/unit/fixtures/dict.pyi | 19 ++- test-data/unit/fixtures/typing-async.pyi | 1 + test-data/unit/fixtures/typing-full.pyi | 1 + test-data/unit/fixtures/typing-medium.pyi | 1 + .../unit/fixtures/typing-typeddict-iror.pyi | 66 ++++++++ 9 files changed, 316 insertions(+), 11 deletions(-) create mode 100644 test-data/unit/fixtures/typing-typeddict-iror.pyi diff --git a/mypy/checker.py b/mypy/checker.py index 02bab37aa13f..64bbbfa0a55b 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7783,14 +7783,25 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> tuple[bool, st """ typ = get_proper_type(typ) method = operators.op_methods[operator] + existing_method = None if isinstance(typ, Instance): - if operator in operators.ops_with_inplace_method: - inplace_method = "__i" + method[2:] - if typ.type.has_readable_member(inplace_method): - return True, inplace_method + existing_method = _find_inplace_method(typ, method, operator) + elif isinstance(typ, TypedDictType): + existing_method = _find_inplace_method(typ.fallback, method, operator) + + if existing_method is not None: + return True, existing_method return False, method +def _find_inplace_method(inst: Instance, method: str, operator: str) -> str | None: + if operator in operators.ops_with_inplace_method: + inplace_method = "__i" + method[2:] + if inst.type.has_readable_member(inplace_method): + return inplace_method + return None + + def is_valid_inferred_type(typ: Type, is_lvalue_final: bool = False) -> bool: """Is an inferred type valid and needs no further refinement? diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2dc5a93a1de9..18c1c570ba91 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2,12 +2,13 @@ from __future__ import annotations +import enum import itertools import time from collections import defaultdict from contextlib import contextmanager from typing import Callable, ClassVar, Final, Iterable, Iterator, List, Optional, Sequence, cast -from typing_extensions import TypeAlias as _TypeAlias, overload +from typing_extensions import TypeAlias as _TypeAlias, assert_never, overload import mypy.checker import mypy.errorcodes as codes @@ -277,6 +278,20 @@ class Finished(Exception): """Raised if we can terminate overload argument check early (no match).""" +@enum.unique +class UseReverse(enum.Enum): + """Used in `visit_op_expr` to enable or disable reverse method checks.""" + + DEFAULT = 0 + ALWAYS = 1 + NEVER = 2 + + +USE_REVERSE_DEFAULT: Final = UseReverse.DEFAULT +USE_REVERSE_ALWAYS: Final = UseReverse.ALWAYS +USE_REVERSE_NEVER: Final = UseReverse.NEVER + + class ExpressionChecker(ExpressionVisitor[Type]): """Expression type checker. @@ -3371,6 +3386,24 @@ def visit_op_expr(self, e: OpExpr) -> Type: return proper_left_type.copy_modified( items=proper_left_type.items + [UnpackType(mapped)] ) + + use_reverse: UseReverse = USE_REVERSE_DEFAULT + if e.op == "|": + if is_named_instance(proper_left_type, "builtins.dict"): + # This is a special case for `dict | TypedDict`. + # 1. Find `dict | TypedDict` case + # 2. Switch `dict.__or__` to `TypedDict.__ror__` (the same from both runtime and typing perspective) + proper_right_type = get_proper_type(self.accept(e.right)) + if isinstance(proper_right_type, TypedDictType): + use_reverse = USE_REVERSE_ALWAYS + if isinstance(proper_left_type, TypedDictType): + # This is the reverse case: `TypedDict | dict`, + # simply do not allow the reverse checking: + # do not call `__dict__.__ror__`. + proper_right_type = get_proper_type(self.accept(e.right)) + if is_named_instance(proper_right_type, "builtins.dict"): + use_reverse = USE_REVERSE_NEVER + if TYPE_VAR_TUPLE in self.chk.options.enable_incomplete_feature: # Handle tuple[X, ...] + tuple[Y, Z] = tuple[*tuple[X, ...], Y, Z]. if ( @@ -3390,7 +3423,25 @@ def visit_op_expr(self, e: OpExpr) -> Type: if e.op in operators.op_methods: method = operators.op_methods[e.op] - result, method_type = self.check_op(method, left_type, e.right, e, allow_reverse=True) + if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: + result, method_type = self.check_op( + method, + base_type=left_type, + arg=e.right, + context=e, + allow_reverse=use_reverse is UseReverse.DEFAULT, + ) + elif use_reverse is UseReverse.ALWAYS: + result, method_type = self.check_op( + # The reverse operator here gives better error messages: + operators.reverse_op_methods[method], + base_type=self.accept(e.right), + arg=e.left, + context=e, + allow_reverse=False, + ) + else: + assert_never(use_reverse) e.method_type = method_type return result else: diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index b60fc3873c04..ddcc37f465fe 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -74,12 +74,21 @@ def get_method_signature_hook( return typed_dict_setdefault_signature_callback elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}: return typed_dict_pop_signature_callback - elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}: - return typed_dict_update_signature_callback elif fullname == "_ctypes.Array.__setitem__": return ctypes.array_setitem_callback elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD: return singledispatch.call_singledispatch_function_callback + + typed_dict_updates = set() + for n in TPDICT_FB_NAMES: + typed_dict_updates.add(n + ".update") + typed_dict_updates.add(n + ".__or__") + typed_dict_updates.add(n + ".__ror__") + typed_dict_updates.add(n + ".__ior__") + + if fullname in typed_dict_updates: + return typed_dict_update_signature_callback + return None def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None: @@ -401,11 +410,16 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type: def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: - """Try to infer a better signature type for TypedDict.update.""" + """Try to infer a better signature type for methods that update `TypedDict`. + + This includes: `TypedDict.update`, `TypedDict.__or__`, `TypedDict.__ror__`, + and `TypedDict.__ior__`. + """ signature = ctx.default_signature if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1: arg_type = get_proper_type(signature.arg_types[0]) - assert isinstance(arg_type, TypedDictType) + if not isinstance(arg_type, TypedDictType): + return signature arg_type = arg_type.as_anonymous() arg_type = arg_type.copy_modified(required_keys=set()) if ctx.args and ctx.args[0]: diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index 7ee9ef0b708b..0e1d800e0468 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -3236,3 +3236,146 @@ def foo(x: int) -> Foo: ... f: Foo = {**foo("no")} # E: Argument 1 to "foo" has incompatible type "str"; expected "int" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictWith__or__method] +from typing import Dict +from mypy_extensions import TypedDict + +class Foo(TypedDict): + key: int + +foo1: Foo = {'key': 1} +foo2: Foo = {'key': 2} + +reveal_type(foo1 | foo2) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})" +reveal_type(foo1 | {'key': 1}) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})" +reveal_type(foo1 | {'key': 'a'}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +reveal_type(foo1 | {}) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})" + +d1: Dict[str, int] +d2: Dict[int, str] + +reveal_type(foo1 | d1) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +foo1 | d2 # E: Unsupported operand types for | ("Foo" and "Dict[int, str]") + + +class Bar(TypedDict): + key: int + value: str + +bar: Bar +reveal_type(bar | {}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type(bar | {'key': 1, 'value': 'v'}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type(bar | {'key': 1}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type(bar | {'value': 'v'}) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type(bar | {'key': 'a'}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +reveal_type(bar | {'value': 1}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +reveal_type(bar | {'key': 'a', 'value': 1}) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" + +reveal_type(bar | foo1) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type(bar | d1) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +bar | d2 # E: Unsupported operand types for | ("Bar" and "Dict[int, str]") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict-iror.pyi] + +[case testTypedDictWith__or__method_error] +from mypy_extensions import TypedDict + +class Foo(TypedDict): + key: int + +foo: Foo = {'key': 1} +foo | 1 + +class SubDict(dict): ... +foo | SubDict() +[out] +main:7: error: No overload variant of "__or__" of "TypedDict" matches argument type "int" +main:7: note: Possible overload variants: +main:7: note: def __or__(self, TypedDict({'key'?: int}), /) -> Foo +main:7: note: def __or__(self, Dict[str, Any], /) -> Dict[str, object] +main:10: error: No overload variant of "__ror__" of "dict" matches argument type "Foo" +main:10: note: Possible overload variants: +main:10: note: def __ror__(self, Dict[Any, Any], /) -> Dict[Any, Any] +main:10: note: def [T, T2] __ror__(self, Dict[T, T2], /) -> Dict[Union[Any, T], Union[Any, T2]] +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict-iror.pyi] + +[case testTypedDictWith__ror__method] +from typing import Dict +from mypy_extensions import TypedDict + +class Foo(TypedDict): + key: int + +foo: Foo = {'key': 1} + +reveal_type({'key': 1} | foo) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})" +reveal_type({'key': 'a'} | foo) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +reveal_type({} | foo) # N: Revealed type is "TypedDict('__main__.Foo', {'key': builtins.int})" +{1: 'a'} | foo # E: Dict entry 0 has incompatible type "int": "str"; expected "str": "Any" + +d1: Dict[str, int] +d2: Dict[int, str] + +reveal_type(d1 | foo) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +d2 | foo # E: Unsupported operand types for | ("Dict[int, str]" and "Foo") +1 | foo # E: Unsupported left operand type for | ("int") + + +class Bar(TypedDict): + key: int + value: str + +bar: Bar +reveal_type({} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type({'key': 1, 'value': 'v'} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type({'key': 1} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type({'value': 'v'} | bar) # N: Revealed type is "TypedDict('__main__.Bar', {'key': builtins.int, 'value': builtins.str})" +reveal_type({'key': 'a'} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +reveal_type({'value': 1} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +reveal_type({'key': 'a', 'value': 1} | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" + +reveal_type(d1 | bar) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]" +d2 | bar # E: Unsupported operand types for | ("Dict[int, str]" and "Bar") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict-iror.pyi] + +[case testTypedDictWith__ior__method] +from typing import Dict +from mypy_extensions import TypedDict + +class Foo(TypedDict): + key: int + +foo: Foo = {'key': 1} +foo |= {'key': 2} + +foo |= {} +foo |= {'key': 'a', 'b': 'a'} # E: Expected TypedDict key "key" but found keys ("key", "b") \ + # E: Incompatible types (expression has type "str", TypedDict item "key" has type "int") +foo |= {'b': 2} # E: Unexpected TypedDict key "b" + +d1: Dict[str, int] +d2: Dict[int, str] + +foo |= d1 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[str, int]"; expected "TypedDict({'key'?: int})" +foo |= d2 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[int, str]"; expected "TypedDict({'key'?: int})" + + +class Bar(TypedDict): + key: int + value: str + +bar: Bar +bar |= {} +bar |= {'key': 1, 'value': 'a'} +bar |= {'key': 'a', 'value': 'a', 'b': 'a'} # E: Expected TypedDict keys ("key", "value") but found keys ("key", "value", "b") \ + # E: Incompatible types (expression has type "str", TypedDict item "key" has type "int") + +bar |= foo +bar |= d1 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[str, int]"; expected "TypedDict({'key'?: int, 'value'?: str})" +bar |= d2 # E: Argument 1 to "__ior__" of "TypedDict" has incompatible type "Dict[int, str]"; expected "TypedDict({'key'?: int, 'value'?: str})" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict-iror.pyi] diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index 19d175ff79ab..7c0c8767f7d7 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -3,10 +3,12 @@ from _typeshed import SupportsKeysAndGetItem import _typeshed from typing import ( - TypeVar, Generic, Iterable, Iterator, Mapping, Tuple, overload, Optional, Union, Sequence + TypeVar, Generic, Iterable, Iterator, Mapping, Tuple, overload, Optional, Union, Sequence, + Self, ) T = TypeVar('T') +T2 = TypeVar('T2') KT = TypeVar('KT') VT = TypeVar('VT') @@ -34,6 +36,21 @@ class dict(Mapping[KT, VT]): def get(self, k: KT, default: Union[VT, T]) -> Union[VT, T]: pass def __len__(self) -> int: ... + # This was actually added in 3.9: + @overload + def __or__(self, __value: dict[KT, VT]) -> dict[KT, VT]: ... + @overload + def __or__(self, __value: dict[T, T2]) -> dict[Union[KT, T], Union[VT, T2]]: ... + @overload + def __ror__(self, __value: dict[KT, VT]) -> dict[KT, VT]: ... + @overload + def __ror__(self, __value: dict[T, T2]) -> dict[Union[KT, T], Union[VT, T2]]: ... + # dict.__ior__ should be kept roughly in line with MutableMapping.update() + @overload # type: ignore[misc] + def __ior__(self, __value: _typeshed.SupportsKeysAndGetItem[KT, VT]) -> Self: ... + @overload + def __ior__(self, __value: Iterable[Tuple[KT, VT]]) -> Self: ... + class int: # for convenience def __add__(self, x: Union[int, complex]) -> int: pass def __radd__(self, x: int) -> int: pass diff --git a/test-data/unit/fixtures/typing-async.pyi b/test-data/unit/fixtures/typing-async.pyi index b207dd599c33..9897dfd0b270 100644 --- a/test-data/unit/fixtures/typing-async.pyi +++ b/test-data/unit/fixtures/typing-async.pyi @@ -24,6 +24,7 @@ ClassVar = 0 Final = 0 Literal = 0 NoReturn = 0 +Self = 0 T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index e9f0aa199bb4..ef903ace78af 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -30,6 +30,7 @@ Literal = 0 TypedDict = 0 NoReturn = 0 NewType = 0 +Self = 0 T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) diff --git a/test-data/unit/fixtures/typing-medium.pyi b/test-data/unit/fixtures/typing-medium.pyi index 03be1d0a664d..c19c5d5d96e2 100644 --- a/test-data/unit/fixtures/typing-medium.pyi +++ b/test-data/unit/fixtures/typing-medium.pyi @@ -28,6 +28,7 @@ NoReturn = 0 NewType = 0 TypeAlias = 0 LiteralString = 0 +Self = 0 T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) diff --git a/test-data/unit/fixtures/typing-typeddict-iror.pyi b/test-data/unit/fixtures/typing-typeddict-iror.pyi new file mode 100644 index 000000000000..e452c8497109 --- /dev/null +++ b/test-data/unit/fixtures/typing-typeddict-iror.pyi @@ -0,0 +1,66 @@ +# Test stub for typing module that includes TypedDict `|` operator. +# It only covers `__or__`, `__ror__`, and `__ior__`. +# +# We cannot define these methods in `typing-typeddict.pyi`, +# because they need `dict` with two type args, +# and not all tests using `[typing typing-typeddict.pyi]` have the proper +# `dict` stub. +# +# Keep in sync with `typeshed`'s definition. +from abc import ABCMeta + +cast = 0 +assert_type = 0 +overload = 0 +Any = 0 +Union = 0 +Optional = 0 +TypeVar = 0 +Generic = 0 +Protocol = 0 +Tuple = 0 +Callable = 0 +NamedTuple = 0 +Final = 0 +Literal = 0 +TypedDict = 0 +NoReturn = 0 +Required = 0 +NotRequired = 0 +Self = 0 + +T = TypeVar('T') +T_co = TypeVar('T_co', covariant=True) +V = TypeVar('V') + +# Note: definitions below are different from typeshed, variances are declared +# to silence the protocol variance checks. Maybe it is better to use type: ignore? + +class Sized(Protocol): + def __len__(self) -> int: pass + +class Iterable(Protocol[T_co]): + def __iter__(self) -> 'Iterator[T_co]': pass + +class Iterator(Iterable[T_co], Protocol): + def __next__(self) -> T_co: pass + +class Sequence(Iterable[T_co]): + # misc is for explicit Any. + def __getitem__(self, n: Any) -> T_co: pass # type: ignore[misc] + +class Mapping(Iterable[T], Generic[T, T_co], metaclass=ABCMeta): + pass + +# Fallback type for all typed dicts (does not exist at runtime). +class _TypedDict(Mapping[str, object]): + @overload + def __or__(self, __value: Self) -> Self: ... + @overload + def __or__(self, __value: dict[str, Any]) -> dict[str, object]: ... + @overload + def __ror__(self, __value: Self) -> Self: ... + @overload + def __ror__(self, __value: dict[str, Any]) -> dict[str, object]: ... + # supposedly incompatible definitions of __or__ and __ior__ + def __ior__(self, __value: Self) -> Self: ... # type: ignore[misc]