Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for imports #122

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 64 additions & 9 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ API
Basically Beniget provides three analyse:

- ``beniget.Ancestors`` that maps each node to the list of enclosing nodes;
- ``beniget.DefUseChains`` that maps each node to the list of definition points in that node;
- ``beniget.DefUseChains`` that:

- maps each node to the list of definition points in that node;
- maps each scope node to their locals dictionary;
- maps each alias node to their resolved import;

- ``beniget.UseDefChains`` that maps each node to the list of possible definition of that node.

See sample usages and/or run ``pydoc beniget`` for more information :-).
Expand All @@ -34,22 +39,20 @@ This is a very basic usage: look for def without any use, and warn about them, f
>>> import beniget, gast as ast

# parse some simple statements
>>> code = "from math import cos, sin; print(cos(3))"
>>> code = "from math import cos, sin; import x, y; print(cos(3) + y.f(2))"
>>> module = ast.parse(code)

# compute the def-use chains at module level
>>> duc = beniget.DefUseChains()
>>> duc.visit(module)

# grab the import statement
>>> imported = module.body[0].names

# inspect the users of each imported name
>>> for name in imported:
... ud = duc.chains[name]
>>> for alias in duc.imports:
... ud = duc.chains[alias]
... if not ud.users():
... print("Unused import: {}".format(ud.name()))
... print(f"Unused import: {ud.name()}")
Unused import: sin
Unused import: x

*NOTE*: Due to the dynamic nature of Python, one can fool this analysis by
calling the ``eval`` function, eventually through an indirection, or by performing a lookup
Expand Down Expand Up @@ -231,8 +234,60 @@ let's use the UseDef chains combined with the ancestors.
>>> list(map(type, capturex.external))
[<class 'gast.gast.Assign'>, <class 'gast.gast.Assign'>, <class 'gast.gast.Assign'>]

Report usage of imported names
******************************

This analysis takes a collection of names and
reports when their beeing imported and used.

.. code:: python

>>> import ast, beniget
>>> def find_references_to(names, defuse: beniget.DefUseChains,
... ancestors: beniget.Ancestors) -> 'list[beniget.Def]':
... names = dict.fromkeys(names)
... found = []
... for al,imp in defuse.imports.items():
... if imp.target() in names: # "from x import y;y" form
... for use in defuse.chains[al].users():
... found.append(use)
... # Note: this doesn't handle aliasing.
... else: # "import x; x.y" form
... for n in names:
... if n.startswith(f'{imp.target()}.'):
... diffnames = n[len(f'{imp.target()}.'):].split('.')
... for use in defuse.chains[al].users():
... attr_node = parent_node = ancestors.parent(use.node)
... index = 0
... # check if node is part of an attribute access matching the dotted name
... while isinstance(parent_node, ast.Attribute) and index < len(diffnames):
... if parent_node.attr != diffnames[index]:
... break
... attr_node = parent_node
... parent_node = ancestors.parent(parent_node)
... index += 1
... else:
... if index: # It has not break and did a loop, meaning we found a match
... found.append(defuse.chains[attr_node])
...
... return found
...
>>> module = ast.parse('''\
... from typing import List, Dict; import typing as t; import numpy as np
... def f() -> List[str]: ...
... def g(a: Dict) -> t.overload: return np.fft.calc(0)''')
>>> c = beniget.DefUseChains()
>>> c.visit(module)
>>> a = beniget.Ancestors()
>>> a.visit(module)
>>> print([str(i) for i in find_references_to(['typing.Dict', 'typing.List', 'typing.overload', 'numpy.fft.calc'], c, a)])
['List -> (<Subscript> -> ())', 'Dict -> ()', '.overload -> ()', '.calc -> (<Call> -> ())']

>>> print([str(i) for i in find_references_to(['typing'], c, a)])
['t -> (.overload -> ())']

Acknowledgments
---------------

Beniget is in Pierre Augier's debt, for he triggered the birth of beniget and provided
countless meaningful bug reports and advices. Trugarez!
countless meaningful bug reports and advices. Trugarez!
128 changes: 127 additions & 1 deletion beniget/beniget.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,122 @@ def _str(self, nodes):
for u in self._users)
)

