From 0ff7a29d5336dad6400a9356bd4116b59c20a875 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 30 Oct 2023 11:48:07 -0700 Subject: [PATCH] stubgen: include __all__ in output (#16356) Fixes #10314 --- CHANGELOG.md | 2 +- mypy/stubutil.py | 56 ++++++++++++++++++++++++++----------- test-data/unit/stubgen.test | 40 +++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d8237795112b..74f7c676c279 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Unreleased -... +Stubgen will now include `__all__` in its output if it is in the input file (PR [16356](https://github.com/python/mypy/pull/16356)). #### Other Notable Changes and Fixes ... diff --git a/mypy/stubutil.py b/mypy/stubutil.py index cc3b63098fd2..5ec240087145 100644 --- a/mypy/stubutil.py +++ b/mypy/stubutil.py @@ -614,10 +614,24 @@ def get_imports(self) -> str: def output(self) -> str: """Return the text for the stub.""" - imports = self.get_imports() - if imports and self._output: - imports += "\n" - return imports + "".join(self._output) + pieces: list[str] = [] + if imports := self.get_imports(): + pieces.append(imports) + if dunder_all := self.get_dunder_all(): + pieces.append(dunder_all) + if self._output: + pieces.append("".join(self._output)) + return "\n".join(pieces) + + def get_dunder_all(self) -> str: + """Return the __all__ list for the stub.""" + if self._all_: + # Note we emit all names in the runtime __all__ here, even if they + # don't actually exist. If that happens, the runtime has a bug, and + # it's not obvious what the correct behavior should be. We choose + # to reflect the runtime __all__ as closely as possible. + return f"__all__ = {self._all_!r}\n" + return "" def add(self, string: str) -> None: """Add text to generated stub.""" @@ -651,8 +665,7 @@ def set_defined_names(self, defined_names: set[str]) -> None: self.defined_names = defined_names # Names in __all__ are required for name in self._all_ or (): - if name not in self.IGNORED_DUNDERS: - self.import_tracker.reexport(name) + self.import_tracker.reexport(name) # These are "soft" imports for objects which might appear in annotations but not have # a corresponding import statement. @@ -751,7 +764,13 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool: return False if name == "_": return False - return name.startswith("_") and (not name.endswith("__") or name in self.IGNORED_DUNDERS) + if not name.startswith("_"): + return False + if self._all_ and name in self._all_: + return False + if name.startswith("__") and name.endswith("__"): + return name in self.IGNORED_DUNDERS + return True def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool: if ( @@ -761,18 +780,21 @@ def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> b ): # Special case certain names that should be exported, against our general rules. return True + if name_is_alias: + return False + if self.export_less: + return False + if not self.module_name: + return False is_private = self.is_private_name(name, full_module + "." + name) + if is_private: + return False top_level = full_module.split(".")[0] self_top_level = self.module_name.split(".", 1)[0] - if ( - not name_is_alias - and not self.export_less - and (not self._all_ or name in self.IGNORED_DUNDERS) - and self.module_name - and not is_private - and top_level in (self_top_level, "_" + self_top_level) - ): + if top_level not in (self_top_level, "_" + self_top_level): # Export imports from the same package, since we can't reliably tell whether they # are part of the public API. - return True - return False + return False + if self._all_: + return name in self._all_ + return True diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 895500c1ba57..2a43ce16383d 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -587,6 +587,8 @@ __all__ = [] + ['f'] def f(): ... def g(): ... [out] +__all__ = ['f'] + def f() -> None: ... [case testOmitDefsNotInAll_semanal] @@ -594,6 +596,8 @@ __all__ = ['f'] def f(): ... def g(): ... [out] +__all__ = ['f'] + def f() -> None: ... [case testOmitDefsNotInAll_inspect] @@ -601,6 +605,8 @@ __all__ = [] + ['f'] def f(): ... def g(): ... [out] +__all__ = ['f'] + def f(): ... [case testVarDefsNotInAll_import] @@ -610,6 +616,8 @@ x = 1 y = 1 def g(): ... [out] +__all__ = ['f', 'g'] + def f() -> None: ... def g() -> None: ... @@ -620,6 +628,8 @@ x = 1 y = 1 def g(): ... [out] +__all__ = ['f', 'g'] + def f(): ... def g(): ... @@ -628,6 +638,8 @@ __all__ = [] + ['f'] def f(): ... class A: ... [out] +__all__ = ['f'] + def f() -> None: ... class A: ... @@ -637,6 +649,8 @@ __all__ = [] + ['f'] def f(): ... class A: ... [out] +__all__ = ['f'] + def f(): ... class A: ... @@ -647,6 +661,8 @@ class A: x = 1 def f(self): ... [out] +__all__ = ['A'] + class A: x: int def f(self) -> None: ... @@ -684,6 +700,8 @@ x = 1 [out] from re import match as match, sub as sub +__all__ = ['match', 'sub', 'x'] + x: int [case testExportModule_import] @@ -694,6 +712,8 @@ y = 2 [out] import re as re +__all__ = ['re', 'x'] + x: int [case testExportModule2_import] @@ -704,6 +724,8 @@ y = 2 [out] import re as re +__all__ = ['re', 'x'] + x: int [case testExportModuleAs_import] @@ -714,6 +736,8 @@ y = 2 [out] import re as rex +__all__ = ['rex', 'x'] + x: int [case testExportModuleInPackage_import] @@ -722,6 +746,8 @@ __all__ = ['p'] [out] import urllib.parse as p +__all__ = ['p'] + [case testExportPackageOfAModule_import] import urllib.parse __all__ = ['urllib'] @@ -729,6 +755,8 @@ __all__ = ['urllib'] [out] import urllib as urllib +__all__ = ['urllib'] + [case testRelativeImportAll] from .x import * [out] @@ -741,6 +769,8 @@ x = 1 class C: def g(self): ... [out] +__all__ = ['f', 'x', 'C', 'g'] + def f() -> None: ... x: int @@ -758,6 +788,8 @@ x = 1 class C: def g(self): ... [out] +__all__ = ['f', 'x', 'C', 'g'] + def f(): ... x: int @@ -2343,6 +2375,8 @@ else: [out] import cookielib as cookielib +__all__ = ['cookielib'] + [case testCannotCalculateMRO_semanal] class X: pass @@ -2788,6 +2822,8 @@ class A: pass # p/__init__.pyi from p.a import A +__all__ = ['a'] + a: A # p/a.pyi class A: ... @@ -2961,7 +2997,9 @@ __uri__ = '' __version__ = '' [out] -from m import __version__ as __version__ +from m import __about__ as __about__, __author__ as __author__, __version__ as __version__ + +__all__ = ['__about__', '__author__', '__version__'] [case testAttrsClass_semanal] import attrs