Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix strict optional handling in dataclasses #15571

Merged
merged 3 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
info: TypeInfo,
kw_only: bool,
is_neither_frozen_nor_nonfrozen: bool,
api: SemanticAnalyzerPluginInterface,
) -> None:
self.name = name
self.alias = alias
Expand All @@ -116,6 +117,7 @@ def __init__(
self.info = info
self.kw_only = kw_only
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
self._api = api

def to_argument(self, current_info: TypeInfo) -> Argument:
arg_kind = ARG_POS
Expand All @@ -138,7 +140,10 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]:
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
with state.strict_optional_set(self._api.options.strict_optional):
return expand_type(
self.type, {self.info.self_type.id: fill_typevars(current_info)}
)
return self.type

def to_var(self, current_info: TypeInfo) -> Var:
Expand All @@ -165,13 +170,14 @@ def deserialize(
) -> DataclassAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop("type"), api)
return cls(type=typ, info=info, **data)
return cls(type=typ, info=info, **data, api=api)

def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if self.type is not None:
self.type = map_type_from_supertype(self.type, sub_type, self.info)
with state.strict_optional_set(self._api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)


class DataclassTransformer:
Expand Down Expand Up @@ -230,12 +236,11 @@ def transform(self) -> bool:
and ("__init__" not in info.names or info.names["__init__"].plugin_generated)
and attributes
):
with state.strict_optional_set(self._api.options.strict_optional):
args = [
attr.to_argument(info)
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]
args = [
attr.to_argument(info)
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]

if info.fallback_to_any:
# Make positional args optional since we don't know their order.
Expand Down Expand Up @@ -355,8 +360,7 @@ def transform(self) -> bool:
self._add_dataclass_fields_magic_attribute()

if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
with state.strict_optional_set(self._api.options.strict_optional):
self._add_internal_replace_method(attributes)
self._add_internal_replace_method(attributes)
if "__post_init__" in info.names:
self._add_internal_post_init_method(attributes)

Expand Down Expand Up @@ -546,8 +550,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
with state.strict_optional_set(self._api.options.strict_optional):
attr.expand_typevar_from_subtype(cls.info)
attr.expand_typevar_from_subtype(cls.info)
found_attrs[name] = attr

sym_node = cls.info.names.get(name)
Expand Down Expand Up @@ -693,6 +696,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
cls.info
),
api=self._api,
)

all_attrs = list(found_attrs.values())
Expand Down
13 changes: 13 additions & 0 deletions test-data/unit/check-dataclasses.test
Original file line number Diff line number Diff line change
Expand Up @@ -2420,3 +2420,16 @@ class Test(Protocol):
def reset(self) -> None:
self.x = DEFAULT
[builtins fixtures/dataclasses.pyi]

[case testStrictOptionalAlwaysSet]
# flags: --strict-optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# flags: --strict-optional

--strict-optional is the default. So the flag wouldn't be necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--strict-optional is actually off by default in tests (at least for testcheck.py).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this actually reminded me that this test should use real stubs, so it should be moved to pythoneval.test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, before writing the initial comment, I removed the flag from the test case and it still passed 🤔

--strict-optional is actually off by default in tests (at least for testcheck.py).

I enabled it for pythoneval tests in #15474, seems I should work on the rest as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, before writing the initial comment, I removed the flag from the test case and it still passed

This is exactly why I moved it to pythoneval, otherwise it doesn't really test much.

from dataclasses import dataclass
from typing import Callable, Optional

@dataclass
class Description:
name_fn: Callable[[Optional[int]], Optional[str]]

def f(d: Description) -> None:
reveal_type(d.name_fn) # N: Revealed type is "def (Union[builtins.int, None]) -> Union[builtins.str, None]"
[builtins fixtures/dataclasses.pyi]