diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 64e53e5..17fd686 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -4,9 +4,11 @@ import logging from collections.abc import Sequence from typing import Any, Literal, Self, Union +from _typeshed import Incomplete + from . import CustomException -logger = ... +logger: Incomplete __all__ = [ "func_empty", diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index d4340ac..8f18bdb 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -7,9 +7,11 @@ import re import typing from dataclasses import asdict, dataclass +from functools import cache from pathlib import Path import libcst as cst +import libcst.matchers as cstm from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum @@ -45,13 +47,13 @@ class KnownImport: Attributes ---------- - import_path : str, optional + import_path Dotted names after "from". - import_name : str, optional + import_name Dotted names after "import". - import_alias : str, optional + import_alias Name (without ".") after "as". - builtin_name : str, optional + builtin_name Names an object that's builtin and doesn't need an import. Examples @@ -65,6 +67,26 @@ class KnownImport: import_alias: str = None builtin_name: str = None + @classmethod + @cache + def typeshed_Incomplete(cls): + """Create import corresponding to ``from _typeshed import Incomplete``. + + This type is not actually available at runtime and only intended to be + used in stub files [1]_. + + Returns + ------- + import : KnownImport + The import corresponding to ``from _typeshed import Incomplete``. + + References + ---------- + .. [1] https://typing.readthedocs.io/en/latest/guides/writing_stubs.html#incomplete-stubs + """ + import_ = cls(import_path="_typeshed", import_name="Incomplete") + return import_ + @classmethod def one_from_config(cls, name, *, info): """Create one KnownImport from the configuration format. @@ -327,23 +349,47 @@ def __init__(self, *, module_name): def visit_ClassDef(self, node: cst.ClassDef) -> bool: self._stack.append(node.name.value) - - class_name = ".".join(self._stack[:1]) - qualname = f"{self.module_name}.{'.'.join(self._stack)}" - known_import = KnownImport(import_path=self.module_name, import_name=class_name) - self.known_imports[qualname] = known_import - + self._collect_type_annotation(self._stack) return True def leave_ClassDef(self, original_node: cst.ClassDef) -> None: self._stack.pop() def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: - self._stack.append(node.name.value) - return True + return False + + def visit_TypeAlias(self, node: cst.TypeAlias) -> bool: + """Collect type alias with 3.12 syntax.""" + stack = [*self._stack, node.name.value] + self._collect_type_annotation(stack) + return False + + def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: + """Collect type alias annotated with `TypeAlias`.""" + is_type_alias = cstm.matches( + node, + cstm.AnnAssign( + annotation=cstm.Annotation(annotation=cstm.Name(value="TypeAlias")) + ), + ) + if is_type_alias and node.value is not None: + names = cstm.findall(node.target, cstm.Name()) + assert len(names) == 1 + stack = [*self._stack, names[0].value] + self._collect_type_annotation(stack) + return False - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: - self._stack.pop() + def _collect_type_annotation(self, stack): + """Collect an importable type annotation. + + Parameters + ---------- + stack : Iterable[str] + A list of names that form the path to the collected type. + """ + qualname = ".".join([self.module_name, *stack]) + known_import = KnownImport(import_path=self.module_name, import_name=stack[0]) + self.known_imports[qualname] = known_import class TypesDatabase: diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index e4d26b8..6eb74f5 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -13,7 +13,7 @@ from numpydoc.docscrape import NumpyDocString from ._analysis import KnownImport -from ._utils import ContextFormatter, accumulate_qualname, escape_qualname +from ._utils import ContextFormatter, DocstubError, accumulate_qualname, escape_qualname logger = logging.getLogger(__name__) @@ -135,16 +135,29 @@ def _aggregate_annotations(*types): return values, imports -GrammarErrorFallback = Annotation( - value="Any", - imports=frozenset((KnownImport(import_path="typing", import_name="Any"),)), +FallbackAnnotation = Annotation( + value="Incomplete", imports=frozenset([KnownImport.typeshed_Incomplete()]) ) +class QualnameIsKeyword(DocstubError): + """Raised when a qualname is a blacklisted Python keyword.""" + + @lark.visitors.v_args(tree=True) class DoctypeTransformer(lark.visitors.Transformer): """Transformer for docstring type descriptions (doctypes). + Attributes + ---------- + blacklisted_qualnames : frozenset[str] + All Python keywords [1]_ are blacklisted from use in qualnames except for ``True`` + ``False`` and ``None``. + + References + ---------- + .. [1] https://docs.python.org/3/reference/lexical_analysis.html#keywords + Examples -------- >>> transformer = DoctypeTransformer() @@ -155,6 +168,43 @@ class DoctypeTransformer(lark.visitors.Transformer): [('tuple', 0, 5), ('int', 9, 12)] """ + blacklisted_qualnames = frozenset( + { + "await", + "else", + "import", + "pass", + "break", + "except", + "in", + "raise", + "class", + "finally", + "is", + "return", + "and", + "continue", + "for", + "lambda", + "try", + "as", + "def", + "from", + "nonlocal", + "while", + "assert", + "del", + "global", + "not", + "with", + "async", + "elif", + "if", + "or", + "yield", + } + ) + def __init__(self, *, types_db=None, replace_doctypes=None, **kwargs): """ Parameters @@ -204,7 +254,11 @@ def doctype_to_annotation(self, doctype): value=value, imports=frozenset(self._collected_imports) ) return annotation, self._unknown_qualnames - except (lark.exceptions.LexError, lark.exceptions.ParseError): + except ( + lark.exceptions.LexError, + lark.exceptions.ParseError, + QualnameIsKeyword, + ): self.stats["grammar_errors"] += 1 raise finally: @@ -274,6 +328,13 @@ def qualname(self, tree): _qualname = self._find_import(_qualname, meta=tree.meta) + if _qualname in self.blacklisted_qualnames: + msg = ( + f"qualname {_qualname!r} in docstring type description " + "is a reserved Python keyword and not allowed" + ) + raise QualnameIsKeyword(msg) + _qualname = lark.Token(type="QUALNAME", value=_qualname) return _qualname @@ -399,7 +460,7 @@ def _doctype_to_annotation(self, doctype, ds_line=0): details = details.replace("^", click.style("^", fg="red", bold=True)) if ctx: ctx.print_message("invalid syntax in doctype", details=details) - return GrammarErrorFallback + return FallbackAnnotation except lark.visitors.VisitError as e: tb = "\n".join(traceback.format_exception(e.orig_exc)) @@ -408,7 +469,7 @@ def _doctype_to_annotation(self, doctype, ds_line=0): ctx.print_message( "unexpected error while parsing doctype", details=details ) - return GrammarErrorFallback + return FallbackAnnotation else: for name, start_col, stop_col in unknown_qualnames: @@ -421,6 +482,28 @@ def _doctype_to_annotation(self, doctype, ds_line=0): ) return annotation + @cached_property + def attributes(self) -> dict[str, Annotation]: + annotations = {} + for attribute in self.np_docstring["Attributes"]: + if not attribute.type: + continue + + ds_line = 0 + for i, line in enumerate(self.docstring.split("\n")): + if attribute.name in line and attribute.type in line: + ds_line = i + break + + if attribute.name in annotations: + logger.warning("duplicate parameter name %r, ignoring", attribute.name) + continue + + annotation = self._doctype_to_annotation(attribute.type, ds_line=ds_line) + annotations[attribute.name] = annotation + + return annotations + @cached_property def parameters(self) -> dict[str, Annotation]: all_params = chain( diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index 3951a77..975338d 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -3,6 +3,7 @@ import enum import logging from dataclasses import dataclass +from functools import wraps import libcst as cst import libcst.matchers as cstm @@ -106,7 +107,7 @@ def try_format_stub(stub: str) -> str: return stub -class FuncType(enum.StrEnum): +class ScopeType(enum.StrEnum): MODULE = enum.auto() CLASS = enum.auto() FUNC = enum.auto() @@ -119,19 +120,19 @@ class FuncType(enum.StrEnum): class _Scope: """""" - type: FuncType + type: ScopeType node: cst.CSTNode = None @property def has_self_or_cls(self): - return self.type in {FuncType.METHOD, FuncType.CLASSMETHOD} + return self.type in {ScopeType.METHOD, ScopeType.CLASSMETHOD} @property def is_method(self): return self.type in { - FuncType.METHOD, - FuncType.CLASSMETHOD, - FuncType.STATICMETHOD, + ScopeType.METHOD, + ScopeType.CLASSMETHOD, + ScopeType.STATICMETHOD, } @property @@ -174,12 +175,47 @@ def _get_docstring_node(node): return docstring_node +def _log_error_with_line_context(func): + """Log unexpected errors in Py2StubTransformer` with line context. + + Parameters + ---------- + func : callable + A `leave_*` method of `Py2StubTransformer`. + + Returns + ------- + wrapped : callable + """ + + @wraps(func) + def wrapped(self: "Py2StubTransformer", original_node, updated_node): + try: + return func(self, original_node, updated_node) + except (SystemError, KeyboardInterrupt): + raise + except Exception: + position = self.get_metadata( + cst.metadata.PositionProvider, original_node + ).start + logger.exception( + "unexpected exception at %s:%s", self.current_source, position.line + ) + return updated_node + + return wrapped + + class Py2StubTransformer(cst.CSTTransformer): - """Transform syntax tree of a Python file into the tree of a stub file. + """Transform syntax tree of a Python file into the tree of a stub file [1]_. Attributes ---------- types_db : ~.TypesDatabase + + References + ---------- + .. [1] Stub file specification https://typing.readthedocs.io/en/latest/spec/distributing.html#stub-files """ METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) @@ -195,7 +231,7 @@ class Py2StubTransformer(cst.CSTTransformer): leading_whitespace=cst.SimpleWhitespace(value=" "), body=[cst.Expr(value=cst.Ellipsis())], ) - _Annotation_Any = cst.Annotation(cst.Name("Any")) + _Annotation_Incomplete = cst.Annotation(cst.Name("Incomplete")) _Annotation_None = cst.Annotation(cst.Name("None")) def __init__(self, *, types_db=None, replace_doctypes=None): @@ -228,7 +264,7 @@ def current_source(self, value): if self.types_db is not None: self.types_db.current_source = value - def python_to_stub(self, source, *, module_path=None): + def python_to_stub(self, source, *, module_path=None, try_format=True): """Convert Python source code to stub-file ready code. Parameters @@ -238,6 +274,9 @@ def python_to_stub(self, source, *, module_path=None): The location of the source that is transformed into a stub file. If given, used to enhance logging & error messages with more context information. + try_format : bool, optional + Try to format the output, if the appropriate dependencies are + installed. Returns ------- @@ -253,7 +292,8 @@ def python_to_stub(self, source, *, module_path=None): source_tree = cst.metadata.MetadataWrapper(source_tree) stub_tree = source_tree.visit(self) stub = stub_tree.code - stub = try_format_stub(stub) + if try_format is True: + stub = try_format_stub(stub) return stub finally: self._scope_stack = None @@ -272,7 +312,7 @@ def visit_ClassDef(self, node): ------- out : Literal[True] """ - self._scope_stack.append(_Scope(type=FuncType.CLASS, node=node)) + self._scope_stack.append(_Scope(type=ScopeType.CLASS, node=node)) pytypes = self._annotations_from_node(node) self._pytypes_stack.append(pytypes) return True @@ -330,15 +370,17 @@ def leave_FunctionDef(self, original_node, updated_node): ds_annotations = self._pytypes_stack.pop() if ds_annotations and ds_annotations.returns: assert ds_annotations.returns.value - node_changes["returns"] = cst.Annotation( + annotation = cst.Annotation( cst.parse_expression(ds_annotations.returns.value) ) + node_changes["returns"] = annotation self._required_imports |= ds_annotations.returns.imports updated_node = updated_node.with_changes(**node_changes) self._scope_stack.pop() return updated_node + @_log_error_with_line_context def leave_Param(self, original_node, updated_node): """Add type annotation to parameter. @@ -378,10 +420,10 @@ def leave_Param(self, original_node, updated_node): if pytype.imports: self._required_imports |= pytype.imports - # Potentially use "Any" except for first param in (class)methods + # Potentially use "Incomplete" except for first param in (class)methods elif not is_self_or_cls and updated_node.annotation is None: - node_changes["annotation"] = self._Annotation_Any - import_ = KnownImport(import_path="typing", import_name="Any") + node_changes["annotation"] = self._Annotation_Incomplete + import_ = KnownImport.typeshed_Incomplete() self._required_imports.add(import_) if node_changes: @@ -416,8 +458,9 @@ def leave_Comment(self, original_node, updated_node): """ return cst.RemovalSentinel.REMOVE + @_log_error_with_line_context def leave_Assign(self, original_node, updated_node): - """Drop value of assign statements from stub files. + """Handle assignment statements without annotations. Parameters ---------- @@ -426,17 +469,89 @@ def leave_Assign(self, original_node, updated_node): Returns ------- - updated_node : cst.Assign + updated_node : cst.Assign or cst.FlattenSentinel """ - targets = cstm.findall(updated_node, cstm.AssignTarget()) - names_are__all__ = [ - name - for target in targets - for name in cstm.findall(target, cst.Name(value="__all__")) + target_names = [ + name.value + for target in updated_node.targets + for name in cstm.findall(target, cstm.Name()) ] - if not names_are__all__: - # TODO replace with AnnAssign if possible / figure out assign type? - updated_node = updated_node.with_changes(value=self._body_replacement) + if "__all__" in target_names: + if len(target_names) > 1: + logger.warning( + "found `__all__` in assignment with multiple targets, not modifying it" + ) + return updated_node + + assert len(original_node.targets) > 0 + if len(target_names) == 1: + # Replace with annotated assignment + updated_node = self._create_annotated_assign(name=target_names[0]) + + else: + # Unpack assignment with multiple targets into multiple annotated ones + # e.g. `x, y = (1, 2)` -> `x: Any = ...; y: Any = ...` + unpacked = [] + for name in target_names: + is_last = name == target_names[-1] + sub_node = self._create_annotated_assign( + name=name, trailing_semicolon=not is_last + ) + unpacked.append(sub_node) + updated_node = cst.FlattenSentinel(unpacked) + + return updated_node + + @_log_error_with_line_context + def leave_AnnAssign(self, original_node, updated_node): + """Handle annotated assignment statements. + + Parameters + ---------- + original_node : cst.AnnAssign + updated_node : cst.AnnAssign + + Returns + ------- + updated_node : cst.AnnAssign + """ + name = updated_node.target.value + is_type_alias = cstm.matches( + updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias")) + ) + is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__")) + + # Remove value if not type alias or __all__ + if updated_node.value is not None and not is_type_alias and not is__all__: + updated_node = updated_node.with_changes( + value=None, equal=cst.MaybeSentinel.DEFAULT + ) + + # Replace with type annotation from docstring, if available + pytypes = self._pytypes_stack[-1] + if pytypes and name in pytypes.attributes: + pytype = pytypes.attributes[name] + expr = cst.parse_expression(pytype.value) + self._required_imports |= pytype.imports + + if updated_node.annotation is not None: + # Turn original annotation into str and print with context + position = self.get_metadata( + cst.metadata.PositionProvider, original_node + ).start + ctx = ContextFormatter(path=self.current_source, line=position.line) + replaced = cst.Module([]).code_for_node( + updated_node.annotation.annotation + ) + ctx.print_message( + short="replacing existing inline annotation", + details=f"{replaced}\n{"^" * len(replaced)} -> {pytype.value}", + ) + + updated_node = updated_node.with_deep_changes( + updated_node.annotation, annotation=expr + ) + return updated_node def visit_Module(self, node): @@ -450,7 +565,7 @@ def visit_Module(self, node): ------- Literal[True] """ - self._scope_stack.append(_Scope(type=FuncType.MODULE, node=node)) + self._scope_stack.append(_Scope(type=ScopeType.MODULE, node=node)) pytypes = self._annotations_from_node(node) self._pytypes_stack.append(pytypes) return True @@ -557,20 +672,20 @@ def _function_type(self, func_def): Returns ------- - func_type : FuncType + func_type : ScopeType """ - func_type = FuncType.FUNC - if self._scope_stack[-1].type == FuncType.CLASS: - func_type = FuncType.METHOD + func_type = ScopeType.FUNC + if self._scope_stack[-1].type == ScopeType.CLASS: + func_type = ScopeType.METHOD for decorator in func_def.decorators: if not hasattr(decorator.decorator, "value"): continue if decorator.decorator.value == "classmethod": - func_type = FuncType.CLASSMETHOD + func_type = ScopeType.CLASSMETHOD break if decorator.decorator.value == "staticmethod": - assert func_type == FuncType.METHOD - func_type = FuncType.STATICMETHOD + assert func_type == ScopeType.METHOD + func_type = ScopeType.STATICMETHOD break return func_type @@ -592,7 +707,6 @@ def _annotations_from_node(self, node): position = self.get_metadata( cst.metadata.PositionProvider, docstring_node ).start - ctx = ContextFormatter(path=self.current_source, line=position.line) try: annotations = DocstringAnnotations( @@ -609,3 +723,36 @@ def _annotations_from_node(self, node): e, ) return annotations + + def _create_annotated_assign(self, *, name, trailing_semicolon=False): + """Create an annotated assign. + + Parameters + ---------- + name : str + trailing_semicolon : bool, optional + + Returns + ------- + replacement : cst.AnnAssign + """ + pytypes = self._pytypes_stack[-1] + if pytypes and name in pytypes.attributes: + pytype = pytypes.attributes[name] + annotation = cst.Annotation(cst.parse_expression(pytype.value)) + self._required_imports |= pytype.imports + else: + annotation = self._Annotation_Incomplete + self._required_imports.add(KnownImport.typeshed_Incomplete()) + + semicolon = ( + cst.Semicolon(whitespace_after=cst.SimpleWhitespace(" ")) + if trailing_semicolon + else cst.MaybeSentinel.DEFAULT + ) + node = cst.AnnAssign( + target=cst.Name(name), + annotation=annotation, + semicolon=semicolon, + ) + return node diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index b85fb9a..329b1ba 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -261,3 +261,7 @@ def __post_init__(self): if self.path is not None and not isinstance(self.path, Path): msg = f"expected `path` to be of type `Path`, got {type(self.path)!r}" raise TypeError(msg) + + +class DocstubError(Exception): + """An error raised by docstub.""" diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 8959ec4..fd74554 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -1,6 +1,94 @@ +from textwrap import dedent + import pytest -from docstub._analysis import KnownImport, TypesDatabase +from docstub._analysis import KnownImport, TypeCollector, TypesDatabase + + +@pytest.fixture() +def module_factory(tmp_path): + """Fixture to help with creating adhoc modules with a given source. + + Parameters + ---------- + tmp_path : Path + + Returns + ------- + module_factory : callable + A callable with the signature `(src: str, module_name: str) -> Path`. + """ + + def _module_factory(src, module_name): + *parents, name = module_name.split(".") + + cwd = tmp_path + for parent in parents: + package = cwd / parent + package.mkdir() + (package / "__init__.py").touch() + cwd = package + + module_path = cwd / f"{name}.py" + with open(module_path, "w") as fp: + fp.write(src) + + return module_path + + return _module_factory + + +class Test_TypeCollector: + + def test_classes(self, module_factory): + module_path = module_factory( + src=dedent( + """ + class TopLevelClass: + class NestedClass: + pass + """ + ), + module_name="sub.module", + ) + imports = TypeCollector.collect(file=module_path) + assert len(imports) == 2 + assert imports["sub.module.TopLevelClass"] == KnownImport( + import_path="sub.module", import_name="TopLevelClass" + ) + # The import for the nested class should still use only the top-level + # class as an import target + assert imports["sub.module.TopLevelClass.NestedClass"] == KnownImport( + import_path="sub.module", import_name="TopLevelClass" + ) + + @pytest.mark.parametrize( + "src", ["type alias_name = int", "alias_name: TypeAlias = int"] + ) + def test_type_alias(self, module_factory, src): + module_path = module_factory(src=src, module_name="sub.module") + imports = TypeCollector.collect(file=module_path) + assert len(imports) == 1 + assert imports == { + "sub.module.alias_name": KnownImport( + import_path="sub.module", import_name="alias_name" + ) + } + + @pytest.mark.parametrize( + "src", + [ + "assign_name = 3", + "assign_name: int", + "assign_name: int = 3", + "assign_name = int", # Valid type alias, but not supported (yet) + "assign_name: TypeAlias", # No value, so should be ignored as a target + ], + ) + def test_ignores_assigns(self, module_factory, src): + module_path = module_factory(src=src, module_name="sub.module") + imports = TypeCollector.collect(file=module_path) + assert len(imports) == 0 class Test_TypesDatabase: diff --git a/tests/test_stubs.py b/tests/test_stubs.py index d14b70e..5b3dba7 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -1,7 +1,9 @@ +import re from textwrap import dedent import libcst as cst import libcst.matchers as cstm +import pytest from docstub._stubs import Py2StubTransformer, _get_docstring_node @@ -73,6 +75,43 @@ def foo(a, b=None): assert docstring_node is None +MODULE_ATTRIBUTE_TEMPLATE = '''\ +"""Module docstring. + +Attributes +---------- +{doctype} +""" + +{assign} +''' + +CLASS_ATTRIBUTE_TEMPLATE = '''\ +class TopLevel: + """Class docstring. + + Attributes + ---------- + {doctype} + """ + + {assign} +''' + +NESTED_CLASS_ATTRIBUTE_TEMPLATE = '''\ +class TopLevel: + class Nested: + """Class docstring. + + Attributes + ---------- + {doctype} + """ + + {assign} +''' + + class Test_Py2StubTransformer: def test_default_None(self): @@ -95,3 +134,85 @@ def foo(a=None, b=1): transformer = Py2StubTransformer() result = transformer.python_to_stub(source) assert expected in result + + # fmt: off + @pytest.mark.parametrize( + ("assign", "expected"), + [ + ("annotated: int", "annotated: int"), + # No implicit optional for values of `None` + ("annotated_value: int = None", "annotated_value: int"), + ("undocumented_assign = None", "undocumented_assign: Incomplete"), + # Type aliases are untouched + ("annot_alias: TypeAlias = int", "annot_alias: TypeAlias = int"), + ("type type_stmt = int", "type type_stmt = int"), + # Unpacking assignments are expanded + ("a, b = (4, 5)", "a: Incomplete; b: Incomplete"), + ("x, *y = (4, 5)", "x: Incomplete; y: Incomplete"), + # All is untouched + ("__all__ = ['foo']", "__all__ = ['foo']"), + ("__all__: list[str] = ['foo']", "__all__: list[str] = ['foo']"), + ], + ) + @pytest.mark.parametrize("scope", ["module", "class", "nested class"]) + def test_attributes_no_doctype(self, assign, expected, scope): + if scope == "module": + src = MODULE_ATTRIBUTE_TEMPLATE.format(assign=assign, doctype="") + elif scope == "class": + src = CLASS_ATTRIBUTE_TEMPLATE.format(assign=assign, doctype="") + elif scope == "nested class": + src = NESTED_CLASS_ATTRIBUTE_TEMPLATE.format(assign=assign, doctype="") + + transformer = Py2StubTransformer() + result = transformer.python_to_stub(src, try_format=False) + + # Find exactly one occurrence of `expected` + pattern = f"^ *({re.escape(expected)})$" + matches = re.findall(pattern, result, flags=re.MULTILINE) + assert [matches] == [[expected]], result + + # Docstrings are stripped + assert "'''" not in result + assert '"""' not in result + if "Any" in result: + assert "from typing import Any" in result + # fmt: on + + # fmt: off + @pytest.mark.parametrize( + ("assign", "doctype", "expected"), + [ + ("plain = 3", "plain : int", "plain: int"), + ("plain = None", "plain : int", "plain: int"), + ("x, y = (1, 2)", "x : int", "x: int; y: Incomplete"), + # Replace pre-existing annotations + ("annotated: float = 1.0", "annotated : int", "annotated: int"), + # Type aliases are untouched + ("alias: TypeAlias = int", "alias: str", "alias: TypeAlias = int"), + ("type alias = int", "alias: str", "type alias = int"), + ], + ) + # @pytest.mark.parametrize("scope", ["module", "class", "nested class"]) + @pytest.mark.parametrize("scope", ["module"]) + def test_attributes_with_doctype(self, assign, doctype, expected, scope): + if scope == "module": + src = MODULE_ATTRIBUTE_TEMPLATE.format(assign=assign, doctype=doctype) + elif scope == "class": + src = CLASS_ATTRIBUTE_TEMPLATE.format(assign=assign, doctype=doctype) + elif scope == "nested class": + src = NESTED_CLASS_ATTRIBUTE_TEMPLATE.format(assign=assign, doctype=doctype) + + transformer = Py2StubTransformer() + result = transformer.python_to_stub(src, try_format=False) + + # Find exactly one occurrence of `expected` + pattern = f"^ *({re.escape(expected)})$" + matches = re.findall(pattern, result, flags=re.MULTILINE) + assert [matches] == [[expected]], result + + # Docstrings are stripped + assert "'''" not in result + assert '"""' not in result + if "Any" in result: + assert "from typing import Any" in result + # fmt: on