From 8290bb81db80b139185a3543bd459f904841fe44 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 27 Jun 2023 00:25:59 +0100 Subject: [PATCH] Support flexible TypedDict creation/update (#15425) Fixes #9408 Fixes #4122 Fixes #6462 Supersedes #13353 This PR enables two similar technically unsafe behaviors for TypedDicts, as @JukkaL explained in https://github.com/python/mypy/issues/6462#issuecomment-466464229 allowing an "incomplete" TypedDict as an argument to `.update()` is technically unsafe (and a similar argument applies to `**` syntax in TypedDict literals). These are however very common patterns (judging from number of duplicates to above issues), so I think we should support them. Here is what I propose: * Always support cases that are safe (like passing the type itself to `update`) * Allow popular but technically unsafe cases _by default_ * Have a new flag (as part of `--strict`) to fall back to current behavior Note that unfortunately we can't use just a custom new error code, since we need to conditionally tweak some types in a plugin. Btw there are couple TODOs I add here: * First is for unsafe behavior for repeated TypedDict keys. This is not new, I just noticed it when working on this * Second is for tricky corner case involving multiple `**` items where we may have false-negatives in strict mode. Note that I don't test all the possible combinations here (since the phase space is huge), but I think I am testing all main ingredients (and I will be glad to add more if needed): * All syntax variants for TypedDicts creation are handled * Various shadowing/overrides scenarios * Required vs non-required keys handling * Union types (both as item and target types) * Inference for generic TypedDicts * New strictness flag More than half of the tests I took from the original PR #13353 --- docs/source/command_line.rst | 28 ++ mypy/checkexpr.py | 255 ++++++++++--- mypy/main.py | 13 +- mypy/messages.py | 18 + mypy/options.py | 6 +- mypy/plugins/default.py | 27 ++ mypy/semanal.py | 4 +- mypy/subtypes.py | 10 +- mypy/types.py | 4 + .../unit/check-parameter-specification.test | 2 +- test-data/unit/check-typeddict.test | 361 ++++++++++++++++++ 11 files changed, 659 insertions(+), 69 deletions(-) diff --git a/docs/source/command_line.rst b/docs/source/command_line.rst index 2809294092ab..d9de5cd8f9bd 100644 --- a/docs/source/command_line.rst +++ b/docs/source/command_line.rst @@ -612,6 +612,34 @@ of the above sections. assert text is not None # OK, check against None is allowed as a special case. +.. option:: --extra-checks + + This flag enables additional checks that are technically correct but may be + impractical in real code. In particular, it prohibits partial overlap in + ``TypedDict`` updates, and makes arguments prepended via ``Concatenate`` + positional-only. For example: + + .. code-block:: python + + from typing import TypedDict + + class Foo(TypedDict): + a: int + + class Bar(TypedDict): + a: int + b: int + + def test(foo: Foo, bar: Bar) -> None: + # This is technically unsafe since foo can have a subtype of Foo at + # runtime, where type of key "b" is incompatible with int, see below + bar.update(foo) + + class Bad(Foo): + b: str + bad: Bad = {"a": 0, "b": "no"} + test(bad, bar) + .. option:: --strict This flag mode enables all optional error checking flags. You can see the diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 43896171eadc..986e58c21762 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4,8 +4,9 @@ import itertools import time +from collections import defaultdict from contextlib import contextmanager -from typing import Callable, ClassVar, Iterator, List, Optional, Sequence, cast +from typing import Callable, ClassVar, Iterable, Iterator, List, Optional, Sequence, cast from typing_extensions import Final, TypeAlias as _TypeAlias, overload import mypy.checker @@ -695,74 +696,183 @@ def check_typeddict_call( context: Context, orig_callee: Type | None, ) -> Type: - if args and all([ak == ARG_NAMED for ak in arg_kinds]): - # ex: Point(x=42, y=1337) - assert all(arg_name is not None for arg_name in arg_names) - item_names = cast(List[str], arg_names) - item_args = args - return self.check_typeddict_call_with_kwargs( - callee, dict(zip(item_names, item_args)), context, orig_callee - ) + if args and all([ak in (ARG_NAMED, ARG_STAR2) for ak in arg_kinds]): + # ex: Point(x=42, y=1337, **extras) + # This is a bit ugly, but this is a price for supporting all possible syntax + # variants for TypedDict constructors. + kwargs = zip([StrExpr(n) if n is not None else None for n in arg_names], args) + result = self.validate_typeddict_kwargs(kwargs=kwargs, callee=callee) + if result is not None: + validated_kwargs, always_present_keys = result + return self.check_typeddict_call_with_kwargs( + callee, validated_kwargs, context, orig_callee, always_present_keys + ) + return AnyType(TypeOfAny.from_error) if len(args) == 1 and arg_kinds[0] == ARG_POS: unique_arg = args[0] if isinstance(unique_arg, DictExpr): - # ex: Point({'x': 42, 'y': 1337}) + # ex: Point({'x': 42, 'y': 1337, **extras}) return self.check_typeddict_call_with_dict( - callee, unique_arg, context, orig_callee + callee, unique_arg.items, context, orig_callee ) if isinstance(unique_arg, CallExpr) and isinstance(unique_arg.analyzed, DictExpr): - # ex: Point(dict(x=42, y=1337)) + # ex: Point(dict(x=42, y=1337, **extras)) return self.check_typeddict_call_with_dict( - callee, unique_arg.analyzed, context, orig_callee + callee, unique_arg.analyzed.items, context, orig_callee ) if not args: # ex: EmptyDict() - return self.check_typeddict_call_with_kwargs(callee, {}, context, orig_callee) + return self.check_typeddict_call_with_kwargs(callee, {}, context, orig_callee, set()) self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context) return AnyType(TypeOfAny.from_error) - def validate_typeddict_kwargs(self, kwargs: DictExpr) -> dict[str, Expression] | None: - item_args = [item[1] for item in kwargs.items] - - item_names = [] # List[str] - for item_name_expr, item_arg in kwargs.items: - literal_value = None + def validate_typeddict_kwargs( + self, kwargs: Iterable[tuple[Expression | None, Expression]], callee: TypedDictType + ) -> tuple[dict[str, list[Expression]], set[str]] | None: + # All (actual or mapped from ** unpacks) expressions that can match given key. + result = defaultdict(list) + # Keys that are guaranteed to be present no matter what (e.g. for all items of a union) + always_present_keys = set() + # Indicates latest encountered ** unpack among items. + last_star_found = None + + for item_name_expr, item_arg in kwargs: if item_name_expr: key_type = self.accept(item_name_expr) values = try_getting_str_literals(item_name_expr, key_type) + literal_value = None if values and len(values) == 1: literal_value = values[0] - if literal_value is None: - key_context = item_name_expr or item_arg - self.chk.fail( - message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, - key_context, - code=codes.LITERAL_REQ, - ) - return None + if literal_value is None: + key_context = item_name_expr or item_arg + self.chk.fail( + message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, + key_context, + code=codes.LITERAL_REQ, + ) + return None + else: + # A directly present key unconditionally shadows all previously found + # values from ** items. + # TODO: for duplicate keys, type-check all values. + result[literal_value] = [item_arg] + always_present_keys.add(literal_value) else: - item_names.append(literal_value) - return dict(zip(item_names, item_args)) + last_star_found = item_arg + if not self.validate_star_typeddict_item( + item_arg, callee, result, always_present_keys + ): + return None + if self.chk.options.extra_checks and last_star_found is not None: + absent_keys = [] + for key in callee.items: + if key not in callee.required_keys and key not in result: + absent_keys.append(key) + if absent_keys: + # Having an optional key not explicitly declared by a ** unpacked + # TypedDict is unsafe, it may be an (incompatible) subtype at runtime. + # TODO: catch the cases where a declared key is overridden by a subsequent + # ** item without it (and not again overriden with complete ** item). + self.msg.non_required_keys_absent_with_star(absent_keys, last_star_found) + return result, always_present_keys + + def validate_star_typeddict_item( + self, + item_arg: Expression, + callee: TypedDictType, + result: dict[str, list[Expression]], + always_present_keys: set[str], + ) -> bool: + """Update keys/expressions from a ** expression in TypedDict constructor. + + Note `result` and `always_present_keys` are updated in place. Return true if the + expression `item_arg` may valid in `callee` TypedDict context. + """ + with self.chk.local_type_map(), self.msg.filter_errors(): + inferred = get_proper_type(self.accept(item_arg, type_context=callee)) + possible_tds = [] + if isinstance(inferred, TypedDictType): + possible_tds = [inferred] + elif isinstance(inferred, UnionType): + for item in get_proper_types(inferred.relevant_items()): + if isinstance(item, TypedDictType): + possible_tds.append(item) + elif not self.valid_unpack_fallback_item(item): + self.msg.unsupported_target_for_star_typeddict(item, item_arg) + return False + elif not self.valid_unpack_fallback_item(inferred): + self.msg.unsupported_target_for_star_typeddict(inferred, item_arg) + return False + all_keys: set[str] = set() + for td in possible_tds: + all_keys |= td.items.keys() + for key in all_keys: + arg = TempNode( + UnionType.make_union([td.items[key] for td in possible_tds if key in td.items]) + ) + arg.set_line(item_arg) + if all(key in td.required_keys for td in possible_tds): + always_present_keys.add(key) + # Always present keys override previously found values. This is done + # to support use cases like `Config({**defaults, **overrides})`, where + # some `overrides` types are narrower that types in `defaults`, and + # former are too wide for `Config`. + if result[key]: + first = result[key][0] + if not isinstance(first, TempNode): + # We must always preserve any non-synthetic values, so that + # we will accept them even if they are shadowed. + result[key] = [first, arg] + else: + result[key] = [arg] + else: + result[key] = [arg] + else: + # If this key is not required at least in some item of a union + # it may not shadow previous item, so we need to type check both. + result[key].append(arg) + return True + + def valid_unpack_fallback_item(self, typ: ProperType) -> bool: + if isinstance(typ, AnyType): + return True + if not isinstance(typ, Instance) or not typ.type.has_base("typing.Mapping"): + return False + mapped = map_instance_to_supertype(typ, self.chk.lookup_typeinfo("typing.Mapping")) + return all(isinstance(a, AnyType) for a in get_proper_types(mapped.args)) def match_typeddict_call_with_dict( - self, callee: TypedDictType, kwargs: DictExpr, context: Context + self, + callee: TypedDictType, + kwargs: list[tuple[Expression | None, Expression]], + context: Context, ) -> bool: - validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) - if validated_kwargs is not None: + result = self.validate_typeddict_kwargs(kwargs=kwargs, callee=callee) + if result is not None: + validated_kwargs, _ = result return callee.required_keys <= set(validated_kwargs.keys()) <= set(callee.items.keys()) else: return False def check_typeddict_call_with_dict( - self, callee: TypedDictType, kwargs: DictExpr, context: Context, orig_callee: Type | None + self, + callee: TypedDictType, + kwargs: list[tuple[Expression | None, Expression]], + context: Context, + orig_callee: Type | None, ) -> Type: - validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) - if validated_kwargs is not None: + result = self.validate_typeddict_kwargs(kwargs=kwargs, callee=callee) + if result is not None: + validated_kwargs, always_present_keys = result return self.check_typeddict_call_with_kwargs( - callee, kwargs=validated_kwargs, context=context, orig_callee=orig_callee + callee, + kwargs=validated_kwargs, + context=context, + orig_callee=orig_callee, + always_present_keys=always_present_keys, ) else: return AnyType(TypeOfAny.from_error) @@ -803,20 +913,37 @@ def typeddict_callable_from_context(self, callee: TypedDictType) -> CallableType def check_typeddict_call_with_kwargs( self, callee: TypedDictType, - kwargs: dict[str, Expression], + kwargs: dict[str, list[Expression]], context: Context, orig_callee: Type | None, + always_present_keys: set[str], ) -> Type: actual_keys = kwargs.keys() - if not (callee.required_keys <= actual_keys <= callee.items.keys()): - expected_keys = [ - key - for key in callee.items.keys() - if key in callee.required_keys or key in actual_keys - ] - self.msg.unexpected_typeddict_keys( - callee, expected_keys=expected_keys, actual_keys=list(actual_keys), context=context - ) + if not ( + callee.required_keys <= always_present_keys and actual_keys <= callee.items.keys() + ): + if not (actual_keys <= callee.items.keys()): + self.msg.unexpected_typeddict_keys( + callee, + expected_keys=[ + key + for key in callee.items.keys() + if key in callee.required_keys or key in actual_keys + ], + actual_keys=list(actual_keys), + context=context, + ) + if not (callee.required_keys <= always_present_keys): + self.msg.unexpected_typeddict_keys( + callee, + expected_keys=[ + key for key in callee.items.keys() if key in callee.required_keys + ], + actual_keys=[ + key for key in always_present_keys if key in callee.required_keys + ], + context=context, + ) if callee.required_keys > actual_keys: # found_set is a sub-set of the required_keys # This means we're missing some keys and as such, we can't @@ -839,7 +966,10 @@ def check_typeddict_call_with_kwargs( with self.msg.filter_errors(), self.chk.local_type_map(): orig_ret_type, _ = self.check_callable_call( infer_callee, - list(kwargs.values()), + # We use first expression for each key to infer type variables of a generic + # TypedDict. This is a bit arbitrary, but in most cases will work better than + # trying to infer a union or a join. + [args[0] for args in kwargs.values()], [ArgKind.ARG_NAMED] * len(kwargs), context, list(kwargs.keys()), @@ -856,17 +986,18 @@ def check_typeddict_call_with_kwargs( for item_name, item_expected_type in ret_type.items.items(): if item_name in kwargs: - item_value = kwargs[item_name] - self.chk.check_simple_assignment( - lvalue_type=item_expected_type, - rvalue=item_value, - context=item_value, - msg=ErrorMessage( - message_registry.INCOMPATIBLE_TYPES.value, code=codes.TYPEDDICT_ITEM - ), - lvalue_name=f'TypedDict item "{item_name}"', - rvalue_name="expression", - ) + item_values = kwargs[item_name] + for item_value in item_values: + self.chk.check_simple_assignment( + lvalue_type=item_expected_type, + rvalue=item_value, + context=item_value, + msg=ErrorMessage( + message_registry.INCOMPATIBLE_TYPES.value, code=codes.TYPEDDICT_ITEM + ), + lvalue_name=f'TypedDict item "{item_name}"', + rvalue_name="expression", + ) return orig_ret_type @@ -4382,7 +4513,7 @@ def check_typeddict_literal_in_context( self, e: DictExpr, typeddict_context: TypedDictType ) -> Type: orig_ret_type = self.check_typeddict_call_with_dict( - callee=typeddict_context, kwargs=e, context=e, orig_callee=None + callee=typeddict_context, kwargs=e.items, context=e, orig_callee=None ) ret_type = get_proper_type(orig_ret_type) if isinstance(ret_type, TypedDictType): @@ -4482,7 +4613,9 @@ def find_typeddict_context( for item in context.items: item_contexts = self.find_typeddict_context(item, dict_expr) for item_context in item_contexts: - if self.match_typeddict_call_with_dict(item_context, dict_expr, dict_expr): + if self.match_typeddict_call_with_dict( + item_context, dict_expr.items, dict_expr + ): items.append(item_context) return items # No TypedDict type in context. diff --git a/mypy/main.py b/mypy/main.py index b60c5b2a6bba..22ff3e32a718 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -826,10 +826,12 @@ def add_invertible_flag( ) add_invertible_flag( - "--strict-concatenate", + "--extra-checks", default=False, strict_flag=True, - help="Make arguments prepended via Concatenate be truly positional-only", + help="Enable additional checks that are technically correct but may be impractical " + "in real code. For example, this prohibits partial overlap in TypedDict updates, " + "and makes arguments prepended via Concatenate positional-only", group=strictness_group, ) @@ -1155,6 +1157,8 @@ def add_invertible_flag( parser.add_argument( "--disable-memoryview-promotion", action="store_true", help=argparse.SUPPRESS ) + # This flag is deprecated, it has been moved to --extra-checks + parser.add_argument("--strict-concatenate", action="store_true", help=argparse.SUPPRESS) # options specifying code to check code_group = parser.add_argument_group( @@ -1226,8 +1230,11 @@ def add_invertible_flag( parser.error(f"Cannot find config file '{config_file}'") options = Options() + strict_option_set = False def set_strict_flags() -> None: + nonlocal strict_option_set + strict_option_set = True for dest, value in strict_flag_assignments: setattr(options, dest, value) @@ -1379,6 +1386,8 @@ def set_strict_flags() -> None: "Warning: --enable-recursive-aliases is deprecated;" " recursive types are enabled by default" ) + if options.strict_concatenate and not strict_option_set: + print("Warning: --strict-concatenate is deprecated; use --extra-checks instead") # Set target. if special_opts.modules + special_opts.packages: diff --git a/mypy/messages.py b/mypy/messages.py index b74a795a4318..ea7923c59778 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1757,6 +1757,24 @@ def need_annotation_for_var( def explicit_any(self, ctx: Context) -> None: self.fail('Explicit "Any" is not allowed', ctx) + def unsupported_target_for_star_typeddict(self, typ: Type, ctx: Context) -> None: + self.fail( + "Unsupported type {} for ** expansion in TypedDict".format( + format_type(typ, self.options) + ), + ctx, + code=codes.TYPEDDICT_ITEM, + ) + + def non_required_keys_absent_with_star(self, keys: list[str], ctx: Context) -> None: + self.fail( + "Non-required {} not explicitly found in any ** item".format( + format_key_list(keys, short=True) + ), + ctx, + code=codes.TYPEDDICT_ITEM, + ) + def unexpected_typeddict_keys( self, typ: TypedDictType, diff --git a/mypy/options.py b/mypy/options.py index f75734124eb0..e1d731c1124c 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -40,6 +40,7 @@ class BuildType: "disallow_untyped_defs", "enable_error_code", "enabled_error_codes", + "extra_checks", "follow_imports_for_stubs", "follow_imports", "ignore_errors", @@ -200,9 +201,12 @@ def __init__(self) -> None: # This makes 1 == '1', 1 in ['1'], and 1 is '1' errors. self.strict_equality = False - # Make arguments prepended via Concatenate be truly positional-only. + # Deprecated, use extra_checks instead. self.strict_concatenate = False + # Enable additional checks that are technically correct but impractical. + self.extra_checks = False + # Report an error for any branches inferred to be unreachable as a result of # type analysis. self.warn_unreachable = False diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index b83c0192a14b..f5dea0621177 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -31,7 +31,9 @@ TypedDictType, TypeOfAny, TypeVarType, + UnionType, get_proper_type, + get_proper_types, ) @@ -404,6 +406,31 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: assert isinstance(arg_type, TypedDictType) arg_type = arg_type.as_anonymous() arg_type = arg_type.copy_modified(required_keys=set()) + if ctx.args and ctx.args[0]: + with ctx.api.msg.filter_errors(): + inferred = get_proper_type( + ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type) + ) + possible_tds = [] + if isinstance(inferred, TypedDictType): + possible_tds = [inferred] + elif isinstance(inferred, UnionType): + possible_tds = [ + t + for t in get_proper_types(inferred.relevant_items()) + if isinstance(t, TypedDictType) + ] + items = [] + for td in possible_tds: + item = arg_type.copy_modified( + required_keys=(arg_type.required_keys | td.required_keys) + & arg_type.items.keys() + ) + if not ctx.api.options.extra_checks: + item = item.copy_modified(item_names=list(td.items)) + items.append(item) + if items: + arg_type = make_simplified_union(items) return signature.copy_modified(arg_types=[arg_type]) return signature diff --git a/mypy/semanal.py b/mypy/semanal.py index 43960d972101..d18cc4298fed 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -5084,14 +5084,14 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None: For other variants of dict(...), return None. """ - if not all(kind == ARG_NAMED for kind in call.arg_kinds): + if not all(kind in (ARG_NAMED, ARG_STAR2) for kind in call.arg_kinds): # Must still accept those args. for a in call.args: a.accept(self) return None expr = DictExpr( [ - (StrExpr(cast(str, key)), value) # since they are all ARG_NAMED + (StrExpr(key) if key is not None else None, value) for key, value in zip(call.arg_names, call.args) ] ) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a3b28a3e24de..c9de56edfa36 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -694,7 +694,9 @@ def visit_callable_type(self, left: CallableType) -> bool: right, is_compat=self._is_subtype, ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, - strict_concatenate=self.options.strict_concatenate if self.options else True, + strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate) + if self.options + else True, ) elif isinstance(right, Overloaded): return all(self._is_subtype(left, item) for item in right.items) @@ -858,7 +860,11 @@ def visit_overloaded(self, left: Overloaded) -> bool: else: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. - strict_concat = self.options.strict_concatenate if self.options else True + strict_concat = ( + (self.options.extra_checks or self.options.strict_concatenate) + if self.options + else True + ) if left_index not in matched_overloads and ( is_callable_compatible( left_item, diff --git a/mypy/types.py b/mypy/types.py index 33673b58f775..131383790ec8 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2437,6 +2437,7 @@ def copy_modified( *, fallback: Instance | None = None, item_types: list[Type] | None = None, + item_names: list[str] | None = None, required_keys: set[str] | None = None, ) -> TypedDictType: if fallback is None: @@ -2447,6 +2448,9 @@ def copy_modified( items = dict(zip(self.items, item_types)) if required_keys is None: required_keys = self.required_keys + if item_names is not None: + items = {k: v for (k, v) in items.items() if k in item_names} + required_keys &= set(item_names) return TypedDictType(items, required_keys, fallback, self.line, self.column) def create_anonymous_fallback(self) -> Instance: diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index cafcaca0a14c..bebbbf4b1676 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -570,7 +570,7 @@ reveal_type(f(n)) # N: Revealed type is "def (builtins.int, builtins.bytes) -> [builtins fixtures/paramspec.pyi] [case testParamSpecConcatenateNamedArgs] -# flags: --python-version 3.8 --strict-concatenate +# flags: --python-version 3.8 --extra-checks # this is one noticeable deviation from PEP but I believe it is for the better from typing_extensions import ParamSpec, Concatenate from typing import Callable, TypeVar diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index fc487d2d553d..4d2d64848515 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -2885,3 +2885,364 @@ d: A d[''] # E: TypedDict "A" has no key "" [builtins fixtures/dict.pyi] [typing fixtures/typing-typeddict.pyi] + +[case testTypedDictFlexibleUpdate] +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int}) + +a = A({"foo": 1, "bar": 2}) +b = B({"foo": 2}) +a.update({"foo": 2}) +a.update(b) +a.update(a) +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictStrictUpdate] +# flags: --extra-checks +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int}) + +a = A({"foo": 1, "bar": 2}) +b = B({"foo": 2}) +a.update({"foo": 2}) # OK +a.update(b) # E: Argument 1 to "update" of "TypedDict" has incompatible type "B"; expected "TypedDict({'foo': int, 'bar'?: int})" +a.update(a) # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictFlexibleUpdateUnion] +from typing import Union +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int}) +C = TypedDict("C", {"bar": int}) + +a = A({"foo": 1, "bar": 2}) +u: Union[B, C] +a.update(u) +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictFlexibleUpdateUnionExtra] +from typing import Union +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +B = TypedDict("B", {"foo": int, "extra": int}) +C = TypedDict("C", {"bar": int, "extra": int}) + +a = A({"foo": 1, "bar": 2}) +u: Union[B, C] +a.update(u) +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictFlexibleUpdateUnionStrict] +# flags: --extra-checks +from typing import Union, NotRequired +from mypy_extensions import TypedDict + +A = TypedDict("A", {"foo": int, "bar": int}) +A1 = TypedDict("A1", {"foo": int, "bar": NotRequired[int]}) +A2 = TypedDict("A2", {"foo": NotRequired[int], "bar": int}) +B = TypedDict("B", {"foo": int}) +C = TypedDict("C", {"bar": int}) + +a = A({"foo": 1, "bar": 2}) +u: Union[B, C] +a.update(u) # E: Argument 1 to "update" of "TypedDict" has incompatible type "Union[B, C]"; expected "Union[TypedDict({'foo': int, 'bar'?: int}), TypedDict({'foo'?: int, 'bar': int})]" +u2: Union[A1, A2] +a.update(u2) # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackSame] +# flags: --extra-checks +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +foo1: Foo = {"a": 1, "b": 1} +foo2: Foo = {**foo1, "b": 2} +foo3 = Foo(**foo1, b=2) +foo4 = Foo({**foo1, "b": 2}) +foo5 = Foo(dict(**foo1, b=2)) +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackCompatible] +# flags: --extra-checks +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: int + +foo: Foo = {"a": 1} +bar: Bar = {**foo, "b": 2} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackIncompatible] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: str + +class Bar(TypedDict): + a: int + b: int + +foo: Foo = {"a": 1, "b": "a"} +bar1: Bar = {**foo, "b": 2} # Incompatible item is overriden +bar2: Bar = {**foo, "a": 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNotRequiredKeyIncompatible] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: NotRequired[str] + +class Bar(TypedDict): + a: NotRequired[int] + +foo: Foo = {} +bar: Bar = {**foo} # E: Incompatible types (expression has type "str", TypedDict item "a" has type "int") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + + +[case testTypedDictUnpackMissingOrExtraKey] +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: int + +foo1: Foo = {"a": 1} +bar1: Bar = {"a": 1, "b": 1} +foo2: Foo = {**bar1} # E: Extra key "b" for TypedDict "Foo" +bar2: Bar = {**foo1} # E: Missing key "b" for TypedDict "Bar" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNotRequiredKeyExtra] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: NotRequired[int] + +foo1: Foo = {"a": 1} +bar1: Bar = {"a": 1} +foo2: Foo = {**bar1} # E: Extra key "b" for TypedDict "Foo" +bar2: Bar = {**foo1} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackRequiredKeyMissing] +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: NotRequired[int] + +class Bar(TypedDict): + a: int + +foo: Foo = {"a": 1} +bar: Bar = {**foo} # E: Missing key "a" for TypedDict "Bar" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackMultiple] +# flags: --extra-checks +from typing import TypedDict + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + b: int + +class Baz(TypedDict): + a: int + b: int + c: int + +foo: Foo = {"a": 1} +bar: Bar = {"b": 1} +baz: Baz = {**foo, **bar, "c": 1} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNested] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + c: Foo + d: int + +foo: Foo = {"a": 1, "b": 1} +bar: Bar = {"c": foo, "d": 1} +bar2: Bar = {**bar, "c": {**bar["c"], "b": 2}, "d": 2} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackNestedError] +from typing import TypedDict + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + c: Foo + d: int + +foo: Foo = {"a": 1, "b": 1} +bar: Bar = {"c": foo, "d": 1} +bar2: Bar = {**bar, "c": {**bar["c"], "b": "wrong"}, "d": 2} # E: Incompatible types (expression has type "str", TypedDict item "b" has type "int") +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackOverrideRequired] +from mypy_extensions import TypedDict + +Details = TypedDict('Details', {'first_name': str, 'last_name': str}) +DetailsSubset = TypedDict('DetailsSubset', {'first_name': str, 'last_name': str}, total=False) +defaults: Details = {'first_name': 'John', 'last_name': 'Luther'} + +def generate(data: DetailsSubset) -> Details: + return {**defaults, **data} # OK +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackUntypedDict] +from typing import Any, Dict, TypedDict + +class Bar(TypedDict): + pass + +foo: Dict[str, Any] = {} +bar: Bar = {**foo} # E: Unsupported type "Dict[str, Any]" for ** expansion in TypedDict +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackIntoUnion] +from typing import TypedDict, Union + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + b: int + +foo: Foo = {'a': 1} +foo_or_bar: Union[Foo, Bar] = {**foo} +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackFromUnion] +from typing import TypedDict, Union + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + b: int + +foo_or_bar: Union[Foo, Bar] = {'b': 1} +foo: Bar = {**foo_or_bar} # E: Extra key "a" for TypedDict "Bar" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackUnionRequiredMissing] +from typing import TypedDict, NotRequired, Union + +class Foo(TypedDict): + a: int + b: int + +class Bar(TypedDict): + a: int + b: NotRequired[int] + +foo_or_bar: Union[Foo, Bar] = {"a": 1} +foo: Foo = {**foo_or_bar} # E: Missing key "b" for TypedDict "Foo" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackInference] +from typing import TypedDict, Generic, TypeVar + +class Foo(TypedDict): + a: int + b: str + +T = TypeVar("T") +class TD(TypedDict, Generic[T]): + a: T + b: str + +foo: Foo +bar = TD(**foo) +reveal_type(bar) # N: Revealed type is "TypedDict('__main__.TD', {'a': builtins.int, 'b': builtins.str})" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackStrictMode] +# flags: --extra-checks +from typing import TypedDict, NotRequired + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + a: int + b: NotRequired[int] + +foo: Foo +bar: Bar = {**foo} # E: Non-required key "b" not explicitly found in any ** item +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi] + +[case testTypedDictUnpackAny] +from typing import Any, TypedDict, NotRequired, Dict, Union + +class Foo(TypedDict): + a: int + b: NotRequired[int] + +x: Any +y: Dict[Any, Any] +z: Union[Any, Dict[Any, Any]] +t1: Foo = {**x} # E: Missing key "a" for TypedDict "Foo" +t2: Foo = {**y} # E: Missing key "a" for TypedDict "Foo" +t3: Foo = {**z} # E: Missing key "a" for TypedDict "Foo" +[builtins fixtures/dict.pyi] +[typing fixtures/typing-typeddict.pyi]