|
10 | 10 |
|
11 | 11 | import libcst as cst
|
12 | 12 |
|
13 |
| -from ._utils import accumulate_qualname |
| 13 | +from ._utils import accumulate_qualname, module_name_from_path |
14 | 14 |
|
15 | 15 | logger = logging.getLogger(__name__)
|
16 | 16 |
|
@@ -42,16 +42,21 @@ def _shared_leading_path(*paths):
|
42 | 42 | class KnownImport:
|
43 | 43 | """Import information associated with a single known type annotation.
|
44 | 44 |
|
45 |
| - Parameters |
| 45 | + Attributes |
46 | 46 | ----------
|
47 |
| - import_name : |
48 |
| - Dotted names after "import". |
49 |
| - import_path : |
| 47 | + import_path : str, optional |
50 | 48 | Dotted names after "from".
|
51 |
| - import_alias : |
| 49 | + import_name : str, optional |
| 50 | + Dotted names after "import". |
| 51 | + import_alias : str, optional |
52 | 52 | Name (without ".") after "as".
|
53 |
| - builtin_name : |
| 53 | + builtin_name : str, optional |
54 | 54 | Names an object that's builtin and doesn't need an import.
|
| 55 | +
|
| 56 | + Examples |
| 57 | + -------- |
| 58 | + >>> KnownImport(import_path="numpy", import_name="uint8", import_alias="ui8") |
| 59 | + <KnownImport 'from numpy import uint8 as ui8'> |
55 | 60 | """
|
56 | 61 |
|
57 | 62 | import_name: str = None
|
@@ -170,14 +175,6 @@ def __str__(self):
|
170 | 175 | return out
|
171 | 176 |
|
172 | 177 |
|
173 |
| -@dataclass(slots=True, frozen=True) |
174 |
| -class InspectionContext: |
175 |
| - """Currently inspected module and other information.""" |
176 |
| - |
177 |
| - file_path: Path |
178 |
| - in_package_path: str |
179 |
| - |
180 |
| - |
181 | 178 | def _is_type(value) -> bool:
|
182 | 179 | """Check if value is a type."""
|
183 | 180 | # Checking for isinstance(..., type) isn't enough, some types such as
|
@@ -262,45 +259,57 @@ def common_known_imports():
|
262 | 259 | return known_imports
|
263 | 260 |
|
264 | 261 |
|
265 |
| -class KnownImportCollector(cst.CSTVisitor): |
| 262 | +class TypeCollector(cst.CSTVisitor): |
266 | 263 | @classmethod
|
267 |
| - def collect(cls, file, module_name): |
| 264 | + def collect(cls, file): |
| 265 | + """Collect importable type annotations in given file. |
| 266 | +
|
| 267 | + Parameters |
| 268 | + ---------- |
| 269 | + file : Path |
| 270 | +
|
| 271 | + Returns |
| 272 | + ------- |
| 273 | + collected : dict[str, KnownImport] |
| 274 | + """ |
268 | 275 | file = Path(file)
|
269 | 276 | with file.open("r") as fo:
|
270 | 277 | source = fo.read()
|
271 | 278 |
|
272 | 279 | tree = cst.parse_module(source)
|
273 |
| - collector = cls(module_name=module_name) |
| 280 | + collector = cls(module_name=module_name_from_path(file)) |
274 | 281 | tree.visit(collector)
|
275 | 282 | return collector.known_imports
|
276 | 283 |
|
277 | 284 | def __init__(self, *, module_name):
|
| 285 | + """Initialize type collector. |
| 286 | +
|
| 287 | + Parameters |
| 288 | + ---------- |
| 289 | + module_name : str |
| 290 | + """ |
278 | 291 | self.module_name = module_name
|
279 | 292 | self._stack = []
|
280 | 293 | self.known_imports = {}
|
281 | 294 |
|
282 |
| - def visit_ClassDef(self, node): |
| 295 | + def visit_ClassDef(self, node: cst.ClassDef) -> bool: |
283 | 296 | self._stack.append(node.name.value)
|
284 | 297 |
|
285 | 298 | class_name = ".".join(self._stack[:1])
|
286 | 299 | qualname = f"{self.module_name}.{'.'.join(self._stack)}"
|
287 |
| - |
288 |
| - known_import = KnownImport( |
289 |
| - import_name=class_name, |
290 |
| - import_path=self.module_name, |
291 |
| - ) |
| 300 | + known_import = KnownImport(import_path=self.module_name, import_name=class_name) |
292 | 301 | self.known_imports[qualname] = known_import
|
293 | 302 |
|
294 | 303 | return True
|
295 | 304 |
|
296 |
| - def leave_ClassDef(self, original_node): |
| 305 | + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: |
297 | 306 | self._stack.pop()
|
298 | 307 |
|
299 |
| - def visit_FunctionDef(self, node): |
| 308 | + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: |
300 | 309 | self._stack.append(node.name.value)
|
301 | 310 | return True
|
302 | 311 |
|
303 |
| - def leave_FunctionDef(self, original_node): |
| 312 | + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: |
304 | 313 | self._stack.pop()
|
305 | 314 |
|
306 | 315 |
|
@@ -395,7 +404,8 @@ def query(self, search_name):
|
395 | 404 |
|
396 | 405 | if known_import is None and self.current_source:
|
397 | 406 | # Try scope of current module
|
398 |
| - try_qualname = f"{self.current_source.import_path}.{search_name}" |
| 407 | + module_name = module_name_from_path(self.current_source) |
| 408 | + try_qualname = f"{module_name}.{search_name}" |
399 | 409 | known_import = self.known_imports.get(try_qualname)
|
400 | 410 | if known_import:
|
401 | 411 | annotation_name = search_name
|
|
0 commit comments