Skip to content

Commit ca2f68d

Browse files
committed
WIP
1 parent a501579 commit ca2f68d

File tree

7 files changed

+6035
-51
lines changed

7 files changed

+6035
-51
lines changed

src/docstub/_analysis.py

Lines changed: 137 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Collect type information."""
22

3-
import builtins
43
import importlib
54
import json
65
import logging
@@ -226,27 +225,6 @@ def _is_type(value):
226225
return is_type
227226

228227

229-
def _builtin_types():
230-
"""Return known imports for all builtins (in the current runtime).
231-
232-
Returns
233-
-------
234-
known_imports : dict[str, KnownImport]
235-
"""
236-
known_builtins = set(dir(builtins))
237-
238-
known_imports = {}
239-
for name in known_builtins:
240-
if name.startswith("_"):
241-
continue
242-
value = getattr(builtins, name)
243-
if not _is_type(value):
244-
continue
245-
known_imports[name] = KnownImport(builtin_name=name)
246-
247-
return known_imports
248-
249-
250228
def _runtime_types_in_module(module_name):
251229
module = importlib.import_module(module_name)
252230
types = {}
@@ -277,18 +255,20 @@ def common_known_types():
277255
Examples
278256
--------
279257
>>> types = common_known_types()
280-
>>> types["str"]
281-
<KnownImport str (builtin)>
282-
>>> types["Iterable"]
283-
<KnownImport 'from collections.abc import Iterable'>
258+
>>> types["builtins.str"]
259+
<KnownImport 'from builtins import str'>
260+
>>> types["typing.Iterable"]
261+
<KnownImport 'from typing import Iterable'>
284262
>>> types["collections.abc.Iterable"]
285263
<KnownImport 'from collections.abc import Iterable'>
286264
"""
287-
known_imports = _builtin_types()
288-
known_imports |= _runtime_types_in_module("typing")
289-
# Overrides containers from typing
290-
known_imports |= _runtime_types_in_module("collections.abc")
291-
return known_imports
265+
from ._stdlib_types import stdlib_types
266+
267+
types = {
268+
f"{module}.{type_name}": KnownImport(import_path=module, import_name=type_name)
269+
for module, type_name in stdlib_types
270+
}
271+
return types
292272

293273

294274
class TypeCollector(cst.CSTVisitor):
@@ -334,7 +314,7 @@ def collect(cls, file):
334314
335315
Returns
336316
-------
337-
collected : dict[str, KnownImport]
317+
collected_types : dict[str, KnownImport]
338318
"""
339319
file = Path(file)
340320
with file.open("r") as fo:
@@ -343,7 +323,7 @@ def collect(cls, file):
343323
tree = cst.parse_module(source)
344324
collector = cls(module_name=module_name_from_path(file))
345325
tree.visit(collector)
346-
return collector.known_imports
326+
return collector.collected_types
347327

348328
def __init__(self, *, module_name):
349329
"""Initialize type collector.
@@ -354,7 +334,7 @@ def __init__(self, *, module_name):
354334
"""
355335
self.module_name = module_name
356336
self._stack = []
357-
self.known_imports = {}
337+
self.collected_types = {}
358338

359339
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
360340
self._stack.append(node.name.value)
@@ -396,9 +376,104 @@ def _collect_type_annotation(self, stack):
396376
stack : Iterable[str]
397377
A list of names that form the path to the collected type.
398378
"""
399-
qualname = ".".join([self.module_name, *stack])
400379
known_import = KnownImport(import_path=self.module_name, import_name=stack[0])
401-
self.known_imports[qualname] = known_import
380+
381+
qualname = f"{self.module_name}.{'.'.join(stack)}"
382+
scoped_name = f"{self.module_name}:{'.'.join(stack)}"
383+
self.collected_types[qualname] = known_import
384+
self.collected_types[scoped_name] = known_import
385+
386+
387+
class StubTypeCollector(TypeCollector):
388+
389+
def __init__(self, *, module_name):
390+
"""Initialize type collector.
391+
392+
Parameters
393+
----------
394+
module_name : str
395+
"""
396+
super().__init__(module_name=module_name)
397+
self.collected_types = set()
398+
self.dunder_all = set()
399+
400+
@classmethod
401+
def collect(cls, file):
402+
"""Collect importable type annotations in given file.
403+
404+
Parameters
405+
----------
406+
file : Path
407+
408+
Returns
409+
-------
410+
collected_types : dict[str, KnownImport]
411+
"""
412+
file = Path(file)
413+
with file.open("r") as fo:
414+
source = fo.read()
415+
416+
tree = cst.parse_module(source)
417+
collector = cls(module_name=module_name_from_path(file))
418+
tree.visit(collector)
419+
return collector.collected_types, collector.dunder_all
420+
421+
def visit_ImportFrom(self, node):
422+
# https://typing.python.org/en/latest/spec/distributing.html#import-conventions
423+
424+
if cstm.matches(node, cstm.ImportFrom(names=cstm.ImportStar())):
425+
module_names = cstm.findall(node.module, cstm.Name())
426+
module = "_".join(name.value for name in module_names)
427+
stack = [*self._stack, f"<Reference: {module}.*>"]
428+
self._collect_type_annotation(stack)
429+
430+
names = cstm.findall(node, cstm.AsName())
431+
for name in names:
432+
if cstm.matches(name, cstm.AsName(name=cstm.Name())):
433+
value = name.name.value
434+
assert value
435+
if value == "__all__":
436+
continue
437+
438+
stack = [*self._stack, value]
439+
self._collect_type_annotation(stack)
440+
441+
def visit_AugAssign(self, node):
442+
is_add_assign_to_dunder_all = cstm.matches(
443+
node,
444+
cstm.AugAssign(
445+
target=cstm.Name(value="__all__"), operator=cstm.AddAssign()
446+
),
447+
)
448+
is_assign_list = cstm.matches(node.value, cstm.List())
449+
if is_add_assign_to_dunder_all and is_assign_list:
450+
strings = cstm.findall(node.value, cstm.SimpleString())
451+
for string in strings:
452+
self._collect_dunder_all(string.value)
453+
454+
def visit_Assign(self, node):
455+
is_assign_to_dunder_all = cstm.matches(
456+
node,
457+
cstm.Assign(targets=[cstm.AssignTarget(target=cstm.Name(value="__all__"))]),
458+
)
459+
is_assign_list = cstm.matches(node.value, cstm.List())
460+
if is_assign_to_dunder_all and is_assign_list:
461+
strings = cstm.findall(node.value, cstm.SimpleString())
462+
for string in strings:
463+
self._collect_dunder_all(string.value)
464+
465+
def _collect_type_annotation(self, stack):
466+
"""Collect an importable type annotation.
467+
468+
Parameters
469+
----------
470+
stack : Iterable[str]
471+
A list of names that form the path to the collected type.
472+
"""
473+
self.collected_types.add((self.module_name, ".".join(stack)))
474+
475+
def _collect_dunder_all(self, value):
476+
self.dunder_all.add((self.module_name, value.strip("'\"")))
402477

