Skip to content

Commit 1e00163

Browse files
authored
Allow prefixing common known types with module (#53)
Previously `collections.abc.Iterable` couldn't be used in docstrings, which seems a bit unexpected.
1 parent 5282c0b commit 1e00163

File tree

7 files changed

+60
-47
lines changed

7 files changed

+60
-47
lines changed

docs/user_guide.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ There are several interesting things to note here:
100100
- Optional arguments that default to `None` are recognized and a `| None` is appended automatically if the type doesn't include it already.
101101
The `optional` or `default = ...` part don't influence the annotation.
102102
103-
- Common container types from Python's standard library such as `Iterable` can be used and a necessary import will be added automatically.
103+
- Referencing the `float` and `Iterable` types worked out of the box.
104+
All builtin types as well as types from the standard libraries `typing` and `collections.abc` module can be used.
105+
Necessary imports will be added automatically to the stub file.
104106
105107
106108
## Using types & nicknames

examples/example_pkg-stubs/_basic.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def func_literals(
3232
def func_use_from_elsewhere(
3333
a1: CustomException,
3434
a2: ExampleClass,
35-
a3: CustomException.NestedClass,
35+
a3: ExampleClass.NestedClass,
3636
a4: ExampleClass.NestedClass,
3737
) -> tuple[CustomException, ExampleClass.NestedClass]: ...
3838

examples/example_pkg-stubs/_numpy.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# File generated with docstub
22

3+
import numpy
34
import numpy as np
45
from numpy.typing import ArrayLike, NDArray
56

67
def func_object_with_numpy_objects(
7-
a1: np.int8, a2: np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike
8+
a1: numpy.int8, a2: np.int16, a3: numpy.typing.DTypeLike, a4: np.typing.DTypeLike
89
) -> None: ...
910
def func_ndarray(
1011
a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] | None = ...

examples/example_pkg/_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def func_contains(a1, a2, a3, a4, a5, a6, a7):
3333
----------
3434
a1 : list[float]
3535
a2 : dict[str, Union[int, str]]
36-
a3 : Sequence[int | float]
36+
a3 : collections.abc.Sequence[int | float]
3737
a4 : frozenset[bytes]
3838
a5 : tuple of int
3939
a6 : list of (int, str)

src/docstub/_analysis.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Collect type information."""
22

33
import builtins
4-
import collections.abc
4+
import importlib
55
import json
66
import logging
77
import re
8-
import typing
98
from dataclasses import asdict, dataclass
109
from functools import cache
1110
from pathlib import Path
@@ -227,7 +226,7 @@ def _is_type(value):
227226
return is_type
228227

229228

230-
def _builtin_imports():
229+
def _builtin_types():
231230
"""Return known imports for all builtins (in the current runtime).
232231
233232
Returns
@@ -248,45 +247,24 @@ def _builtin_imports():
248247
return known_imports
249248

250249

251-
def _typing_imports():
252-
"""Return known imports for public types in the `typing` module.
253-
254-
Returns
255-
-------
256-
known_imports : dict[str, KnownImport]
257-
"""
258-
known_imports = {}
259-
for name in typing.__all__:
250+
def _runtime_types_in_module(module_name):
251+
module = importlib.import_module(module_name)
252+
types = {}
253+
for name in module.__all__:
260254
if name.startswith("_"):
261255
continue
262-
value = getattr(typing, name)
256+
value = getattr(module, name)
263257
if not _is_type(value):
264258
continue
265-
known_imports[name] = KnownImport.one_from_config(name, info={"from": "typing"})
266-
return known_imports
267-
268259

269-
def _collections_abc_imports():
270-
"""Return known imports for public types in the `collections.abc` module.
260+
import_ = KnownImport(import_path=module_name, import_name=name)
261+
types[name] = import_
262+
types[f"{module_name}.{name}"] = import_
271263

272-
Returns
273-
-------
274-
known_imports : dict[str, KnownImport]
275-
"""
276-
known_imports = {}
277-
for name in collections.abc.__all__:
278-
if name.startswith("_"):
279-
continue
280-
value = getattr(collections.abc, name)
281-
if not _is_type(value):
282-
continue
283-
known_imports[name] = KnownImport.one_from_config(
284-
name, info={"from": "collections.abc"}
285-
)
286-
return known_imports
264+
return types
287265

288266

289-
def common_known_imports():
267+
def common_known_types():
290268
"""Return known imports for commonly supported types.
291269
292270
This includes builtin types, and types from the `typing` or
@@ -295,10 +273,21 @@ def common_known_imports():
295273
Returns
296274
-------
297275
known_imports : dict[str, KnownImport]
276+
277+
Examples
278+
--------
279+
>>> types = common_known_types()
280+
>>> types["str"]
281+
<KnownImport str (builtin)>
282+
>>> types["Iterable"]
283+
<KnownImport 'from collections.abc import Iterable'>
284+
>>> types["collections.abc.Iterable"]
285+
<KnownImport 'from collections.abc import Iterable'>
298286
"""
299-
known_imports = _builtin_imports()
300-
known_imports |= _typing_imports()
301-
known_imports |= _collections_abc_imports() # Overrides containers from typing
287+
known_imports = _builtin_types()
288+
known_imports |= _runtime_types_in_module("typing")
289+
# Overrides containers from typing
290+
known_imports |= _runtime_types_in_module("collections.abc")
302291
return known_imports
303292

304293

@@ -426,7 +415,7 @@ class TypeMatcher:
426415
427416
Examples
428417
--------
429-
>>> from docstub._analysis import TypeMatcher, common_known_imports
418+
>>> from docstub._analysis import TypeMatcher, common_known_types
430419
>>> db = TypeMatcher()
431420
>>> db.match("Any")
432421
('Any', <KnownImport 'from typing import Any'>)
@@ -446,7 +435,7 @@ def __init__(
446435
type_prefixes : dict[str, KnownImport]
447436
type_nicknames : dict[str, str]
448437
"""
449-
self.types = types or common_known_imports()
438+
self.types = types or common_known_types()
450439
self.type_prefixes = type_prefixes or {}
451440
self.type_nicknames = type_nicknames or {}
452441
self.successful_queries = 0

src/docstub/_cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
KnownImport,
1212
TypeCollector,
1313
TypeMatcher,
14-
common_known_imports,
14+
common_known_types,
1515
)
1616
from ._cache import FileCache
1717
from ._config import Config
@@ -89,7 +89,7 @@ def _collect_types(root_path):
8989
-------
9090
types : dict[str, ~.KnownImport]
9191
"""
92-
types = common_known_imports()
92+
types = common_known_types()
9393

9494
collect_cached_types = FileCache(
9595
func=TypeCollector.collect,
@@ -213,7 +213,7 @@ def run(root_path, out_dir, config_paths, group_errors, allow_errors, verbose):
213213

214214
config = _load_configuration(config_paths)
215215

216-
types = common_known_imports()
216+
types = common_known_types()
217217
types |= _collect_types(root_path)
218218
types |= {
219219
type_name: KnownImport(import_path=module, import_name=type_name)

tests/test_analysis.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
import pytest
44

5-
from docstub._analysis import KnownImport, TypeCollector, TypeMatcher
5+
from docstub._analysis import (
6+
KnownImport,
7+
TypeCollector,
8+
TypeMatcher,
9+
)
610

711

812
class Test_KnownImport:
@@ -182,3 +186,20 @@ def test_query_prefix(self, search_name, expected_name, expected_origin):
182186
assert type_name.startswith(type_origin.target)
183187
assert type_name == expected_name
184188
# fmt: on
189+
190+
@pytest.mark.parametrize(
191+
("search_name", "import_path"),
192+
[
193+
("Iterable", "collections.abc"),
194+
("collections.abc.Iterable", "collections.abc"),
195+
("Literal", "typing"),
196+
("typing.Literal", "typing"),
197+
],
198+
)
199+
def test_common_known_types(self, search_name, import_path):
200+
matcher = TypeMatcher()
201+
type_name, type_origin = matcher.match(search_name)
202+
203+
assert type_name == search_name.split(".")[-1]
204+
assert type_origin is not None
205+
assert type_origin.import_path == import_path

0 commit comments

Comments
 (0)