Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix strict optional handling in dataclasses #15571

Merged
merged 3 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
info: TypeInfo,
kw_only: bool,
is_neither_frozen_nor_nonfrozen: bool,
api: SemanticAnalyzerPluginInterface,
) -> None:
self.name = name
self.alias = alias
Expand All @@ -116,6 +117,7 @@ def __init__(
self.info = info
self.kw_only = kw_only
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
self._api = api

def to_argument(self, current_info: TypeInfo) -> Argument:
arg_kind = ARG_POS
Expand All @@ -138,7 +140,10 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]:
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
with state.strict_optional_set(self._api.options.strict_optional):
return expand_type(
self.type, {self.info.self_type.id: fill_typevars(current_info)}
)
return self.type

def to_var(self, current_info: TypeInfo) -> Var:
Expand All @@ -165,13 +170,14 @@ def deserialize(
) -> DataclassAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop("type"), api)
return cls(type=typ, info=info, **data)
return cls(type=typ, info=info, **data, api=api)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if self.type is not None:
self.type = map_type_from_supertype(self.type, sub_type, self.info)
with state.strict_optional_set(self._api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)


class DataclassTransformer:
Expand Down Expand Up @@ -230,12 +236,11 @@ def transform(self) -> bool:
and ("__init__" not in info.names or info.names["__init__"].plugin_generated)
and attributes
):
with state.strict_optional_set(self._api.options.strict_optional):
args = [
attr.to_argument(info)
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]
args = [
attr.to_argument(info)
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]

if info.fallback_to_any:
# Make positional args optional since we don't know their order.
Expand Down Expand Up @@ -355,8 +360,7 @@ def transform(self) -> bool:
self._add_dataclass_fields_magic_attribute()

if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
with state.strict_optional_set(self._api.options.strict_optional):
self._add_internal_replace_method(attributes)
self._add_internal_replace_method(attributes)
if "__post_init__" in info.names:
self._add_internal_post_init_method(attributes)

Expand Down Expand Up @@ -546,8 +550,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
with state.strict_optional_set(self._api.options.strict_optional):
attr.expand_typevar_from_subtype(cls.info)
attr.expand_typevar_from_subtype(cls.info)
found_attrs[name] = attr

sym_node = cls.info.names.get(name)
Expand Down Expand Up @@ -693,6 +696,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
cls.info
),
api=self._api,
)

all_attrs = list(found_attrs.values())
Expand Down
18 changes: 15 additions & 3 deletions test-data/unit/pythoneval.test
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,6 @@ grouped = groupby(pairs, key=fst)
[out]

[case testDataclassReplaceOptional]
# flags: --strict-optional
from dataclasses import dataclass, replace
from typing import Optional

Expand All @@ -2107,5 +2106,18 @@ reveal_type(a)
a2 = replace(a, x=None) # OK
reveal_type(a2)
[out]
_testDataclassReplaceOptional.py:10: note: Revealed type is "_testDataclassReplaceOptional.A"
_testDataclassReplaceOptional.py:12: note: Revealed type is "_testDataclassReplaceOptional.A"
_testDataclassReplaceOptional.py:9: note: Revealed type is "_testDataclassReplaceOptional.A"
_testDataclassReplaceOptional.py:11: note: Revealed type is "_testDataclassReplaceOptional.A"

[case testDataclassStrictOptionalAlwaysSet]
from dataclasses import dataclass
from typing import Callable, Optional

@dataclass
class Description:
name_fn: Callable[[Optional[int]], Optional[str]]

def f(d: Description) -> None:
reveal_type(d.name_fn)
[out]
_testDataclassStrictOptionalAlwaysSet.py:9: note: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]"