Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support flexible TypedDict creation/update #15425

Merged
merged 13 commits into from
Jun 26, 2023
255 changes: 194 additions & 61 deletions mypy/checkexpr.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,14 @@ def add_invertible_flag(
group=strictness_group,
)

add_invertible_flag(
"--strict-typeddict-update",
default=False,
strict_flag=True,
help="Disallow partial overlap in TypedDict update (including ** in constructor)",
group=strictness_group,
)

strict_help = "Strict mode; enables the following flags: {}".format(
", ".join(strict_flag_names)
)
Expand Down
18 changes: 18 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,24 @@ def need_annotation_for_var(
def explicit_any(self, ctx: Context) -> None:
self.fail('Explicit "Any" is not allowed', ctx)

def unsupported_target_for_star_typeddict(self, typ: Type, ctx: Context) -> None:
self.fail(
"Unsupported type {} for ** expansion in TypedDict".format(
format_type(typ, self.options)
),
ctx,
code=codes.TYPEDDICT_ITEM,
)

def non_required_keys_absent_with_star(self, keys: list[str], ctx: Context) -> None:
self.fail(
"Non-required {} not explicitly found in any ** item".format(
format_key_list(keys, short=True)
),
ctx,
code=codes.TYPEDDICT_ITEM,
)

def unexpected_typeddict_keys(
self,
typ: TypedDictType,
Expand Down
4 changes: 4 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class BuildType:
"strict_concatenate",
"strict_equality",
"strict_optional",
"strict_typeddict_update",
"warn_no_return",
"warn_return_any",
"warn_unreachable",
Expand Down Expand Up @@ -203,6 +204,9 @@ def __init__(self) -> None:
# Make arguments prepended via Concatenate be truly positional-only.
self.strict_concatenate = False

# Disallow partial overlap in TypedDict update (including ** in constructor).
self.strict_typeddict_update = False

# Report an error for any branches inferred to be unreachable as a result of
# type analysis.
self.warn_unreachable = False
Expand Down
27 changes: 27 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
TypedDictType,
TypeOfAny,
TypeVarType,
UnionType,
get_proper_type,
get_proper_types,
)


Expand Down Expand Up @@ -404,6 +406,31 @@ def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
assert isinstance(arg_type, TypedDictType)
arg_type = arg_type.as_anonymous()
arg_type = arg_type.copy_modified(required_keys=set())
if ctx.args and ctx.args[0]:
with ctx.api.msg.filter_errors():
inferred = get_proper_type(
ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
)
possible_tds = []
if isinstance(inferred, TypedDictType):
possible_tds = [inferred]
elif isinstance(inferred, UnionType):
possible_tds = [
t
for t in get_proper_types(inferred.relevant_items())
if isinstance(t, TypedDictType)
]
items = []
for td in possible_tds:
item = arg_type.copy_modified(
required_keys=(arg_type.required_keys | td.required_keys)
& arg_type.items.keys()
)
if not ctx.api.options.strict_typeddict_update:
item = item.copy_modified(item_names=list(td.items))
items.append(item)
if items:
arg_type = make_simplified_union(items)
return signature.copy_modified(arg_types=[arg_type])
return signature

Expand Down
4 changes: 2 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5084,14 +5084,14 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None:

For other variants of dict(...), return None.
"""
if not all(kind == ARG_NAMED for kind in call.arg_kinds):
if not all(kind in (ARG_NAMED, ARG_STAR2) for kind in call.arg_kinds):
# Must still accept those args.
for a in call.args:
a.accept(self)
return None
expr = DictExpr(
[
(StrExpr(cast(str, key)), value) # since they are all ARG_NAMED
(StrExpr(key) if key is not None else None, value)
for key, value in zip(call.arg_names, call.args)
]
)
Expand Down
4 changes: 4 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2433,6 +2433,7 @@ def copy_modified(
*,
fallback: Instance | None = None,
item_types: list[Type] | None = None,
item_names: list[str] | None = None,
required_keys: set[str] | None = None,
) -> TypedDictType:
if fallback is None:
Expand All @@ -2443,6 +2444,9 @@ def copy_modified(
items = dict(zip(self.items, item_types))
if required_keys is None:
required_keys = self.required_keys
if item_names is not None:
items = {k: v for (k, v) in items.items() if k in item_names}
required_keys &= set(item_names)
return TypedDictType(items, required_keys, fallback, self.line, self.column)

def create_anonymous_fallback(self) -> Instance:
Expand Down
Loading
Loading