Skip to content

Commit

Permalink
allow-types-from--annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
InvincibleRMC committed Jul 2, 2024
1 parent 177c8ee commit 68b1a6f
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
30 changes: 26 additions & 4 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,13 +847,25 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
else:
attrs.append((attr, value))

# Gets annotations if they exist
try:
annotations = cls.__annotations__
except AttributeError:
annotations = {}

for attr, value in attrs:
if attr == "__hash__" and value is None:
# special case for __hash__
continue
prop_type_name = self.strip_or_import(self.get_type_annotation(value))
classvar = self.add_name("typing.ClassVar")
static_properties.append(f"{self._indent}{attr}: {classvar}[{prop_type_name}] = ...")
if attr in annotations:
prop_type_name = self.strip_or_import(annotations[attr])
static_properties.append(f"{self._indent}{attr}: {prop_type_name} = ...")
else:
prop_type_name = self.strip_or_import(self.get_type_annotation(value))
classvar = self.add_name("typing.ClassVar")
static_properties.append(
f"{self._indent}{attr}: {classvar}[{prop_type_name}] = ..."
)

self.dedent()

Expand Down Expand Up @@ -893,7 +905,17 @@ def generate_variable_stub(self, name: str, obj: object, output: list[str]) -> N
if self.is_private_name(name, f"{self.module_name}.{name}") or self.is_not_in_all(name):
return
self.record_name(name)
type_str = self.strip_or_import(self.get_type_annotation(obj))

# Gets annotations if they exist
try:
annotations = self.module.__annotations__
except AttributeError:
annotations = {}

if name in annotations:
type_str = self.strip_or_import(annotations[name])
else:
type_str = self.strip_or_import(self.get_type_annotation(obj))
output.append(f"{name}: {type_str}")


Expand Down
21 changes: 21 additions & 0 deletions mypy/test/teststubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,27 @@ def test(self, arg0: str) -> None:
assert_equal(output, ["def test(self, arg0: int) -> Any: ..."])
assert_equal(gen.get_imports().splitlines(), [])

def test_generate_c_class_fields_from__annotations__(self) -> None:
class TestClass:
__annotations__ = {"x": "dict[str, int]"}
x = {} # type:ignore [var-annotated]

output: list[str] = []
mod = ModuleType("module", "") # any module is fine
gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod)
gen.generate_class_stub("C", TestClass, output)
assert_equal(output, ["class C:", " x: dict[str, int] = ..."])
assert_equal(gen.get_imports().splitlines(), [])

def test_generate_c_module_fields_from__annotations__(self) -> None:
mod = ModuleType("module", "") # any module is fine
mod.__annotations__ = {"x": "dict[str, int]"}
mod.x = {} # type:ignore [attr-defined]
gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod)
gen.generate_module()
assert_equal(gen.output(), "x: dict[str, int]\n")
assert_equal(gen.get_imports().splitlines(), [])

def test_generate_c_type_classmethod(self) -> None:
class TestClass:
@classmethod
Expand Down

0 comments on commit 68b1a6f

Please sign in to comment.