Skip to content

Commit

Permalink
stubgen: fix FunctionContext.fullname for nested classes
Browse files Browse the repository at this point in the history
  • Loading branch information
chadrik committed Oct 16, 2024
1 parent 7237d55 commit 23a81fc
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
16 changes: 11 additions & 5 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,14 +468,18 @@ def __init__(
self._vars: list[list[str]] = [[]]
# What was generated previously in the stub file.
self._state = EMPTY
self._current_class: ClassDef | None = None
self._class_stack: list[ClassDef] = []
# Was the tree semantically analysed before?
self.analyzed = analyzed
# Short names of methods defined in the body of the current class
self.method_names: set[str] = set()
self.processing_enum = False
self.processing_dataclass = False

@property
def _current_class(self) -> ClassDef | None:
return self._class_stack[-1] if self._class_stack else None

def visit_mypy_file(self, o: MypyFile) -> None:
self.module_name = o.fullname # Current module being processed
self.path = o.path
Expand Down Expand Up @@ -646,12 +650,14 @@ def visit_func_def(self, o: FuncDef) -> None:
if init_code:
self.add(init_code)

if self._current_class is not None:
if self._class_stack:
if len(o.arguments):
self_var = o.arguments[0].variable.name
else:
self_var = "self"
class_info = ClassInfo(self._current_class.name, self_var)
class_info = None
for class_def in self._class_stack:
class_info = ClassInfo(class_def.name, self_var, parent=class_info)
else:
class_info = None

Expand Down Expand Up @@ -741,7 +747,7 @@ def get_fullname(self, expr: Expression) -> str:
return self.resolve_name(name)

def visit_class_def(self, o: ClassDef) -> None:
self._current_class = o
self._class_stack.append(o)
self.method_names = find_method_names(o.defs.body)
sep: int | None = None
if self.is_top_level() and self._state != EMPTY:
Expand Down Expand Up @@ -786,8 +792,8 @@ def visit_class_def(self, o: ClassDef) -> None:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False
self._class_stack.pop(-1)
self.processing_enum = False
self._current_class = None

def get_base_types(self, cdef: ClassDef) -> list[str]:
"""Get list of base classes for a class."""
Expand Down
10 changes: 7 additions & 3 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,9 @@ def get_base_types(self, obj: type) -> list[str]:
bases.append(base)
return [self.strip_or_import(self.get_type_fullname(base)) for base in bases]

def generate_class_stub(self, class_name: str, cls: type, output: list[str]) -> None:
def generate_class_stub(
self, class_name: str, cls: type, output: list[str], parent_class: ClassInfo | None = None
) -> None:
"""Generate stub for a single class using runtime introspection.
The result lines will be appended to 'output'. If necessary, any
Expand All @@ -808,7 +810,9 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
self.record_name(class_name)
self.indent()

class_info = ClassInfo(class_name, "", getattr(cls, "__doc__", None), cls)
class_info = ClassInfo(
class_name, "", getattr(cls, "__doc__", None), cls, parent=parent_class
)

for attr, value in items:
# use unevaluated descriptors when dealing with property inspection
Expand Down Expand Up @@ -843,7 +847,7 @@ def generate_class_stub(self, class_name: str, cls: type, output: list[str]) ->
class_info,
)
elif inspect.isclass(value) and self.is_defined_in_module(value):
self.generate_class_stub(attr, value, types)
self.generate_class_stub(attr, value, types, parent_class=class_info)
else:
attrs.append((attr, value))

Expand Down
16 changes: 14 additions & 2 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,18 @@ def args_str(self, args: Iterable[Type]) -> str:

class ClassInfo:
def __init__(
self, name: str, self_var: str, docstring: str | None = None, cls: type | None = None
self,
name: str,
self_var: str,
docstring: str | None = None,
cls: type | None = None,
parent: ClassInfo | None = None,
) -> None:
self.name = name
self.self_var = self_var
self.docstring = docstring
self.cls = cls
self.parent = parent


class FunctionContext:
Expand All @@ -334,7 +340,13 @@ def __init__(
def fullname(self) -> str:
if self._fullname is None:
if self.class_info:
self._fullname = f"{self.module_name}.{self.class_info.name}.{self.name}"
parents = []
class_info = self.class_info
while class_info is not None:
parents.append(class_info.name)
class_info = class_info.parent
namespace = ".".join(reversed(parents))
self._fullname = f"{self.module_name}.{namespace}.{self.name}"
else:
self._fullname = f"{self.module_name}.{self.name}"
return self._fullname
Expand Down

0 comments on commit 23a81fc

Please sign in to comment.