Skip to content

Commit 019e4e8

Browse files
authored
Only use | None for optional parameters if appropriate (#14)
* Allow sphinx references enclosed in "`" only The prefix in ":class:`Foo`" is actually optional for Sphinx to detect it as a reference. So allow "`Foo`" too. * Only use `| None` for optional parameters if appropriate This commit also includes a few other tweaks. I should really pay a bit more attention to making more focused changes. Partial commits only get you so far... A minor change is that `PackageFile` was replaced with `module_name_from_path` which should make for a more sensible approach that doesn't rely as much on carrying a subclass of `Path` around.
1 parent f014783 commit 019e4e8

File tree

12 files changed

+231
-128
lines changed

12 files changed

+231
-128
lines changed

examples/example_pkg-stubs/_basic.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,11 @@ class ExampleClass:
3838
def method_in_nested_class(self, a1: complex) -> None: ...
3939

4040
def __init__(self, a1: str, a2: float = ...) -> None: ...
41-
def method(self, a1: float, a2: float | None) -> list[float]: ...
41+
def method(
42+
self, a1: float, a2: float = ..., a3: float | None = ...
43+
) -> list[float]: ...
4244
@staticmethod
43-
def some_staticmethod(a1: float, a2: float | None = ...) -> dict[str, Any]: ...
45+
def some_staticmethod(a1: float, a2: str = ...) -> dict[str, Any]: ...
4446
@property
4547
def some_property(self) -> str: ...
4648
@some_property.setter

examples/example_pkg-stubs/_numpy.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def func_object_with_numpy_objects(
66
a1: np.int8, a2: np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike
77
) -> None: ...
88
def func_ndarray(
9-
a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] = ...
9+
a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] | None = ...
1010
) -> tuple[NDArray[np.uint8], NDArray[complex]]: ...
1111
def func_array_like(
1212
a1: ArrayLike, a2: ArrayLike, a3: ArrayLike[float], a4: ArrayLike[np.uint8]

examples/example_pkg/_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,14 @@ def method_in_nested_class(self, a1):
9090
def __init__(self, a1, a2=0):
9191
pass
9292

93-
def method(self, a1, a2):
93+
def method(self, a1, a2=0, a3=None):
9494
"""Dummy.
9595
9696
Parameters
9797
----------
9898
a1 : float
9999
a2 : float, optional
100+
a3 : float, optional
100101
101102
Returns
102103
-------
@@ -110,7 +111,7 @@ def some_staticmethod(a1, a2="uno"):
110111
Parameters
111112
----------
112113
a1 : float
113-
a2 : float, optional
114+
a2 : str, optional
114115
115116
Returns
116117
-------

src/docstub/_analysis.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import libcst as cst
1212

13-
from ._utils import accumulate_qualname
13+
from ._utils import accumulate_qualname, module_name_from_path
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -42,16 +42,21 @@ def _shared_leading_path(*paths):
4242
class KnownImport:
4343
"""Import information associated with a single known type annotation.
4444
45-
Parameters
45+
Attributes
4646
----------
47-
import_name :
48-
Dotted names after "import".
49-
import_path :
47+
import_path : str, optional
5048
Dotted names after "from".
51-
import_alias :
49+
import_name : str, optional
50+
Dotted names after "import".
51+
import_alias : str, optional
5252
Name (without ".") after "as".
53-
builtin_name :
53+
builtin_name : str, optional
5454
Names an object that's builtin and doesn't need an import.
55+
56+
Examples
57+
--------
58+
>>> KnownImport(import_path="numpy", import_name="uint8", import_alias="ui8")
59+
<KnownImport 'from numpy import uint8 as ui8'>
5560
"""
5661

5762
import_name: str = None
@@ -170,14 +175,6 @@ def __str__(self):
170175
return out
171176

172177

173-
@dataclass(slots=True, frozen=True)
174-
class InspectionContext:
175-
"""Currently inspected module and other information."""
176-
177-
file_path: Path
178-
in_package_path: str
179-
180-
181178
def _is_type(value) -> bool:
182179
"""Check if value is a type."""
183180
# Checking for isinstance(..., type) isn't enough, some types such as
@@ -262,45 +259,57 @@ def common_known_imports():
262259
return known_imports
263260

264261

265-
class KnownImportCollector(cst.CSTVisitor):
262+
class TypeCollector(cst.CSTVisitor):
266263
@classmethod
267-
def collect(cls, file, module_name):
264+
def collect(cls, file):
265+
"""Collect importable type annotations in given file.
266+
267+
Parameters
268+
----------
269+
file : Path
270+
271+
Returns
272+
-------
273+
collected : dict[str, KnownImport]
274+
"""
268275
file = Path(file)
269276
with file.open("r") as fo:
270277
source = fo.read()
271278

272279
tree = cst.parse_module(source)
273-
collector = cls(module_name=module_name)
280+
collector = cls(module_name=module_name_from_path(file))
274281
tree.visit(collector)
275282
return collector.known_imports
276283

277284
def __init__(self, *, module_name):
285+
"""Initialize type collector.
286+
287+
Parameters
288+
----------
289+
module_name : str
290+
"""
278291
self.module_name = module_name
279292
self._stack = []
280293
self.known_imports = {}
281294

282-
def visit_ClassDef(self, node):
295+
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
283296
self._stack.append(node.name.value)
284297

285298
class_name = ".".join(self._stack[:1])
286299
qualname = f"{self.module_name}.{'.'.join(self._stack)}"
287-
288-
known_import = KnownImport(
289-
import_name=class_name,
290-
import_path=self.module_name,
291-
)
300+
known_import = KnownImport(import_path=self.module_name, import_name=class_name)
292301
self.known_imports[qualname] = known_import
293302

294303
return True
295304

296-
def leave_ClassDef(self, original_node):
305+
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
297306
self._stack.pop()
298307

299-
def visit_FunctionDef(self, node):
308+
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
300309
self._stack.append(node.name.value)
301310
return True
302311

303-
def leave_FunctionDef(self, original_node):
312+
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
304313
self._stack.pop()
305314

306315

@@ -395,7 +404,8 @@ def query(self, search_name):
395404

396405
if known_import is None and self.current_source:
397406
# Try scope of current module
398-
try_qualname = f"{self.current_source.import_path}.{search_name}"
407+
module_name = module_name_from_path(self.current_source)
408+
try_qualname = f"{module_name}.{search_name}"
399409
known_import = self.known_imports.get(try_qualname)
400410
if known_import:
401411
annotation_name = search_name

src/docstub/_cli.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from ._analysis import (
88
KnownImport,
9-
KnownImportCollector,
109
StaticInspector,
10+
TypeCollector,
1111
common_known_imports,
1212
)
1313
from ._config import Config
@@ -92,9 +92,7 @@ def main(source_dir, out_dir, config_path, verbose):
9292
known_imports = common_known_imports()
9393
for source_path in walk_source(source_dir):
9494
logger.info("collecting types in %s", source_path)
95-
known_imports_in_source = KnownImportCollector.collect(
96-
source_path, module_name=source_path.import_path
97-
)
95+
known_imports_in_source = TypeCollector.collect(source_path)
9896
known_imports.update(known_imports_in_source)
9997
known_imports.update(KnownImport.many_from_config(config.known_imports))
10098

src/docstub/_docstrings.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def __post_init__(self):
4848
object.__setattr__(self, "imports", frozenset(self.imports))
4949
if "~" in self.value:
5050
raise ValueError(f"unexpected '~' in annotation value: {self.value}")
51+
for import_ in self.imports:
52+
if not isinstance(import_, KnownImport):
53+
raise TypeError(f"unexpected type {type(import_)} in `imports`")
5154

5255
def __str__(self) -> str:
5356
return self.value
@@ -95,6 +98,22 @@ def as_yields_generator(cls, yield_types, receive_types=()):
9598
# TODO
9699
raise NotImplementedError()
97100

101+
def as_optional(self):
102+
"""Return optional version of this annotation by appending `| None`.
103+
104+
Returns
105+
-------
106+
optional : Annotation
107+
108+
Examples
109+
--------
110+
>>> Annotation(value="int").as_optional()
111+
Annotation(value='int | None', imports=frozenset())
112+
"""
113+
value = f"{self.value} | None"
114+
optional = type(self)(value=value, imports=self.imports)
115+
return optional
116+
98117
@staticmethod
99118
def _aggregate_annotations(*types):
100119
"""Aggregate values and imports of given Annotations.
@@ -118,14 +137,7 @@ def _aggregate_annotations(*types):
118137

119138
GrammarErrorFallback = Annotation(
120139
value="Any",
121-
imports=frozenset(
122-
(
123-
KnownImport(
124-
import_name="Any",
125-
import_path="typing",
126-
),
127-
)
128-
),
140+
imports=frozenset((KnownImport(import_path="typing", import_name="Any"),)),
129141
)
130142

131143

@@ -233,12 +245,8 @@ def types_or(self, tree):
233245
return out
234246

235247
def optional(self, tree):
236-
out = "None"
237-
literal = [child for child in tree.children if child.type == "LITERAL"]
238-
assert len(literal) <= 1
239-
if literal:
240-
out = lark.Discard # Type should cover the default
241-
return out
248+
logger.debug("dropping optional / default info")
249+
return lark.Discard
242250

243251
def extra_info(self, tree):
244252
logger.debug("dropping extra info")

0 commit comments

Comments
 (0)