class Import:
"""
Represents an `ast.alias` node with resolved
origin module and target of the locally bound name.
:note: `orgname` will be ``*`` for wildcard imports.
"""
__slots__ = 'orgmodule', 'orgname', 'asname', '_orgroot'

def __init__(self, orgmodule, orgname=None, asname=None):
"""
Create instances of this class with parse_import().
:param orgmodule: str, The origin module
:param orgname: str or None, The origin name
:param orgname: str or None, Import asname
"""
self.orgmodule = orgmodule
self._orgroot = orgmodule.split(".", 1)[0]
self.orgname = orgname
self.asname = asname

def name(self):
"""
Returns the local name of the imported symbol, str.
This will be equal to the ``name()`` of the `Def` of the `ast.alias` node this `Import` represents.
"""
if self.asname:
return self.asname
if self.orgname:
return self.orgname
return self._orgroot

def target(self):
"""
Returns the fully qualified name of the target of the imported symbol, str.
"""
if self.orgname:
return "{}.{}".format(self.orgmodule, self.orgname)
if self.asname:
return self.orgmodule
return self._orgroot

def code(self):
"""
Returns this imported name as an import code statement, str.
"""
if self.orgname:
if self.asname:
return "from {} import {} as {}".format(self.orgmodule, self.orgname, self.asname)
return "from {} import {}".format(self.orgmodule, self.orgname)
elif self.asname:
return "import {} as {}".format(self.orgmodule, self.asname)
return "import {}".format(self.orgmodule)

def __eq__(self, value):
if isinstance(value, Import):
return self.code() == value.code()
return NotImplemented

def __repr__(self):
return f'Import({self.code()!r})'

def parse_import(node, modname=None, is_package=False):
"""
Parse the given import node into a mapping of aliases to their `Import`.

:param node: The import node (ast.Import or ast.ImportFrom).
:param modname: The name of the module, required to resolve relative imports.
:type modname: string or None (it wich case we can't resolve relative imports)
:param bool is_package: Whether the module is the ``__init__`` file of a package,
required to correctly resolve relative imports in package's __init__.py files.
:rtype: dict[ast.alias, Import]
"""
result = {}

ast = pkg(node)
if isinstance(node, ast.Import):
for al in node.names:
result[al] = Import(orgmodule=al.name, asname=al.asname)

elif isinstance(node, ast.ImportFrom):
if node.module is None:
orgmodule = ()
else:
orgmodule = tuple(node.module.split("."))
level = node.level
if level:
relative_module = ()
if modname:
# parse relative imports, if module name is provided.
curr = tuple(modname.split('.'))
if is_package:
level -= 1
for _ in range(level):
if not curr:
break
curr = curr[:-1]
if curr:
relative_module = curr + orgmodule
if not relative_module:
# An relative import makes no sens for beniget...
# we simply pad the name with dots.
relative_module = ("",) * node.level + orgmodule

orgmodule = relative_module

for alias in node.names:
result[alias] = Import(
orgmodule=".".join(orgmodule),
orgname=alias.name,
asname=alias.asname,
)

else:
raise TypeError('unexpected node type: {}'.format(type(node)))

return result

