Skip to content

Commit

Permalink
stubgen: preserve enum value initialisers (#17125)
Browse files Browse the repository at this point in the history
  • Loading branch information
hauntsaninja authored May 21, 2024
1 parent 2892ed4 commit 42157ba
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
7 changes: 7 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def __init__(
self.analyzed = analyzed
# Short names of methods defined in the body of the current class
self.method_names: set[str] = set()
self.processing_enum = False
self.processing_dataclass = False

def visit_mypy_file(self, o: MypyFile) -> None:
Expand Down Expand Up @@ -727,6 +728,8 @@ def visit_class_def(self, o: ClassDef) -> None:
if base_types:
for base in base_types:
self.import_tracker.require_name(base)
if self.analyzed and o.info.is_enum:
self.processing_enum = True
if isinstance(o.metaclass, (NameExpr, MemberExpr)):
meta = o.metaclass.accept(AliasPrinter(self))
base_types.append("metaclass=" + meta)
Expand Down Expand Up @@ -756,6 +759,7 @@ def visit_class_def(self, o: ClassDef) -> None:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False
self.processing_enum = False
self._current_class = None

def get_base_types(self, cdef: ClassDef) -> list[str]:
Expand Down Expand Up @@ -1153,6 +1157,9 @@ def get_init(
# Final without type argument is invalid in stubs.
final_arg = self.get_str_type_of_node(rvalue)
typename += f"[{final_arg}]"
elif self.processing_enum:
initializer, _ = self.get_str_default_of_node(rvalue)
return f"{self._indent}{lvalue} = {initializer}\n"
elif self.processing_dataclass:
# attribute without annotation is not a dataclass field, don't add annotation.
return f"{self._indent}{lvalue} = ...\n"
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -4342,3 +4342,27 @@ alias = tuple[()]
def f(x: tuple[()]): ...

class C(tuple[()]): ...

[case testPreserveEnumValue_semanal]
from enum import Enum

class Foo(Enum):
A = 1
B = 2
C = 3

class Bar(Enum):
A = object()
B = "a" + "b"

[out]
from enum import Enum

class Foo(Enum):
A = 1
B = 2
C = 3

class Bar(Enum):
A = ...
B = ...

0 comments on commit 42157ba

Please sign in to comment.