Skip to content

Commit

Permalink
Add ReadOnly support for TypedDicts
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn committed Aug 5, 2024
1 parent b56f357 commit 9a43f2e
Show file tree
Hide file tree
Showing 26 changed files with 473 additions and 82 deletions.
13 changes: 8 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,9 @@ def check_typeddict_call_with_kwargs(
always_present_keys: set[str],
) -> Type:
actual_keys = kwargs.keys()
assigned_readonly_keys = actual_keys & callee.readonly_keys
if assigned_readonly_keys:
self.msg.readonly_keys_mutated(assigned_readonly_keys, context=context)
if not (
callee.required_keys <= always_present_keys and actual_keys <= callee.items.keys()
):
Expand Down Expand Up @@ -4337,7 +4340,7 @@ def visit_index_with_type(
else:
return self.nonliteral_tuple_index_helper(left_type, index)
elif isinstance(left_type, TypedDictType):
return self.visit_typeddict_index_expr(left_type, e.index)
return self.visit_typeddict_index_expr(left_type, e.index)[0]
elif isinstance(left_type, FunctionLike) and left_type.is_type_obj():
if left_type.type_object().is_enum:
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
Expand Down Expand Up @@ -4518,7 +4521,7 @@ def union_tuple_fallback_item(self, left_type: TupleType) -> Type:

def visit_typeddict_index_expr(
self, td_type: TypedDictType, index: Expression, setitem: bool = False
) -> Type:
) -> tuple[Type, set[str]]:
if isinstance(index, StrExpr):
key_names = [index.value]
else:
Expand All @@ -4541,17 +4544,17 @@ def visit_typeddict_index_expr(
key_names.append(key_type.value)
else:
self.msg.typeddict_key_must_be_string_literal(td_type, index)
return AnyType(TypeOfAny.from_error)
return AnyType(TypeOfAny.from_error), set()

value_types = []
for key_name in key_names:
value_type = td_type.items.get(key_name)
if value_type is None:
self.msg.typeddict_key_not_found(td_type, key_name, index, setitem)
return AnyType(TypeOfAny.from_error)
return AnyType(TypeOfAny.from_error), set()
else:
value_types.append(value_type)
return make_simplified_union(value_types)
return make_simplified_union(value_types), set(key_names)

def visit_enum_index_expr(
self, enum_type: TypeInfo, index: Expression, context: Context
Expand Down
6 changes: 5 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
freshen_all_functions_type_vars,
)
from mypy.maptype import map_instance_to_supertype
from mypy import errorcodes as codes
from mypy.messages import MessageBuilder
from mypy.nodes import (
ARG_POS,
Expand Down Expand Up @@ -1185,9 +1186,12 @@ def analyze_typeddict_access(
if isinstance(mx.context, IndexExpr):
# Since we can get this during `a['key'] = ...`
# it is safe to assume that the context is `IndexExpr`.
item_type = mx.chk.expr_checker.visit_typeddict_index_expr(
item_type, key_names = mx.chk.expr_checker.visit_typeddict_index_expr(
typ, mx.context.index, setitem=True
)
assigned_readonly_keys = typ.readonly_keys & key_names
if assigned_readonly_keys:
mx.msg.readonly_keys_mutated(assigned_readonly_keys, context=mx.context)
else:
# It can also be `a.__setitem__(...)` direct call.
# In this case `item_type` can be `Any`,
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ def get_mapping_item_type(
with self.msg.filter_errors() as local_errors:
result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
mapping_type, key
)
)[0]
has_local_errors = local_errors.has_new_errors()
# If we can't determine the type statically fall back to treating it as a normal
# mapping
Expand Down
4 changes: 3 additions & 1 deletion mypy/copytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
return self.copy_common(t, TupleType(t.items, t.partial_fallback, implicit=t.implicit))

def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
return self.copy_common(t, TypedDictType(t.items, t.required_keys, t.fallback))
return self.copy_common(
t, TypedDictType(t.items, t.required_keys, t.readonly_keys, t.fallback)
)

def visit_literal_type(self, t: LiteralType) -> ProperType:
return self.copy_common(t, LiteralType(value=t.value, fallback=t.fallback))
Expand Down
3 changes: 3 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def __hash__(self) -> int:
ANNOTATION_UNCHECKED = ErrorCode(
"annotation-unchecked", "Notify about type annotations in unchecked functions", "General"
)
TYPEDDICT_READONLY_MUTATED = ErrorCode(
"typeddict-readonly-mutated", "TypedDict's ReadOnly key is mutated", "General"
)
POSSIBLY_UNDEFINED: Final[ErrorCode] = ErrorCode(
"possibly-undefined",
"Warn about variables that are defined only in some execution paths",
Expand Down
5 changes: 0 additions & 5 deletions mypy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,11 +680,6 @@ def clear_errors_in_targets(self, path: str, targets: set[str]) -> None:
self.has_blockers.remove(path)

def generate_unused_ignore_errors(self, file: str) -> None:
if (
is_typeshed_file(self.options.abs_custom_typeshed_dir if self.options else None, file)
or file in self.ignored_files
):
return
ignored_lines = self.ignored_lines[file]
used_ignored_lines = self.used_ignored_lines[file]
for line, ignored_codes in ignored_lines.items():
Expand Down
2 changes: 1 addition & 1 deletion mypy/exprtotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def expr_to_unanalyzed_type(
value, options, allow_new_syntax, expr
)
result = TypedDictType(
items, set(), Instance(MISSING_FALLBACK, ()), expr.line, expr.column
items, set(), set(), Instance(MISSING_FALLBACK, ()), expr.line, expr.column
)
result.extra_items_from = extra_items_from
return result
Expand Down
2 changes: 1 addition & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2170,7 +2170,7 @@ def visit_Dict(self, n: ast3.Dict) -> Type:
continue
return self.invalid_type(n)
items[item_name.value] = self.visit(value)
result = TypedDictType(items, set(), _dummy_fallback, n.lineno, n.col_offset)
result = TypedDictType(items, set(), set(), _dummy_fallback, n.lineno, n.col_offset)
result.extra_items_from = extra_items_from
return result

Expand Down
4 changes: 3 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,9 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
# We need to filter by items.keys() since some required keys present in both t and
# self.s might be missing from the join if the types are incompatible.
required_keys = set(items.keys()) & t.required_keys & self.s.required_keys
return TypedDictType(items, required_keys, fallback)
# If one type has a key as readonly, we mark it as readonly for both:
readonly_keys = t.readonly_keys | t.readonly_keys
return TypedDictType(items, required_keys, readonly_keys, fallback)
elif isinstance(self.s, Instance):
return join_types(self.s, t.fallback)
else:
Expand Down
12 changes: 11 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
items = dict(item_list)
fallback = self.s.create_anonymous_fallback()
required_keys = t.required_keys | self.s.required_keys
return TypedDictType(items, required_keys, fallback)
readonly_keys = t.readonly_keys | self.s.readonly_keys
return TypedDictType(items, required_keys, readonly_keys, fallback)
elif isinstance(self.s, Instance) and is_subtype(t, self.s):
return t
else:
Expand Down Expand Up @@ -1139,6 +1140,9 @@ def typed_dict_mapping_overlap(
- TypedDict(x=str, y=str, total=False) doesn't overlap with Dict[str, int]
- TypedDict(x=int, y=str, total=False) overlaps with Dict[str, str]
* A TypedDict with at least one ReadOnly[] key does not overlap
with Dict or MutableMapping, because they assume mutable data.
As usual empty, dictionaries lie in a gray area. In general, List[str] and List[str]
are considered non-overlapping despite empty list belongs to both. However, List[int]
and List[Never] are considered overlapping.
Expand All @@ -1159,6 +1163,12 @@ def typed_dict_mapping_overlap(
assert isinstance(right, TypedDictType)
typed, other = right, left

mutable_mapping = next(
(base for base in other.type.mro if base.fullname == "typing.MutableMapping"), None
)
if mutable_mapping is not None and typed.readonly_keys:
return False

mapping = next(base for base in other.type.mro if base.fullname == "typing.Mapping")
other = map_instance_to_supertype(other, mapping)
key_type, value_type = get_proper_types(other.args)
Expand Down
20 changes: 17 additions & 3 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,17 @@ def invalid_index_type(
code=code,
)

def readonly_keys_mutated(self, keys: set[str], context: Context) -> None:
if len(keys) == 1:
suffix = "is"
else:
suffix = "are"
self.fail(
"ReadOnly {} TypedDict {} mutated".format(format_key_list(list(sorted(keys))), suffix),
code=codes.TYPEDDICT_READONLY_MUTATED,
context=context,
)

def too_few_arguments(
self, callee: CallableType, context: Context, argument_names: Sequence[str | None] | None
) -> None:
Expand Down Expand Up @@ -2612,10 +2623,13 @@ def format_literal_value(typ: LiteralType) -> str:
return format(typ.fallback)
items = []
for item_name, item_type in typ.items.items():
modifier = "" if item_name in typ.required_keys else "?"
modifier = ""
if item_name not in typ.required_keys:
modifier += "?"
if item_name in typ.readonly_keys:
modifier += "="
items.append(f"{item_name!r}{modifier}: {format(item_type)}")
s = f"TypedDict({{{', '.join(items)}}})"
return s
return f"TypedDict({{{', '.join(items)}}})"
elif isinstance(typ, LiteralType):
return f"Literal[{format_literal_value(typ)}]"
elif isinstance(typ, UnionType):
Expand Down
4 changes: 4 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
)
return AnyType(TypeOfAny.from_error)

assigned_readonly_keys = ctx.type.readonly_keys & set(keys)
if assigned_readonly_keys:
ctx.api.msg.readonly_keys_mutated(assigned_readonly_keys, context=ctx.context)

default_type = ctx.arg_types[1][0]

value_types = []
Expand Down
1 change: 1 addition & 0 deletions mypy/plugins/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def is_special_target(right: ProperType) -> bool:
"mypy.types.ErasedType",
"mypy.types.DeletedType",
"mypy.types.RequiredType",
"mypy.types.ReadOnlyType",
):
# Special case: these are not valid targets for a type alias and thus safe.
# TODO: introduce a SyntheticType base to simplify this?
Expand Down
8 changes: 4 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7147,7 +7147,7 @@ def type_analyzer(
allow_tuple_literal: bool = False,
allow_unbound_tvars: bool = False,
allow_placeholder: bool = False,
allow_required: bool = False,
allow_typed_dict_special_forms: bool = False,
allow_param_spec_literals: bool = False,
allow_unpack: bool = False,
report_invalid_types: bool = True,
Expand All @@ -7166,7 +7166,7 @@ def type_analyzer(
allow_tuple_literal=allow_tuple_literal,
report_invalid_types=report_invalid_types,
allow_placeholder=allow_placeholder,
allow_required=allow_required,
allow_typed_dict_special_forms=allow_typed_dict_special_forms,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
prohibit_self_type=prohibit_self_type,
Expand All @@ -7189,7 +7189,7 @@ def anal_type(
allow_tuple_literal: bool = False,
allow_unbound_tvars: bool = False,
allow_placeholder: bool = False,
allow_required: bool = False,
allow_typed_dict_special_forms: bool = False,
allow_param_spec_literals: bool = False,
allow_unpack: bool = False,
report_invalid_types: bool = True,
Expand Down Expand Up @@ -7224,7 +7224,7 @@ def anal_type(
allow_unbound_tvars=allow_unbound_tvars,
allow_tuple_literal=allow_tuple_literal,
allow_placeholder=allow_placeholder,
allow_required=allow_required,
allow_typed_dict_special_forms=allow_typed_dict_special_forms,
allow_param_spec_literals=allow_param_spec_literals,
allow_unpack=allow_unpack,
report_invalid_types=report_invalid_types,
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def anal_type(
tvar_scope: TypeVarLikeScope | None = None,
allow_tuple_literal: bool = False,
allow_unbound_tvars: bool = False,
allow_required: bool = False,
allow_typed_dict_special_forms: bool = False,
allow_placeholder: bool = False,
report_invalid_types: bool = True,
prohibit_self_type: str | None = None,
Expand Down
Loading

0 comments on commit 9a43f2e

Please sign in to comment.