import builtins
BuiltinsSrc = builtins.__dict__
Expand Down Expand Up @@ -324,13 +440,17 @@ class DefUseChains(gast.NodeVisitor):
"""


def __init__(self, filename=None):
def __init__(self, filename=None, modname=None, is_package=False):
"""
- filename: str, included in error messages if specified
"""
self.chains = {}
self.locals = defaultdict(list)
self.imports = {}

self.filename = filename
self.modname = modname
self.is_package = is_package

# deep copy of builtins, to remain reentrant
self._builtins = {k: Def(v) for k, v in Builtins.items()}
Expand Down Expand Up @@ -1093,6 +1213,9 @@ def visit_Import(self, node):
base = alias.name.split(".", 1)[0]
self.set_definition(alias.asname or base, dalias)
self.add_to_locals(alias.asname or base, dalias)

self.imports.update(parse_import(node, self.modname,
is_package=self.is_package))

def visit_ImportFrom(self, node):
for alias in node.names:
Expand All @@ -1102,6 +1225,9 @@ def visit_ImportFrom(self, node):
else:
self.set_definition(alias.asname or alias.name, dalias)
self.add_to_locals(alias.asname or alias.name, dalias)

self.imports.update(parse_import(node, self.modname,
is_package=self.is_package))

def visit_Global(self, node):
for name in node.names:
Expand Down
120 changes: 120 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import gast as _gast
import ast as _ast
from unittest import TestCase
from textwrap import dedent

from .test_chains import StrictDefUseChains
from beniget.beniget import Def, parse_import

class TestImportParser(TestCase):
ast = _gast
def test_import_parsing(self):
code = '''
import mod2
import pack.subpack
import pack.subpack as a
from mod2 import _k as k, _l as l
from pack.subpack.stuff import C
from ast import *
'''
# orgmodule, orgname, asname, name, target, code
expected = [
{'mod2':('mod2', None, None, 'mod2', 'mod2', 'import mod2')},
{'pack':('pack.subpack', None, None, 'pack', 'pack', 'import pack.subpack')},
{'a':('pack.subpack', None, 'a', 'a', 'pack.subpack', 'import pack.subpack as a')},
{
'k':('mod2','_k', 'k', 'k', 'mod2._k', 'from mod2 import _k as k'),
'l':('mod2','_l', 'l', 'l', 'mod2._l', 'from mod2 import _l as l'),
},
{'C':('pack.subpack.stuff','C', None, 'C', 'pack.subpack.stuff.C', 'from pack.subpack.stuff import C')},
{'*': ('ast', '*', None, '*', 'ast.*', 'from ast import *')}]

node = self.ast.parse(dedent(code))
du = StrictDefUseChains(modname='mod1')
du.visit(node)

assert len(expected)==len(node.body)
for import_node, expected_names in zip(node.body, expected):
assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom))
for al,i in parse_import(import_node, 'mod1', is_package=False).items():
assert Def(al).name() in expected_names
(expected_orgmodule, expected_orgname,
expected_asname, expected_name,
expected_target, expected_code) = expected_names[Def(al).name()]
assert i.orgmodule == expected_orgmodule
assert i.orgname == expected_orgname
assert i.asname == expected_asname
assert i.name() == expected_name
assert i.target() == expected_target
assert i.code() == expected_code

assert i == du.imports[al]
assert repr(i).startswith(("Import('import ", "Import('from "))

ran=True
assert ran

def test_import_parsing_relative_package(self):
code = '''
from ...mod2 import bar as b
from .pack import foo
from ......error import x
'''
expected = [{'b':('top.mod2','bar')},
{'foo':('top.subpack.other.pack','foo')},
{'x': ('......error', 'x')}]

node = self.ast.parse(dedent(code))
du = StrictDefUseChains(modname='top.subpack.other', is_package=True)
du.visit(node)

assert len(expected)==len(node.body)
for import_node, expected_names in zip(node.body, expected):
assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom))
for al,i in parse_import(import_node,
'top.subpack.other',
is_package=True).items():
assert Def(al).name() in expected_names
expected_orgmodule, expected_orgname = expected_names[Def(al).name()]
assert i.orgmodule == expected_orgmodule
assert i.orgname == expected_orgname

assert i == du.imports[al]
assert repr(i).startswith(("Import('import ", "Import('from "))

ran=True
assert ran

def test_import_parsing_relative_module(self):
code = '''
from ..mod2 import bar as b
from .pack import foo
from ......error import x
'''
expected = [{'b':('top.mod2','bar')},
{'foo':('top.subpack.pack','foo')},
{'x': ('......error', 'x')}]

node = self.ast.parse(dedent(code))
du = StrictDefUseChains(modname='top.subpack.other', is_package=False)
du.visit(node)

assert len(expected)==len(node.body)
for import_node, expected_names in zip(node.body, expected):
assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom))
for al,i in parse_import(import_node,
'top.subpack.other',
is_package=False).items():
assert Def(al).name() in expected_names
expected_orgmodule, expected_orgname = expected_names[Def(al).name()]
assert i.orgmodule == expected_orgmodule
assert i.orgname == expected_orgname

assert i == du.imports[al]
assert repr(i).startswith(("Import('import ", "Import('from "))

ran=True
assert ran

class TestImportParserStdlib(TestImportParser):
ast = _ast
Loading