Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Jul 29, 2023
1 parent 7c066f3 commit ab25c53
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 38 deletions.
4 changes: 2 additions & 2 deletions docs/source/stubgen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ alter the default behavior:

Import and inspect modules instead of parsing source code. This is the default
behavior for c modules and pyc-only packages. The flag is useful to force
inspection for pure python modules that make use of dynamically generated
members that would otherwiswe be omitted when using the default behavior of
inspection for pure python modules that make use of dynamically generated
members that would otherwiswe be omitted when using the default behavior of
code parsing. Implies :option:`--no-analysis` as analysis requires source
code.

Expand Down
26 changes: 11 additions & 15 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@
from __future__ import annotations

import argparse
import glob
import keyword
import os
import os.path
import sys
import traceback
from collections import defaultdict
from typing import Final, Iterable, Mapping
from typing_extensions import Final
from typing import Final, Iterable

import mypy.build
import mypy.mixedtraverser
Expand Down Expand Up @@ -113,10 +110,10 @@
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
from mypy.stubutil import (
BaseStubGenerator,
CantImport,
ClassInfo,
FunctionContext,
BaseStubGenerator,
common_dir_prefix,
fail_missing,
find_module_path_and_all_py3,
Expand Down Expand Up @@ -735,19 +732,20 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
typename = base.args[0].value
if nt_fields is not None:
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
namedtuple_name = self.typing_name("NamedTuple")
namedtuple_name = self.add_typing_import("NamedTuple")
base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])")
self.add_typing_import("NamedTuple")
else:
# Invalid namedtuple() call, cannot determine fields
base_types.append(self.typing_name("Incomplete"))
base_types.append(
self.add_obj_import("_typeshed", "Incomplete", require=True)
)
elif self.is_typed_namedtuple(base):
base_types.append(base.accept(p))
else:
# At this point, we don't know what the base class is, so we
# just use Incomplete as the base class.
base_types.append(self.typing_name("Incomplete"))
self.add_typing_import("Incomplete")
base_types.append(self.add_obj_import("_typeshed", "Incomplete", require=True))
for name, value in cdef.keywords.items():
if name == "metaclass":
continue # handled separately
Expand Down Expand Up @@ -860,8 +858,7 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
if fields is None:
self.annotate_as_incomplete(lvalue)
return
self.add_typing_import("NamedTuple")
bases = self.typing_name("NamedTuple")
bases = self.add_typing_import("NamedTuple")
# TODO: Add support for generic NamedTuples. Requires `Generic` as base class.
class_def = f"{self._indent}class {lvalue.name}({bases}):"
if len(fields) == 0:
Expand Down Expand Up @@ -918,8 +915,7 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n")
self._state = VAR
else:
incomplete = self.add_obj_import("_typeshed", "Incomplete", require=True)
bases = self.typing_name("TypedDict")
bases = self.add_typing_import("TypedDict")
# TODO: Add support for generic TypedDicts. Requires `Generic` as base class.
if total is not None:
bases += f", total={total.accept(p)}"
Expand All @@ -936,8 +932,8 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
self._state = CLASS

def annotate_as_incomplete(self, lvalue: NameExpr) -> None:
self.add_typing_import("Incomplete")
self.add(f"{self._indent}{lvalue.name}: {self.typing_name('Incomplete')}\n")
incomplete = self.add_obj_import("_typeshed", "Incomplete", require=True)
self.add(f"{self._indent}{lvalue.name}: {incomplete}\n")
self._state = VAR

def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
Expand Down
8 changes: 3 additions & 5 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
import inspect
import keyword
import os.path
import re
from abc import abstractmethod
from types import FunctionType, ModuleType
from typing import Any, Final, Mapping
from typing import Any, Mapping

from mypy.fastparse import parse_type_comment
from mypy.moduleinspect import is_c_module
Expand All @@ -31,10 +29,10 @@
parse_all_signatures,
)
from mypy.stubutil import (
BaseStubGenerator,
ClassInfo,
FunctionContext,
SignatureGenerator,
BaseStubGenerator,
infer_method_ret_type,
)

Expand Down Expand Up @@ -587,7 +585,7 @@ def _fix_iter(
ctx.class_info
and ctx.class_info.cls is not None
and ctx.name == "__getitem__"
and "__iter__" not in getattr(ctx.class_info.cls, "__dict__")
and "__iter__" not in ctx.class_info.cls.__dict__
):
item_type: str | None = None
for sig in inferred:
Expand Down
5 changes: 2 additions & 3 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import keyword
import os.path
import re
import sys
Expand Down Expand Up @@ -657,9 +656,9 @@ def set_defined_names(self, defined_names: set[str]) -> None:
# a corresponding import statement.
known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "ParamSpec", "NamedTuple"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
}
for pkg, imports in known_imports.items():
for t in imports:
Expand Down
15 changes: 3 additions & 12 deletions mypy/test/teststubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,6 @@ def test(cls, arg0: str) -> None:
pass

output: list[str] = []
imports: list[str] = []
mod = ModuleType(TestClass.__module__, "")
gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod)
gen.generate_function_stub(
Expand Down Expand Up @@ -1095,19 +1094,11 @@ def test(arg0: str) -> None:
test.__doc__ = property(lambda self: "test(arg0: str) -> None") # type: ignore[assignment]

output: list[str] = []
imports: list[str] = []
mod = ModuleType(self.__module__, "")
generate_c_function_stub(
mod,
"test",
test,
output=output,
imports=imports,
known_modules=[mod.__name__],
sig_generators=get_sig_generators(parse_options([])),
)
gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod)
gen.generate_function_stub("test", test, output=output)
assert_equal(output, ["def test(*args, **kwargs) -> Any: ..."])
assert_equal(imports, [])
assert_equal(gen.get_imports().splitlines(), [])

def test_generate_c_property_with_pybind11(self) -> None:
"""Signatures included by PyBind11 inside property.fget are read."""
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -2853,7 +2853,7 @@ __uri__ = ''
__version__ = ''

[out]
from m import __version__ as __version__
from m import __version__ as __version__

class A: ...

Expand Down

0 comments on commit ab25c53

Please sign in to comment.