Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubgen: Replace obsolete typing aliases with builtin containers #16780

Merged
merged 3 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import os.path
import sys
import traceback
from typing import Final, Iterable
from typing import Final, Iterable, Iterator

import mypy.build
import mypy.mixedtraverser
Expand Down Expand Up @@ -114,6 +114,7 @@
from mypy.stubdoc import ArgSig, FunctionSig
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
from mypy.stubutil import (
TYPING_BUILTIN_REPLACEMENTS,
BaseStubGenerator,
CantImport,
ClassInfo,
Expand Down Expand Up @@ -289,20 +290,19 @@ def visit_call_expr(self, node: CallExpr) -> str:
raise ValueError(f"Unknown argument kind {kind} in call")
return f"{callee}({', '.join(args)})"

def _visit_ref_expr(self, node: NameExpr | MemberExpr) -> str:
fullname = self.stubgen.get_fullname(node)
if fullname in TYPING_BUILTIN_REPLACEMENTS:
return self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=False)
qualname = get_qualified_name(node)
self.stubgen.import_tracker.require_name(qualname)
return qualname

def visit_name_expr(self, node: NameExpr) -> str:
self.stubgen.import_tracker.require_name(node.name)
return node.name
return self._visit_ref_expr(node)

def visit_member_expr(self, o: MemberExpr) -> str:
node: Expression = o
trailer = ""
while isinstance(node, MemberExpr):
trailer = "." + node.name + trailer
node = node.expr
if not isinstance(node, NameExpr):
return ERROR_MARKER
self.stubgen.import_tracker.require_name(node.name)
return node.name + trailer
return self._visit_ref_expr(o)

def visit_str_expr(self, node: StrExpr) -> str:
return repr(node.value)
Expand Down Expand Up @@ -351,11 +351,17 @@ def find_defined_names(file: MypyFile) -> set[str]:
return finder.names


def get_assigned_names(lvalues: Iterable[Expression]) -> Iterator[str]:
for lvalue in lvalues:
if isinstance(lvalue, NameExpr):
yield lvalue.name
elif isinstance(lvalue, TupleExpr):
yield from get_assigned_names(lvalue.items)


class DefinitionFinder(mypy.traverser.TraverserVisitor):
"""Find names of things defined at the top level of a module."""

# TODO: Assignment statements etc.

def __init__(self) -> None:
# Short names of things defined at the top level.
self.names: set[str] = set()
Expand All @@ -368,6 +374,10 @@ def visit_func_def(self, o: FuncDef) -> None:
# Don't recurse, as we only keep track of top-level definitions.
self.names.add(o.name)

def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
for name in get_assigned_names(o.lvalues):
self.names.add(name)


