Skip to content

Commit

Permalink
stubgen: include __all__ in output (#16356)
Browse files Browse the repository at this point in the history
Fixes #10314
  • Loading branch information
JelleZijlstra authored Oct 30, 2023
1 parent 128176a commit 0ff7a29
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 19 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Unreleased

...
Stubgen will now include `__all__` in its output if it is in the input file (PR [16356](https://github.com/python/mypy/pull/16356)).

#### Other Notable Changes and Fixes
...
Expand Down
56 changes: 39 additions & 17 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,10 +614,24 @@ def get_imports(self) -> str:

def output(self) -> str:
"""Return the text for the stub."""
imports = self.get_imports()
if imports and self._output:
imports += "\n"
return imports + "".join(self._output)
pieces: list[str] = []
if imports := self.get_imports():
pieces.append(imports)
if dunder_all := self.get_dunder_all():
pieces.append(dunder_all)
if self._output:
pieces.append("".join(self._output))
return "\n".join(pieces)

def get_dunder_all(self) -> str:
"""Return the __all__ list for the stub."""
if self._all_:
# Note we emit all names in the runtime __all__ here, even if they
# don't actually exist. If that happens, the runtime has a bug, and
# it's not obvious what the correct behavior should be. We choose
# to reflect the runtime __all__ as closely as possible.
return f"__all__ = {self._all_!r}\n"
return ""

def add(self, string: str) -> None:
"""Add text to generated stub."""
Expand Down Expand Up @@ -651,8 +665,7 @@ def set_defined_names(self, defined_names: set[str]) -> None:
self.defined_names = defined_names
# Names in __all__ are required
for name in self._all_ or ():
if name not in self.IGNORED_DUNDERS:
self.import_tracker.reexport(name)
self.import_tracker.reexport(name)

# These are "soft" imports for objects which might appear in annotations but not have
# a corresponding import statement.
Expand Down Expand Up @@ -751,7 +764,13 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
return False
if name == "_":
return False
return name.startswith("_") and (not name.endswith("__") or name in self.IGNORED_DUNDERS)
if not name.startswith("_"):
return False
if self._all_ and name in self._all_:
return False
if name.startswith("__") and name.endswith("__"):
return name in self.IGNORED_DUNDERS
return True

def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool:
if (
Expand All @@ -761,18 +780,21 @@ def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> b
):
# Special case certain names that should be exported, against our general rules.
return True
if name_is_alias:
return False
if self.export_less:
return False
if not self.module_name:
return False
is_private = self.is_private_name(name, full_module + "." + name)
if is_private:
return False
top_level = full_module.split(".")[0]
self_top_level = self.module_name.split(".", 1)[0]
if (
not name_is_alias
and not self.export_less
and (not self._all_ or name in self.IGNORED_DUNDERS)
and self.module_name
and not is_private
and top_level in (self_top_level, "_" + self_top_level)
):
if top_level not in (self_top_level, "_" + self_top_level):
# Export imports from the same package, since we can't reliably tell whether they
# are part of the public API.
return True
return False
return False
if self._all_:
return name in self._all_
return True
40 changes: 39 additions & 1 deletion test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -587,20 +587,26 @@ __all__ = [] + ['f']
def f(): ...
def g(): ...
[out]
__all__ = ['f']

def f() -> None: ...

[case testOmitDefsNotInAll_semanal]
__all__ = ['f']
def f(): ...
def g(): ...
[out]
__all__ = ['f']

def f() -> None: ...

[case testOmitDefsNotInAll_inspect]
__all__ = [] + ['f']
def f(): ...
def g(): ...
[out]
__all__ = ['f']

def f(): ...

[case testVarDefsNotInAll_import]
Expand All @@ -610,6 +616,8 @@ x = 1
y = 1
def g(): ...
[out]
__all__ = ['f', 'g']

def f() -> None: ...
def g() -> None: ...

Expand All @@ -620,6 +628,8 @@ x = 1
y = 1
def g(): ...
[out]
__all__ = ['f', 'g']

def f(): ...
def g(): ...

Expand All @@ -628,6 +638,8 @@ __all__ = [] + ['f']
def f(): ...
class A: ...
[out]
__all__ = ['f']

def f() -> None: ...

class A: ...
Expand All @@ -637,6 +649,8 @@ __all__ = [] + ['f']
def f(): ...
class A: ...
[out]
__all__ = ['f']

def f(): ...

class A: ...
Expand All @@ -647,6 +661,8 @@ class A:
x = 1
def f(self): ...
[out]
__all__ = ['A']

class A:
x: int
def f(self) -> None: ...
Expand Down Expand Up @@ -684,6 +700,8 @@ x = 1
[out]
from re import match as match, sub as sub

__all__ = ['match', 'sub', 'x']

x: int

[case testExportModule_import]
Expand All @@ -694,6 +712,8 @@ y = 2
[out]
import re as re

__all__ = ['re', 'x']

x: int

[case testExportModule2_import]
Expand All @@ -704,6 +724,8 @@ y = 2
[out]
import re as re

__all__ = ['re', 'x']

x: int

[case testExportModuleAs_import]
Expand All @@ -714,6 +736,8 @@ y = 2
[out]
import re as rex

__all__ = ['rex', 'x']

x: int

[case testExportModuleInPackage_import]
Expand All @@ -722,13 +746,17 @@ __all__ = ['p']
[out]
import urllib.parse as p

__all__ = ['p']

[case testExportPackageOfAModule_import]
import urllib.parse
__all__ = ['urllib']

[out]
import urllib as urllib

__all__ = ['urllib']

[case testRelativeImportAll]
from .x import *
[out]
Expand All @@ -741,6 +769,8 @@ x = 1
class C:
def g(self): ...
[out]
__all__ = ['f', 'x', 'C', 'g']

def f() -> None: ...

x: int
Expand All @@ -758,6 +788,8 @@ x = 1
class C:
def g(self): ...
[out]
__all__ = ['f', 'x', 'C', 'g']

def f(): ...

x: int
Expand Down Expand Up @@ -2343,6 +2375,8 @@ else:
[out]
import cookielib as cookielib

__all__ = ['cookielib']

[case testCannotCalculateMRO_semanal]
class X: pass

Expand Down Expand Up @@ -2788,6 +2822,8 @@ class A: pass
# p/__init__.pyi
from p.a import A

__all__ = ['a']

a: A
# p/a.pyi
class A: ...
Expand Down Expand Up @@ -2961,7 +2997,9 @@ __uri__ = ''
__version__ = ''

[out]
from m import __version__ as __version__
from m import __about__ as __about__, __author__ as __author__, __version__ as __version__

__all__ = ['__about__', '__author__', '__version__']

[case testAttrsClass_semanal]
import attrs
Expand Down

0 comments on commit 0ff7a29

Please sign in to comment.