diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 7721366f5c0c..22028694ad6b 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -453,6 +453,7 @@ def __init__( self.analyzed = analyzed # Short names of methods defined in the body of the current class self.method_names: set[str] = set() + self.processing_enum = False self.processing_dataclass = False def visit_mypy_file(self, o: MypyFile) -> None: @@ -727,6 +728,8 @@ def visit_class_def(self, o: ClassDef) -> None: if base_types: for base in base_types: self.import_tracker.require_name(base) + if self.analyzed and o.info.is_enum: + self.processing_enum = True if isinstance(o.metaclass, (NameExpr, MemberExpr)): meta = o.metaclass.accept(AliasPrinter(self)) base_types.append("metaclass=" + meta) @@ -756,6 +759,7 @@ def visit_class_def(self, o: ClassDef) -> None: self._state = CLASS self.method_names = set() self.processing_dataclass = False + self.processing_enum = False self._current_class = None def get_base_types(self, cdef: ClassDef) -> list[str]: @@ -1153,6 +1157,9 @@ def get_init( # Final without type argument is invalid in stubs. final_arg = self.get_str_type_of_node(rvalue) typename += f"[{final_arg}]" + elif self.processing_enum: + initializer, _ = self.get_str_default_of_node(rvalue) + return f"{self._indent}{lvalue} = {initializer}\n" elif self.processing_dataclass: # attribute without annotation is not a dataclass field, don't add annotation. return f"{self._indent}{lvalue} = ...\n" diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 53baa2c0ca06..916e2e3a8e17 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -4342,3 +4342,27 @@ alias = tuple[()] def f(x: tuple[()]): ... class C(tuple[()]): ... + +[case testPreserveEnumValue_semanal] +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +class Bar(Enum): + A = object() + B = "a" + "b" + +[out] +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +class Bar(Enum): + A = ... + B = ...