def find_referenced_names(file: MypyFile) -> set[str]:
finder = ReferenceFinder()
Expand Down Expand Up @@ -1023,10 +1033,15 @@ def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
and isinstance(expr.node, (FuncDef, Decorator, MypyFile))
or isinstance(expr.node, TypeInfo)
) and not self.is_private_member(expr.node.fullname)
elif (
isinstance(expr, IndexExpr)
and isinstance(expr.base, NameExpr)
and not self.is_private_name(expr.base.name)
elif isinstance(expr, IndexExpr) and (
(isinstance(expr.base, NameExpr) and not self.is_private_name(expr.base.name))
or ( # Also some known aliases that could be member expression
isinstance(expr.base, MemberExpr)
and not self.is_private_member(get_qualified_name(expr.base))
and self.get_fullname(expr.base).startswith(
("builtins.", "typing.", "typing_extensions.", "collections.abc.")
)
)
):
if isinstance(expr.index, TupleExpr):
indices = expr.index.items
Expand Down
33 changes: 29 additions & 4 deletions mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@
# Modules that may fail when imported, or that may have side effects (fully qualified).
NOT_IMPORTABLE_MODULES = ()

# Typing constructs to be replaced by their builtin equivalents.
TYPING_BUILTIN_REPLACEMENTS: Final = {
# From typing
"typing.Text": "builtins.str",
"typing.Tuple": "builtins.tuple",
"typing.List": "builtins.list",
"typing.Dict": "builtins.dict",
"typing.Set": "builtins.set",
"typing.FrozenSet": "builtins.frozenset",
"typing.Type": "builtins.type",
# From typing_extensions
"typing_extensions.Text": "builtins.str",
"typing_extensions.Tuple": "builtins.tuple",
"typing_extensions.List": "builtins.list",
"typing_extensions.Dict": "builtins.dict",
"typing_extensions.Set": "builtins.set",
"typing_extensions.FrozenSet": "builtins.frozenset",
"typing_extensions.Type": "builtins.type",
}


class CantImport(Exception):
def __init__(self, module: str, message: str) -> None:
Expand Down Expand Up @@ -229,6 +249,8 @@ def visit_unbound_type(self, t: UnboundType) -> str:
return " | ".join([item.accept(self) for item in t.args])
if fullname == "typing.Optional":
return f"{t.args[0].accept(self)} | None"
if fullname in TYPING_BUILTIN_REPLACEMENTS:
s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True)
if self.known_modules is not None and "." in s:
# see if this object is from any of the modules that we're currently processing.
# reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo".
Expand Down Expand Up @@ -476,7 +498,7 @@ def reexport(self, name: str) -> None:
def import_lines(self) -> list[str]:
"""The list of required import lines (as strings with python code).

In order for a module be included in this output, an indentifier must be both
In order for a module be included in this output, an identifier must be both
'required' via require_name() and 'imported' via add_import_from()
or add_import()
"""
Expand Down Expand Up @@ -585,9 +607,9 @@ def __init__(
# a corresponding import statement.
self.known_imports = {
"_typeshed": ["Incomplete"],
"typing": ["Any", "TypeVar", "NamedTuple"],
"typing": ["Any", "TypeVar", "NamedTuple", "TypedDict"],
"collections.abc": ["Generator"],
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
"typing_extensions": ["ParamSpec", "TypeVarTuple"],
}

def get_sig_generators(self) -> list[SignatureGenerator]:
Expand All @@ -613,7 +635,10 @@ def add_name(self, fullname: str, require: bool = True) -> str:
"""
module, name = fullname.rsplit(".", 1)
alias = "_" + name if name in self.defined_names else None
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
while alias in self.defined_names:
alias = "_" + alias
if module != "builtins" or alias: # don't import from builtins unless needed
self.import_tracker.add_import_from(module, [(name, alias)], require=require)
return alias or name

def add_import_line(self, line: str) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from . import demo as demo
from typing import List, Tuple, overload
from typing import overload

class StaticMethods:
def __init__(self, *args, **kwargs) -> None: ...
Expand All @@ -22,6 +22,6 @@ class TestStruct:

def func_incomplete_signature(*args, **kwargs): ...
def func_returning_optional() -> int | None: ...
def func_returning_pair() -> Tuple[int, float]: ...
def func_returning_pair() -> tuple[int, float]: ...
def func_returning_path() -> os.PathLike: ...
def func_returning_vector() -> List[float]: ...
def func_returning_vector() -> list[float]: ...
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, overload
from typing import ClassVar, overload

PI: float
__version__: str
Expand Down Expand Up @@ -47,7 +47,7 @@ class Point:
def __init__(self) -> None: ...
@overload
def __init__(self, x: float, y: float) -> None: ...
def as_list(self) -> List[float]: ...
def as_list(self) -> list[float]: ...
@overload
def distance_to(self, x: float, y: float) -> float: ...
@overload
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from . import demo as demo
from typing import List, Tuple, overload
from typing import overload

class StaticMethods:
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -44,9 +44,9 @@ def func_incomplete_signature(*args, **kwargs):
"""func_incomplete_signature() -> dummy_sub_namespace::HasNoBinding"""
def func_returning_optional() -> int | None:
"""func_returning_optional() -> Optional[int]"""
def func_returning_pair() -> Tuple[int, float]:
def func_returning_pair() -> tuple[int, float]:
"""func_returning_pair() -> Tuple[int, float]"""
def func_returning_path() -> os.PathLike:
"""func_returning_path() -> os.PathLike"""
def func_returning_vector() -> List[float]:
def func_returning_vector() -> list[float]:
"""func_returning_vector() -> List[float]"""
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, List, overload
from typing import ClassVar, overload

PI: float
__version__: str
Expand Down Expand Up @@ -73,7 +73,7 @@ class Point:

2. __init__(self: pybind11_fixtures.demo.Point, x: float, y: float) -> None
"""
def as_list(self) -> List[float]:
def as_list(self) -> list[float]:
"""as_list(self: pybind11_fixtures.demo.Point) -> List[float]"""
@overload
def distance_to(self, x: float, y: float) -> float:
Expand Down
Loading
Loading