1
1
"""Collect type information."""
2
2
3
- import builtins
4
3
import importlib
5
4
import json
6
5
import logging
@@ -226,27 +225,6 @@ def _is_type(value):
226
225
return is_type
227
226
228
227
229
- def _builtin_types ():
230
- """Return known imports for all builtins (in the current runtime).
231
-
232
- Returns
233
- -------
234
- known_imports : dict[str, KnownImport]
235
- """
236
- known_builtins = set (dir (builtins ))
237
-
238
- known_imports = {}
239
- for name in known_builtins :
240
- if name .startswith ("_" ):
241
- continue
242
- value = getattr (builtins , name )
243
- if not _is_type (value ):
244
- continue
245
- known_imports [name ] = KnownImport (builtin_name = name )
246
-
247
- return known_imports
248
-
249
-
250
228
def _runtime_types_in_module (module_name ):
251
229
module = importlib .import_module (module_name )
252
230
types = {}
@@ -277,18 +255,20 @@ def common_known_types():
277
255
Examples
278
256
--------
279
257
>>> types = common_known_types()
280
- >>> types["str"]
281
- <KnownImport str (builtin) >
282
- >>> types["Iterable"]
283
- <KnownImport 'from collections.abc import Iterable'>
258
+ >>> types["builtins. str"]
259
+ <KnownImport 'from builtins import str' >
260
+ >>> types["typing. Iterable"]
261
+ <KnownImport 'from typing import Iterable'>
284
262
>>> types["collections.abc.Iterable"]
285
263
<KnownImport 'from collections.abc import Iterable'>
286
264
"""
287
- known_imports = _builtin_types ()
288
- known_imports |= _runtime_types_in_module ("typing" )
289
- # Overrides containers from typing
290
- known_imports |= _runtime_types_in_module ("collections.abc" )
291
- return known_imports
265
+ from ._stdlib_types import stdlib_types
266
+
267
+ types = {
268
+ f"{ module } .{ type_name } " : KnownImport (import_path = module , import_name = type_name )
269
+ for module , type_name in stdlib_types
270
+ }
271
+ return types
292
272
293
273
294
274
class TypeCollector (cst .CSTVisitor ):
@@ -334,7 +314,7 @@ def collect(cls, file):
334
314
335
315
Returns
336
316
-------
337
- collected : dict[str, KnownImport]
317
+ collected_types : dict[str, KnownImport]
338
318
"""
339
319
file = Path (file )
340
320
with file .open ("r" ) as fo :
@@ -343,7 +323,7 @@ def collect(cls, file):
343
323
tree = cst .parse_module (source )
344
324
collector = cls (module_name = module_name_from_path (file ))
345
325
tree .visit (collector )
346
- return collector .known_imports
326
+ return collector .collected_types
347
327
348
328
def __init__ (self , * , module_name ):
349
329
"""Initialize type collector.
@@ -354,7 +334,7 @@ def __init__(self, *, module_name):
354
334
"""
355
335
self .module_name = module_name
356
336
self ._stack = []
357
- self .known_imports = {}
337
+ self .collected_types = {}
358
338
359
339
def visit_ClassDef (self , node : cst .ClassDef ) -> bool :
360
340
self ._stack .append (node .name .value )
@@ -396,9 +376,104 @@ def _collect_type_annotation(self, stack):
396
376
stack : Iterable[str]
397
377
A list of names that form the path to the collected type.
398
378
"""
399
- qualname = "." .join ([self .module_name , * stack ])
400
379
known_import = KnownImport (import_path = self .module_name , import_name = stack [0 ])
401
- self .known_imports [qualname ] = known_import
380
+
381
+ qualname = f"{ self .module_name } .{ '.' .join (stack )} "
382
+ scoped_name = f"{ self .module_name } :{ '.' .join (stack )} "
383
+ self .collected_types [qualname ] = known_import
384
+ self .collected_types [scoped_name ] = known_import
385
+
386
+
387
+ class StubTypeCollector (TypeCollector ):
388
+
389
+ def __init__ (self , * , module_name ):
390
+ """Initialize type collector.
391
+
392
+ Parameters
393
+ ----------
394
+ module_name : str
395
+ """
396
+ super ().__init__ (module_name = module_name )
397
+ self .collected_types = set ()
398
+ self .dunder_all = set ()
399
+
400
+ @classmethod
401
+ def collect (cls , file ):
402
+ """Collect importable type annotations in given file.
403
+
404
+ Parameters
405
+ ----------
406
+ file : Path
407
+
408
+ Returns
409
+ -------
410
+ collected_types : dict[str, KnownImport]
411
+ """
412
+ file = Path (file )
413
+ with file .open ("r" ) as fo :
414
+ source = fo .read ()
415
+
416
+ tree = cst .parse_module (source )
417
+ collector = cls (module_name = module_name_from_path (file ))
418
+ tree .visit (collector )
419
+ return collector .collected_types , collector .dunder_all
420
+
421
+ def visit_ImportFrom (self , node ):
422
+ # https://typing.python.org/en/latest/spec/distributing.html#import-conventions
423
+
424
+ if cstm .matches (node , cstm .ImportFrom (names = cstm .ImportStar ())):
425
+ module_names = cstm .findall (node .module , cstm .Name ())
426
+ module = "_" .join (name .value for name in module_names )
427
+ stack = [* self ._stack , f"<Reference: { module } .*>" ]
428
+ self ._collect_type_annotation (stack )
429
+
430
+ names = cstm .findall (node , cstm .AsName ())
431
+ for name in names :
432
+ if cstm .matches (name , cstm .AsName (name = cstm .Name ())):
433
+ value = name .name .value
434
+ assert value
435
+ if value == "__all__" :
436
+ continue
437
+
438
+ stack = [* self ._stack , value ]
439
+ self ._collect_type_annotation (stack )
440
+
441
+ def visit_AugAssign (self , node ):
442
+ is_add_assign_to_dunder_all = cstm .matches (
443
+ node ,
444
+ cstm .AugAssign (
445
+ target = cstm .Name (value = "__all__" ), operator = cstm .AddAssign ()
446
+ ),
447
+ )
448
+ is_assign_list = cstm .matches (node .value , cstm .List ())
449
+ if is_add_assign_to_dunder_all and is_assign_list :
450
+ strings = cstm .findall (node .value , cstm .SimpleString ())
451
+ for string in strings :
452
+ self ._collect_dunder_all (string .value )
453
+
454
+ def visit_Assign (self , node ):
455
+ is_assign_to_dunder_all = cstm .matches (
456
+ node ,
457
+ cstm .Assign (targets = [cstm .AssignTarget (target = cstm .Name (value = "__all__" ))]),
458
+ )
459
+ is_assign_list = cstm .matches (node .value , cstm .List ())
460
+ if is_assign_to_dunder_all and is_assign_list :
461
+ strings = cstm .findall (node .value , cstm .SimpleString ())
462
+ for string in strings :
463
+ self ._collect_dunder_all (string .value )
464
+
465
+ def _collect_type_annotation (self , stack ):
466
+ """Collect an importable type annotation.
467
+
468
+ Parameters
469
+ ----------
470
+ stack : Iterable[str]
471
+ A list of names that form the path to the collected type.
472
+ """
473
+ self .collected_types .add ((self .module_name , "." .join (stack )))
474
+
475
+ def _collect_dunder_all (self , value ):
476
+ self .dunder_all .add ((self .module_name , value .strip ("'\" " )))
402
477
403
478
404
479
class TypeMatcher :
@@ -427,6 +502,7 @@ def __init__(
427
502
types = None ,
428
503
type_prefixes = None ,
429
504
type_nicknames = None ,
505
+ implicit_modules = ("collections.abc" , "typing" , "_typeshed" ),
430
506
):
431
507
"""
432
508
Parameters
@@ -438,6 +514,7 @@ def __init__(
438
514
self .types = types or common_known_types ()
439
515
self .type_prefixes = type_prefixes or {}
440
516
self .type_nicknames = type_nicknames or {}
517
+ self .implicit_modules = implicit_modules
441
518
self .successful_queries = 0
442
519
self .unknown_qualnames = []
443
520
@@ -492,20 +569,39 @@ def match(self, search_name):
492
569
# Replace alias
493
570
search_name = self .type_nicknames .get (search_name , search_name )
494
571
495
- if type_origin is None and self .current_module :
496
- # Try scope of current module
497
- module_name = module_name_from_path (self .current_module )
498
- try_qualname = f"{ module_name } .{ search_name } "
572
+ if type_origin is None :
573
+ # Try builtin
574
+ try_qualname = f"builtins.{ search_name } "
499
575
type_origin = self .types .get (try_qualname )
500
576
if type_origin :
501
577
type_name = search_name
502
578
503
579
if type_origin is None and search_name in self .types :
580
+ # Direct match
504
581
type_name = search_name
505
582
type_origin = self .types [search_name ]
506
583
584
+ if type_origin is None and self .current_module :
585
+ # Try scope of current module
586
+ for sep in ["." , ":" ]:
587
+ try_qualname = f"{ self .current_module } { sep } { search_name } "
588
+ type_origin = self .types .get (try_qualname )
589
+ if type_origin :
590
+ type_name = search_name
591
+ break
592
+
593
+ if type_origin is None and self .implicit_modules :
594
+ # Try implicit modules
595
+ for module in self .implicit_modules :
596
+ try_qualname = f"{ module } .{ search_name } "
597
+ type_origin = self .types .get (try_qualname )
598
+ if type_origin :
599
+ type_name = search_name
600
+ break
601
+
507
602
if type_origin is None :
508
- # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
603
+ # Try matching with module prefix,
604
+ # try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
509
605
for partial_qualname in reversed (accumulate_qualname (search_name )):
510
606
type_origin = self .type_prefixes .get (partial_qualname )
511
607
if type_origin :
0 commit comments