|
13 | 13 | # Local imports
|
14 | 14 | from .. import fixer_base
|
15 | 15 | from os.path import dirname, join, exists, pathsep
|
16 |
| -from ..fixer_util import FromImport, syms |
| 16 | +from ..fixer_util import FromImport, syms, token |
| 17 | + |
| 18 | + |
| 19 | +def traverse_imports(names): |
| 20 | + """ |
| 21 | + Walks over all the names imported in a dotted_as_names node. |
| 22 | + """ |
| 23 | + pending = [names] |
| 24 | + while pending: |
| 25 | + node = pending.pop() |
| 26 | + if node.type == token.NAME: |
| 27 | + yield node.value |
| 28 | + elif node.type == syms.dotted_name: |
| 29 | + yield "".join([ch.value for ch in node.children]) |
| 30 | + elif node.type == syms.dotted_as_name: |
| 31 | + pending.append(node.children[0]) |
| 32 | + elif node.type == syms.dotted_as_names: |
| 33 | + pending.extend(node.children[::-2]) |
| 34 | + else: |
| 35 | + raise AssertionError("unkown node type") |
| 36 | + |
17 | 37 |
|
18 | 38 | class FixImport(fixer_base.BaseFix):
|
19 | 39 |
|
20 | 40 | PATTERN = """
|
21 |
| - import_from< type='from' imp=any 'import' ['('] any [')'] > |
| 41 | + import_from< 'from' imp=any 'import' ['('] any [')'] > |
22 | 42 | |
|
23 |
| - import_name< type='import' imp=any > |
| 43 | + import_name< 'import' imp=any > |
24 | 44 | """
|
25 | 45 |
|
26 | 46 | def transform(self, node, results):
|
27 | 47 | imp = results['imp']
|
28 | 48 |
|
29 |
| - mod_name = unicode(imp.children[0] if imp.type == syms.dotted_as_name \ |
30 |
| - else imp) |
31 |
| - |
32 |
| - if mod_name.startswith('.'): |
33 |
| - # Already a new-style import |
34 |
| - return |
35 |
| - |
36 |
| - if not probably_a_local_import(mod_name, self.filename): |
37 |
| - # I guess this is a global import -- skip it! |
38 |
| - return |
39 |
| - |
40 |
| - if results['type'].value == 'from': |
| 49 | + if node.type == syms.import_from: |
41 | 50 | # Some imps are top-level (eg: 'import ham')
|
42 | 51 | # some are first level (eg: 'import ham.eggs')
|
43 | 52 | # some are third level (eg: 'import ham.eggs as spam')
|
44 | 53 | # Hence, the loop
|
45 | 54 | while not hasattr(imp, 'value'):
|
46 | 55 | imp = imp.children[0]
|
47 |
| - imp.value = "." + imp.value |
48 |
| - node.changed() |
| 56 | + if self.probably_a_local_import(imp.value): |
| 57 | + imp.value = "." + imp.value |
| 58 | + imp.changed() |
| 59 | + return node |
49 | 60 | else:
|
50 |
| - new = FromImport('.', getattr(imp, 'content', None) or [imp]) |
| 61 | + have_local = False |
| 62 | + have_absolute = False |
| 63 | + for mod_name in traverse_imports(imp): |
| 64 | + if self.probably_a_local_import(mod_name): |
| 65 | + have_local = True |
| 66 | + else: |
| 67 | + have_absolute = True |
| 68 | + if have_absolute: |
| 69 | + if have_local: |
| 70 | + # We won't handle both sibling and absolute imports in the |
| 71 | + # same statement at the moment. |
| 72 | + self.warning(node, "absolute and local imports together") |
| 73 | + return |
| 74 | + |
| 75 | + new = FromImport('.', [imp]) |
51 | 76 | new.set_prefix(node.get_prefix())
|
52 |
| - node = new |
53 |
| - return node |
| 77 | + return new |
54 | 78 |
|
55 |
| -def probably_a_local_import(imp_name, file_path): |
56 |
| - # Must be stripped because the right space is included by the parser |
57 |
| - imp_name = imp_name.split('.', 1)[0].strip() |
58 |
| - base_path = dirname(file_path) |
59 |
| - base_path = join(base_path, imp_name) |
60 |
| - # If there is no __init__.py next to the file its not in a package |
61 |
| - # so can't be a relative import. |
62 |
| - if not exists(join(dirname(base_path), '__init__.py')): |
| 79 | + def probably_a_local_import(self, imp_name): |
| 80 | + imp_name = imp_name.split('.', 1)[0] |
| 81 | + base_path = dirname(self.filename) |
| 82 | + base_path = join(base_path, imp_name) |
| 83 | + # If there is no __init__.py next to the file its not in a package |
| 84 | + # so can't be a relative import. |
| 85 | + if not exists(join(dirname(base_path), '__init__.py')): |
| 86 | + return False |
| 87 | + for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: |
| 88 | + if exists(base_path + ext): |
| 89 | + return True |
63 | 90 | return False
|
64 |
| - for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']: |
65 |
| - if exists(base_path + ext): |
66 |
| - return True |
67 |
| - return False |
|
0 commit comments