Skip to content

Commit

Permalink
stubgen: Replace obsolete typing aliases with builtin containers (pyt…
Browse files Browse the repository at this point in the history
…hon#16780)

Addresses part of python#16737

This only replaces typing symbols that have equivalents in the
`builtins` module. Replacing other symbols, like those from the
`collections.abc` module, are a bit more complicated so I suggest we
handle them separately.

I also changed the default `TypedDict` module from `typing_extensions`
to `typing` as typeshed dropped support for Python 3.7.
  • Loading branch information
hamdanal committed Feb 20, 2024
1 parent 15a7282 commit 1ee3e0b
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 50 deletions.
51 changes: 33 additions & 18 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import os.path
import sys
import traceback
from typing import Final, Iterable
from typing import Final, Iterable, Iterator

import mypy.build
import mypy.mixedtraverser
Expand Down Expand Up @@ -114,6 +114,7 @@
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
from mypy.stubutil import (
TYPING_BUILTIN_REPLACEMENTS,
BaseStubGenerator,
CantImport,
ClassInfo,
Expand Down Expand Up @@ -289,20 +290,19 @@ def visit_call_expr(self, node: CallExpr) -> str:
raise ValueError(f"Unknown argument kind {kind} in call")
return f"{callee}({', '.join(args)})"

def _visit_ref_expr(self, node: NameExpr | MemberExpr) -> str:
fullname = self.stubgen.get_fullname(node)
if fullname in TYPING_BUILTIN_REPLACEMENTS:
return self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=False)
qualname = get_qualified_name(node)
self.stubgen.import_tracker.require_name(qualname)
return qualname

def visit_name_expr(self, node: NameExpr) -> str:
self.stubgen.import_tracker.require_name(node.name)
return node.name
return self._visit_ref_expr(node)

def visit_member_expr(self, o: MemberExpr) -> str:
node: Expression = o
trailer = ""
while isinstance(node, MemberExpr):
trailer = "." + node.name + trailer
node = node.expr
if not isinstance(node, NameExpr):
return ERROR_MARKER
self.stubgen.import_tracker.require_name(node.name)
return node.name + trailer
return self._visit_ref_expr(o)

def visit_str_expr(self, node: StrExpr) -> str:
return repr(node.value)
Expand Down Expand Up @@ -351,11 +351,17 @@ def find_defined_names(file: MypyFile) -> set[str]:
return finder.names


def get_assigned_names(lvalues: Iterable[Expression]) -> Iterator[str]:
for lvalue in lvalues:
if isinstance(lvalue, NameExpr):
yield lvalue.name
elif isinstance(lvalue, TupleExpr):
yield from get_assigned_names(lvalue.items)


class DefinitionFinder(mypy.traverser.TraverserVisitor):
"""Find names of things defined at the top level of a module."""

# TODO: Assignment statements etc.

def __init__(self) -> None:
# Short names of things defined at the top level.
self.names: set[str] = set()
Expand All @@ -368,6 +374,10 @@ def visit_func_def(self, o: FuncDef) -> None:
# Don't recurse, as we only keep track of top-level definitions.
self.names.add(o.name)

def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
for name in get_assigned_names(o.lvalues):
self.names.add(name)


def find_referenced_names(file: MypyFile) -> set[str]:
finder = ReferenceFinder()
Expand Down Expand Up @@ -1023,10 +1033,15 @@ def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
and isinstance(expr.node, (FuncDef, Decorator, MypyFile))
or isinstance(expr.node, TypeInfo)
) and not self.is_private_member(expr.node.fullname)
elif (
isinstance(expr, IndexExpr)
and isinstance(expr.base, NameExpr)
and not self.is_private_name(expr.base.name)
elif isinstance(expr, IndexExpr) and (
(isinstance(expr.base, NameExpr) and not self.is_private_name(expr.base.name))
or ( # Also some known aliases that could be member expression
isinstance(expr.base, MemberExpr)
and not self.is_private_member(get_qualified_name(expr.base))
and self.get_fullname(expr.base).startswith(
("builtins.", "typing.", "typing_extensions.", "collections.abc.")
)
)
):
if isinstance(expr.index, TupleExpr):
indices = expr.index.items
Expand Down
33 changes: 29 additions & 4 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@
# Modules that may fail when imported, or that may have side effects (fully qualified).
NOT_IMPORTABLE_MODULES = ()

# Typing constructs to be replaced by their builtin equivalents.
TYPING_BUILTIN_REPLACEMENTS: Final = {
# From typing
"typing.Text": "builtins.str",
"typing.Tuple": "builtins.tuple",
"typing.List": "builtins.list",
"typing.Dict": "builtins.dict",
"typing.Set": "builtins.set",
"typing.FrozenSet": "builtins.frozenset",
"typing.Type": "builtins.type",
# From typing_extensions
"typing_extensions.Text": "builtins.str",
"typing_extensions.Tuple": "builtins.tuple",
"typing_extensions.List": "builtins.list",
"typing_extensions.Dict": "builtins.dict",
"typing_extensions.Set": "builtins.set",
"typing_extensions.FrozenSet": "builtins.frozenset",
"typing_extensions.Type": "builtins.type",
}


class CantImport(Exception):
def __init__(self, module: str, message: str) -> None:
Expand Down Expand Up @@ -229,6 +249,8 @@ def visit_unbound_type(self, t: UnboundType) -> str:
return " | ".join([item.accept(self) for item in t.args])
if fullname == "typing.Optional":
return f"{t.args[0].accept(self)} | None"
if fullname in TYPING_BUILTIN_REPLACEMENTS:
s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True)
if self.known_modules is not None and "." in s:
# see if this object is from any of the modules that we're currently processing.
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
Expand Down Expand Up @@ -476,7 +498,7 @@ def reexport(self, name: str) -> None:
def import_lines(self) -> list[str]:
"""The list of required import lines (as strings with python code).
In order for a module be included in this output, an indentifier must be both
In order for a module be included in this output, an identifier must be both
'required' via require_name() and 'imported' via add_import_from()
or add_import()
"""
Expand Down Expand Up @@ -585,9 +607,9 @@ def __init__(
# a corresponding import statement.
self.known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"typing": ["Any", "TypeVar", "NamedTuple", "TypedDict"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
"typing_extensions": ["ParamSpec", "TypeVarTuple"],
}

def get_sig_generators(self) -> list[SignatureGenerator]:
Expand All @@ -613,7 +635,10 @@ def add_name(self, fullname: str, require: bool = True) -> str:
"""
module, name = fullname.rsplit(".", 1)
alias = "_" + name if name in self.defined_names else None
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
while alias in self.defined_names:
alias = "_" + alias
if module != "builtins" or alias: # don't import from builtins unless needed
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
return alias or name

def add_import_line(self, line: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from . import demo as demo
from typing import List, Tuple, overload
from typing import overload

class StaticMethods:
def __init__(self, *args, **kwargs) -> None: ...
Expand All @@ -22,6 +22,6 @@ class TestStruct:

def func_incomplete_signature(*args, **kwargs): ...
def func_returning_optional() -> int | None: ...
def func_returning_pair() -> Tuple[int, float]: ...
def func_returning_pair() -> tuple[int, float]: ...
def func_returning_path() -> os.PathLike: ...
def func_returning_vector() -> List[float]: ...
def func_returning_vector() -> list[float]: ...
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, overload
from typing import ClassVar, overload

PI: float
__version__: str
Expand Down Expand Up @@ -47,7 +47,7 @@ class Point:
def __init__(self) -> None: ...
@overload
def __init__(self, x: float, y: float) -> None: ...
def as_list(self) -> List[float]: ...
def as_list(self) -> list[float]: ...
@overload
def distance_to(self, x: float, y: float) -> float: ...
@overload
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from . import demo as demo
from typing import List, Tuple, overload
from typing import overload

class StaticMethods:
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -44,9 +44,9 @@ def func_incomplete_signature(*args, **kwargs):
"""func_incomplete_signature() -> dummy_sub_namespace::HasNoBinding"""
def func_returning_optional() -> int | None:
"""func_returning_optional() -> Optional[int]"""
def func_returning_pair() -> Tuple[int, float]:
def func_returning_pair() -> tuple[int, float]:
"""func_returning_pair() -> Tuple[int, float]"""
def func_returning_path() -> os.PathLike:
"""func_returning_path() -> os.PathLike"""
def func_returning_vector() -> List[float]:
def func_returning_vector() -> list[float]:
"""func_returning_vector() -> List[float]"""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, overload
from typing import ClassVar, overload

PI: float
__version__: str
Expand Down Expand Up @@ -73,7 +73,7 @@ class Point:
2. __init__(self: pybind11_fixtures.demo.Point, x: float, y: float) -> None
"""
def as_list(self) -> List[float]:
def as_list(self) -> list[float]:
"""as_list(self: pybind11_fixtures.demo.Point) -> List[float]"""
@overload
def distance_to(self, x: float, y: float) -> float:
Expand Down
Loading

0 comments on commit 1ee3e0b

Please sign in to comment.