From 4c1eaa3bfc6e4c2db182aa4f280f0ab70ad560e1 Mon Sep 17 00:00:00 2001 From: tristanlatr Date: Wed, 12 Feb 2025 16:57:33 -0500 Subject: [PATCH 1/5] Add support for parsing imports --- README.rst | 73 ++++++++++++++++++++--- beniget/beniget.py | 132 +++++++++++++++++++++++++++++++++++++++++- tests/test_imports.py | 97 +++++++++++++++++++++++++++++++ 3 files changed, 292 insertions(+), 10 deletions(-) create mode 100644 tests/test_imports.py diff --git a/README.rst b/README.rst index ea889f9..bf1f511 100644 --- a/README.rst +++ b/README.rst @@ -15,7 +15,12 @@ API Basically Beniget provides three analyse: - ``beniget.Ancestors`` that maps each node to the list of enclosing nodes; -- ``beniget.DefUseChains`` that maps each node to the list of definition points in that node; +- ``beniget.DefUseChains`` that: + + - maps each node to the list of definition points in that node; + - maps each scope node to their locals dictionary; + - maps each alias node to their resolved import; + - ``beniget.UseDefChains`` that maps each node to the list of possible definition of that node. See sample usages and/or run ``pydoc beniget`` for more information :-). @@ -34,22 +39,20 @@ This is a very basic usage: look for def without any use, and warn about them, f >>> import beniget, gast as ast # parse some simple statements - >>> code = "from math import cos, sin; print(cos(3))" + >>> code = "from math import cos, sin; import x, y; print(cos(3) + y.f(2))" >>> module = ast.parse(code) # compute the def-use chains at module level >>> duc = beniget.DefUseChains() >>> duc.visit(module) - # grab the import statement - >>> imported = module.body[0].names - # inspect the users of each imported name - >>> for name in imported: - ... ud = duc.chains[name] + >>> for alias in duc.imports: + ... ud = duc.chains[alias] ... if not ud.users(): - ... print("Unused import: {}".format(ud.name())) + ... print(f"Unused import: {ud.name()}") Unused import: sin + Unused import: x *NOTE*: Due to the dynamic nature of Python, one can fool this analysis by calling the ``eval`` function, eventually through an indirection, or by performing a lookup @@ -231,8 +234,60 @@ let's use the UseDef chains combined with the ancestors. >>> list(map(type, capturex.external)) [, , ] +Report usage of imported names +****************************** + +This analysis takes a collection of names and +reports when their beeing imported and used. + +.. code:: python + + >>> import ast, beniget + >>> def find_references_to(names, defuse: beniget.DefUseChains, + ... ancestors: beniget.Ancestors) -> 'list[beniget.Def]': + ... names = dict.fromkeys(names) + ... found = [] + ... for al,imp in defuse.imports.items(): + ... if imp.target() in names: # "from x import y;y" form + ... for use in defuse.chains[al].users(): + ... found.append(use) + ... # Note: this doesn't handle aliasing. + ... else: # "import x; x.y" form + ... for n in names: + ... if n.startswith(f'{imp.target()}.'): + ... diffnames = n[len(f'{imp.target()}.'):].split('.') + ... for use in defuse.chains[al].users(): + ... attr_node = parent_node = ancestors.parent(use.node) + ... index = 0 + ... # check if node is part of an attribute access matching the dotted name + ... while isinstance(parent_node, ast.Attribute) and index < len(diffnames): + ... if parent_node.attr != diffnames[index]: + ... break + ... attr_node = parent_node + ... parent_node = ancestors.parent(parent_node) + ... index += 1 + ... else: + ... if index: # It has not break and did a loop, meaning we found a match + ... found.append(defuse.chains[attr_node]) + ... + ... return found + ... + >>> module = ast.parse('''\ + ... from typing import List, Dict; import typing as t; import numpy as np + ... def f() -> List[str]: ... + ... def g(a: Dict) -> t.overload: return np.fft.calc(0)''') + >>> c = beniget.DefUseChains() + >>> c.visit(module) + >>> a = beniget.Ancestors() + >>> a.visit(module) + >>> print([str(i) for i in find_references_to(['typing.Dict', 'typing.List', 'typing.overload', 'numpy.fft.calc'], c, a)]) + ['List -> ( -> ())', 'Dict -> ()', '.overload -> ()', '.calc -> ( -> ())'] + + >>> print([str(i) for i in find_references_to(['typing'], c, a)]) + ['t -> (.overload -> ())'] + Acknowledgments --------------- Beniget is in Pierre Augier's debt, for he triggered the birth of beniget and provided -countless meaningful bug reports and advices. Trugarez! +countless meaningful bug reports and advices. Trugarez! \ No newline at end of file diff --git a/beniget/beniget.py b/beniget/beniget.py index e7714d3..ffa19ba 100644 --- a/beniget/beniget.py +++ b/beniget/beniget.py @@ -193,6 +193,126 @@ def _str(self, nodes): for u in self._users) ) +class Import: + """ + Represents an `ast.alias` node with resolved + origin module and target of the locally bound name. + :note: `orgname` will be ``*`` for wildcard imports. + """ + __slots__ = 'orgmodule', 'orgname', 'asname', '_orgroot' + + def __init__(self, orgmodule, orgname=None, asname=None): + """ + Create instances of this class with parse_import(). + :param orgmodule: str, The origin module + :param orgname: str or None, The origin name + :param orgname: str or None, Import asname + """ + self.orgmodule = orgmodule + self._orgroot = orgmodule.split(".", 1)[0] + self.orgname = orgname + self.asname = asname + + def name(self): + """ + Returns the local name of the imported symbol, str. + This will be equal to the ``name()`` of the `Def` of the `ast.alias` node this `Import` represents. + """ + if self.asname: + return self.asname + if self.orgname: + return self.orgname + return self._orgroot + + def target(self): + """ + Returns the fully qualified name of the target of the imported symbol, str. + """ + if self.orgname: + return "{}.{}".format(self.orgmodule, self.orgname) + if self.asname: + return self.orgmodule + return self._orgroot + + def code(self): + """ + Returns this imported name as an import code statement, str. + """ + if self.orgname: + if self.asname: + return "from {} import {} as {}".format(self.orgmodule, self.orgname, self.asname) + return "from {} import {}".format(self.orgmodule, self.orgname) + elif self.asname: + return "import {} as {}".format(self.orgmodule, self.asname) + return "import {}".format(self.orgmodule) + + def __eq__(self, value): + if isinstance(value, Import): + return self.code() == value.code() + return NotImplemented + +def parse_import(node, modname=None, is_package=False): + """ + Parse the given import node into a mapping of aliases to their `Import`. + + :param node: The import node (ast.Import or ast.ImportFrom). + :param modname: The name of the module, required to resolve relative imports. + :type modname: string or None (it wich case we can't resolve relative imports) + :param bool is_package: Whether the module is the ``__init__`` file of a package, + required to correctly resolve relative imports in package's __init__.py files. + :rtype: dict[ast.alias, Import] + """ + result = {} + + ast = pkg(node) + if isinstance(node, ast.Import): + for al in node.names: + result[al] = Import(orgmodule=al.name, + asname=al.asname) + + elif isinstance(node, ast.ImportFrom): + if node.module is None: + module = () + else: + module = tuple(node.module.split(".")) + + if not node.level: + source_module = module + else: + if modname: + # parse relative imports, if module name if provided. + current_module = tuple(modname.split(".")) + if node.level == 1: + if is_package: + relative_module = current_module + else: + relative_module = current_module[:-1] + else: + if is_package: + relative_module = current_module[: 1 - node.level] + else: + relative_module = current_module[: -node.level] + else: + relative_module = () + + if not relative_module: + # We don't raise errors when an relative import makes no sens, + # we simply pad the name with dots. + relative_module = ("",) * node.level + + source_module = relative_module + module + + for alias in node.names: + result[alias] = Import( + orgmodule=".".join(source_module), + orgname=alias.name, + asname=alias.asname, + ) + + else: + raise TypeError('unexpected node type: {}'.format(type(node))) + + return result import builtins BuiltinsSrc = builtins.__dict__ @@ -324,13 +444,17 @@ class DefUseChains(gast.NodeVisitor): """ - def __init__(self, filename=None): + def __init__(self, filename=None, modname=None, is_package=False): """ - filename: str, included in error messages if specified """ self.chains = {} self.locals = defaultdict(list) + self.imports = {} + self.filename = filename + self.modname = modname + self.is_package = is_package # deep copy of builtins, to remain reentrant self._builtins = {k: Def(v) for k, v in Builtins.items()} @@ -1093,6 +1217,9 @@ def visit_Import(self, node): base = alias.name.split(".", 1)[0] self.set_definition(alias.asname or base, dalias) self.add_to_locals(alias.asname or base, dalias) + + self.imports.update(parse_import(node, self.modname, + is_package=self.is_package)) def visit_ImportFrom(self, node): for alias in node.names: @@ -1102,6 +1229,9 @@ def visit_ImportFrom(self, node): else: self.set_definition(alias.asname or alias.name, dalias) self.add_to_locals(alias.asname or alias.name, dalias) + + self.imports.update(parse_import(node, self.modname, + is_package=self.is_package)) def visit_Global(self, node): for name in node.names: diff --git a/tests/test_imports.py b/tests/test_imports.py new file mode 100644 index 0000000..a63f5ba --- /dev/null +++ b/tests/test_imports.py @@ -0,0 +1,97 @@ +import gast as _gast +import ast as _ast +from unittest import TestCase +from textwrap import dedent + +from beniget.beniget import Def, parse_import + +class TestImportParser(TestCase): + ast = _gast + def test_import_parsing(self): + code = ''' + import mod2 + import pack.subpack + import pack.subpack as a + from mod2 import _k as k, _l as l + from pack.subpack.stuff import C + from ast import * + ''' + # orgmodule, orgname, asname, name, target, code + expected = [ + {'mod2':('mod2', None, None, 'mod2', 'mod2', 'import mod2')}, + {'pack':('pack.subpack', None, None, 'pack', 'pack', 'import pack.subpack')}, + {'a':('pack.subpack', None, 'a', 'a', 'pack.subpack', 'import pack.subpack as a')}, + { + 'k':('mod2','_k', 'k', 'k', 'mod2._k', 'from mod2 import _k as k'), + 'l':('mod2','_l', 'l', 'l', 'mod2._l', 'from mod2 import _l as l'), + }, + {'C':('pack.subpack.stuff','C', None, 'C', 'pack.subpack.stuff.C', 'from pack.subpack.stuff import C')}, + {'*': ('ast', '*', None, '*', 'ast.*', 'from ast import *')}] + + node = self.ast.parse(dedent(code)) + assert len(expected)==len(node.body) + for import_node, expected_names in zip(node.body, expected): + assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom)) + for al,i in parse_import(import_node, 'mod1', is_package=False).items(): + assert Def(al).name() in expected_names + (expected_orgmodule, expected_orgname, + expected_asname, expected_name, + expected_target, expected_code) = expected_names[Def(al).name()] + assert i.orgmodule == expected_orgmodule + assert i.orgname == expected_orgname + assert i.asname == expected_asname + assert i.name() == expected_name + assert i.target() == expected_target + assert i.code() == expected_code + + ran=True + assert ran + + def test_import_parsing_relative_package(self): + code = ''' + from ...mod2 import bar as b + from .pack import foo + from ......error import x + ''' + expected = [{'b':('top.mod2','bar')}, + {'foo':('top.subpack.other.pack','foo')}, + {'x': ('......error', 'x')}] + node = self.ast.parse(dedent(code)) + assert len(expected)==len(node.body) + for import_node, expected_names in zip(node.body, expected): + assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom)) + for al,i in parse_import(import_node, + 'top.subpack.other', + is_package=True).items(): + assert Def(al).name() in expected_names + expected_orgmodule, expected_orgname = expected_names[Def(al).name()] + assert i.orgmodule == expected_orgmodule + assert i.orgname == expected_orgname + ran=True + assert ran + + def test_import_parsing_relative_module(self): + code = ''' + from ..mod2 import bar as b + from .pack import foo + from ......error import x + ''' + expected = [{'b':('top.mod2','bar')}, + {'foo':('top.subpack.pack','foo')}, + {'x': ('......error', 'x')}] + node = self.ast.parse(dedent(code)) + assert len(expected)==len(node.body) + for import_node, expected_names in zip(node.body, expected): + assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom)) + for al,i in parse_import(import_node, + 'top.subpack.other', + is_package=False).items(): + assert Def(al).name() in expected_names + expected_orgmodule, expected_orgname = expected_names[Def(al).name()] + assert i.orgmodule == expected_orgmodule + assert i.orgname == expected_orgname + ran=True + assert ran + +class TestImportParserStdlib(TestImportParser): + ast = _ast \ No newline at end of file From 42414f48f53092829531b672491b07c9ba3edb26 Mon Sep 17 00:00:00 2001 From: tristanlatr Date: Thu, 13 Feb 2025 12:16:24 -0500 Subject: [PATCH 2/5] Refactor the relative module support into less lines --- beniget/beniget.py | 52 +++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/beniget/beniget.py b/beniget/beniget.py index ffa19ba..77757f6 100644 --- a/beniget/beniget.py +++ b/beniget/beniget.py @@ -251,6 +251,9 @@ def __eq__(self, value): return self.code() == value.code() return NotImplemented + def __repr__(self): + return f'Import({self.code()!r})' + def parse_import(node, modname=None, is_package=False): """ Parse the given import node into a mapping of aliases to their `Import`. @@ -267,44 +270,37 @@ def parse_import(node, modname=None, is_package=False): ast = pkg(node) if isinstance(node, ast.Import): for al in node.names: - result[al] = Import(orgmodule=al.name, - asname=al.asname) - + result[al] = Import(orgmodule=al.name, asname=al.asname) + elif isinstance(node, ast.ImportFrom): if node.module is None: - module = () - else: - module = tuple(node.module.split(".")) - - if not node.level: - source_module = module + orgmodule = () else: + orgmodule = tuple(node.module.split(".")) + level = node.level + if level: + relative_module = () if modname: # parse relative imports, if module name if provided. - current_module = tuple(modname.split(".")) - if node.level == 1: - if is_package: - relative_module = current_module - else: - relative_module = current_module[:-1] - else: - if is_package: - relative_module = current_module[: 1 - node.level] - else: - relative_module = current_module[: -node.level] - else: - relative_module = () - + curr = tuple(modname.split('.')) + if is_package: + level -= 1 + for _ in range(level): + if not curr: + break + curr = curr[:-1] + if curr: + relative_module = curr + orgmodule if not relative_module: - # We don't raise errors when an relative import makes no sens, + # An relative import makes no sens for beniget... # we simply pad the name with dots. - relative_module = ("",) * node.level - - source_module = relative_module + module + relative_module = ("",) * node.level + orgmodule + + orgmodule = relative_module for alias in node.names: result[alias] = Import( - orgmodule=".".join(source_module), + orgmodule=".".join(orgmodule), orgname=alias.name, asname=alias.asname, ) From d18b99ca21a5944fd17dbff463db8d93e2ebb0b4 Mon Sep 17 00:00:00 2001 From: tristanlatr Date: Thu, 13 Feb 2025 12:16:50 -0500 Subject: [PATCH 3/5] Actually test the DefUseCains.imports attribute as part of the import parsing testing. --- tests/test_imports.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_imports.py b/tests/test_imports.py index a63f5ba..5561804 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -3,6 +3,7 @@ from unittest import TestCase from textwrap import dedent +from .test_chains import StrictDefUseChains from beniget.beniget import Def, parse_import class TestImportParser(TestCase): @@ -29,6 +30,9 @@ def test_import_parsing(self): {'*': ('ast', '*', None, '*', 'ast.*', 'from ast import *')}] node = self.ast.parse(dedent(code)) + du = StrictDefUseChains('./mod1.py', 'mod1', is_package=False) + du.visit(node) + assert len(expected)==len(node.body) for import_node, expected_names in zip(node.body, expected): assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom)) @@ -44,6 +48,9 @@ def test_import_parsing(self): assert i.target() == expected_target assert i.code() == expected_code + assert i == du.imports[al] + assert repr(i).startswith(("Import('import ", "Import('from ")) + ran=True assert ran From 35111c71e6c0c7e8d30a2ca24f6a9d7c113a1a24 Mon Sep 17 00:00:00 2001 From: tristanlatr Date: Thu, 13 Feb 2025 14:04:17 -0500 Subject: [PATCH 4/5] Checks DefUseChains.imports in all import tests --- tests/test_imports.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_imports.py b/tests/test_imports.py index 5561804..88f38f4 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -30,7 +30,7 @@ def test_import_parsing(self): {'*': ('ast', '*', None, '*', 'ast.*', 'from ast import *')}] node = self.ast.parse(dedent(code)) - du = StrictDefUseChains('./mod1.py', 'mod1', is_package=False) + du = StrictDefUseChains(modname='mod1') du.visit(node) assert len(expected)==len(node.body) @@ -63,7 +63,11 @@ def test_import_parsing_relative_package(self): expected = [{'b':('top.mod2','bar')}, {'foo':('top.subpack.other.pack','foo')}, {'x': ('......error', 'x')}] + node = self.ast.parse(dedent(code)) + du = StrictDefUseChains(modname='top.subpack.other', is_package=True) + du.visit(node) + assert len(expected)==len(node.body) for import_node, expected_names in zip(node.body, expected): assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom)) @@ -74,6 +78,10 @@ def test_import_parsing_relative_package(self): expected_orgmodule, expected_orgname = expected_names[Def(al).name()] assert i.orgmodule == expected_orgmodule assert i.orgname == expected_orgname + + assert i == du.imports[al] + assert repr(i).startswith(("Import('import ", "Import('from ")) + ran=True assert ran @@ -86,7 +94,11 @@ def test_import_parsing_relative_module(self): expected = [{'b':('top.mod2','bar')}, {'foo':('top.subpack.pack','foo')}, {'x': ('......error', 'x')}] + node = self.ast.parse(dedent(code)) + du = StrictDefUseChains(modname='top.subpack.other', is_package=False) + du.visit(node) + assert len(expected)==len(node.body) for import_node, expected_names in zip(node.body, expected): assert isinstance(import_node, (self.ast.Import, self.ast.ImportFrom)) @@ -97,6 +109,10 @@ def test_import_parsing_relative_module(self): expected_orgmodule, expected_orgname = expected_names[Def(al).name()] assert i.orgmodule == expected_orgmodule assert i.orgname == expected_orgname + + assert i == du.imports[al] + assert repr(i).startswith(("Import('import ", "Import('from ")) + ran=True assert ran From 06d2b13f0bf5400507e8463794a4656f15df3885 Mon Sep 17 00:00:00 2001 From: tristanlatr <19967168+tristanlatr@users.noreply.github.com> Date: Thu, 13 Feb 2025 14:26:15 -0500 Subject: [PATCH 5/5] fix typo --- beniget/beniget.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/beniget/beniget.py b/beniget/beniget.py index 77757f6..bed9ef1 100644 --- a/beniget/beniget.py +++ b/beniget/beniget.py @@ -281,7 +281,7 @@ def parse_import(node, modname=None, is_package=False): if level: relative_module = () if modname: - # parse relative imports, if module name if provided. + # parse relative imports, if module name is provided. curr = tuple(modname.split('.')) if is_package: level -= 1