403478

404479
class TypeMatcher:
@@ -427,6 +502,7 @@ def __init__(
427502
types=None,
428503
type_prefixes=None,
429504
type_nicknames=None,
505+
implicit_modules=("collections.abc", "typing", "_typeshed"),
430506
):
431507
"""
432508
Parameters
@@ -438,6 +514,7 @@ def __init__(
438514
self.types = types or common_known_types()
439515
self.type_prefixes = type_prefixes or {}
440516
self.type_nicknames = type_nicknames or {}
517+
self.implicit_modules = implicit_modules
441518
self.successful_queries = 0
442519
self.unknown_qualnames = []
443520

@@ -492,20 +569,39 @@ def match(self, search_name):
492569
# Replace alias
493570
search_name = self.type_nicknames.get(search_name, search_name)
494571

495-
if type_origin is None and self.current_module:
496-
# Try scope of current module
497-
module_name = module_name_from_path(self.current_module)
498-
try_qualname = f"{module_name}.{search_name}"
572+
if type_origin is None:
573+
# Try builtin
574+
try_qualname = f"builtins.{search_name}"
499575
type_origin = self.types.get(try_qualname)
500576
if type_origin:
501577
type_name = search_name
502578

503579
if type_origin is None and search_name in self.types:
580+
# Direct match
504581
type_name = search_name
505582
type_origin = self.types[search_name]
506583

584+
if type_origin is None and self.current_module:
585+
# Try scope of current module
586+
for sep in [".", ":"]:
587+
try_qualname = f"{self.current_module}{sep}{search_name}"
588+
type_origin = self.types.get(try_qualname)
589+
if type_origin:
590+
type_name = search_name
591+
break
592+
593+
if type_origin is None and self.implicit_modules:
594+
# Try implicit modules
595+
for module in self.implicit_modules:
596+
try_qualname = f"{module}.{search_name}"
597+
type_origin = self.types.get(try_qualname)
598+
if type_origin:
599+
type_name = search_name
600+
break
601+
507602
if type_origin is None:
508-
# Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
603+
# Try matching with module prefix,
604+
# try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
509605
for partial_qualname in reversed(accumulate_qualname(search_name)):
510606
type_origin = self.type_prefixes.get(partial_qualname)
511607
if type_origin:

src/docstub/_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ def _collect_types(root_path):
8989
-------
9090
types : dict[str, ~.KnownImport]
9191
"""
92-
types = common_known_types()
93-
92+
types = {}
9493
collect_cached_types = FileCache(
9594
func=TypeCollector.collect,
9695
serializer=TypeCollector.ImportSerializer(),
9796
cache_dir=Path.cwd() / ".docstub_cache",
9897
name=f"{__version__}/collected_types",
9998
)
99+
100100
if root_path.is_dir():
101101
for source_path in walk_python_package(root_path):
102102
logger.info("collecting types in %s", source_path)

0 commit comments

Comments
 (0)