Skip to content

Commit efe1b34

Browse files
committed
🐛 Fix suite transformer
Make sure to visit non-suite child nodes. Added a consistent way to add child nodes.
1 parent ee728c5 commit efe1b34

File tree

11 files changed

+181
-71
lines changed

11 files changed

+181
-71
lines changed

src/python_minifier/module_printer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def visit_FunctionDef(self, node, is_async=False):
494494
self.visit_arguments(node.args)
495495
self.code += ')'
496496

497-
if hasattr(node, 'returns') and node.returns:
497+
if hasattr(node, 'returns') and node.returns is not None:
498498
self.code += '->'
499499
self._expression(node.returns)
500500
self.code += ':'

src/python_minifier/transforms/combine_imports.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class CombineImports(SuiteTransformer):
1111
1212
"""
1313

14-
def _combine_import(self, node_list):
14+
def _combine_import(self, node_list, parent):
1515

1616
alias = []
1717

@@ -20,15 +20,17 @@ def _combine_import(self, node_list):
2020
alias += statement.names
2121
else:
2222
if alias:
23-
yield ast.Import(names=alias)
23+
yield self.add_child(ast.Import(names=alias),
24+
parent=parent)
2425
alias = []
2526

2627
yield statement
2728

2829
if alias:
29-
yield ast.Import(names=alias)
30+
yield self.add_child(ast.Import(names=alias),
31+
parent=parent)
3032

31-
def _combine_import_from(self, node_list):
33+
def _combine_import_from(self, node_list, parent):
3234

3335
prev_import = None
3436
alias = []
@@ -54,16 +56,18 @@ def combine(statement):
5456
alias += statement.names
5557
else:
5658
if alias:
57-
yield ast.ImportFrom(module=prev_import.module, names=alias, level=prev_import.level)
59+
yield self.add_child(ast.ImportFrom(module=prev_import.module, names=alias, level=prev_import.level),
60+
parent=parent)
5861
alias = []
5962

6063
yield statement
6164

6265
if alias:
63-
yield ast.ImportFrom(module=prev_import.module, names=alias, level=prev_import.level)
66+
yield self.add_child(ast.ImportFrom(module=prev_import.module, names=alias, level=prev_import.level),
67+
parent=parent)
6468

6569
def suite(self, node_list, parent):
66-
a = list(self._combine_import(node_list))
67-
b = list(self._combine_import_from(a))
70+
a = list(self._combine_import(node_list, parent))
71+
b = list(self._combine_import_from(a, parent))
6872

6973
return [self.visit(n) for n in b]

src/python_minifier/transforms/remove_annotations.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import ast
22
import sys
33

4-
from python_minifier.rename.mapper import add_parent
54
from python_minifier.transforms.suite_transformer import SuiteTransformer
65

76

@@ -16,22 +15,21 @@ def __call__(self, node):
1615
return self.visit(node)
1716

1817
def visit_FunctionDef(self, node):
18+
node.args = self.visit_arguments(node.args)
19+
node.body = self.suite(node.body, parent=node)
20+
node.decorator_list = [self.visit(d) for d in node.decorator_list]
21+
1922
if hasattr(node, 'returns'):
2023
node.returns = None
21-
node.body = [self.visit(a) for a in node.body]
22-
23-
if node.args:
24-
node.args = self.visit_arguments(node.args)
2524

2625
return node
2726

28-
def visit_AsyncFunctionDef(self, node):
29-
return self.visit_FunctionDef(node)
30-
3127
def visit_arguments(self, node):
28+
assert isinstance(node, ast.arguments)
3229

3330
if node.args:
3431
node.args = [self.visit_arg(a) for a in node.args]
32+
3533
if hasattr(node, 'kwonlyargs') and node.kwonlyargs:
3634
node.kwonlyargs = [self.visit_arg(a) for a in node.kwonlyargs]
3735

@@ -76,15 +74,11 @@ def is_dataclass_field(node):
7674
if is_dataclass_field(node):
7775
return node
7876
elif node.value:
79-
assign = ast.Assign([node.target], node.value)
80-
assign.parent = node.parent
81-
assign.namespace = node.namespace
82-
return assign
77+
return self.add_child(ast.Assign([node.target], node.value), parent=node.parent)
8378
else:
8479
# Valueless annotations cause the interpreter to treat the variable as a local.
8580
# I don't know of another way to do that without assigning to it, so
8681
# keep it as an AnnAssign, but replace the annotation with '0'
8782

88-
node.annotation = ast.Num(0)
89-
add_parent(node, parent=node.parent, namespace=node.namespace)
83+
node.annotation = self.add_child(ast.Num(0), parent=node.parent)
9084
return node

src/python_minifier/transforms/remove_literal_statements.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import ast
22

3-
from python_minifier.rename.mapper import add_parent
43
from python_minifier.transforms.suite_transformer import SuiteTransformer
54

65

@@ -46,24 +45,15 @@ def is_literal_statement(self, node):
4645
if not isinstance(node, ast.Expr):
4746
return False
4847

49-
if (
50-
isinstance(node.value, (ast.Num, ast.Str, ast.NameConstant))
51-
or node.value.__class__.__name__ == 'Constant'
52-
or node.value.__class__.__name__ == 'Bytes'
53-
):
54-
return True
55-
56-
return False
48+
return self.is_node(node.value, (ast.Num, ast.Str, 'NameConstant', 'Bytes'))
5749

5850
def suite(self, node_list, parent):
59-
without_literals = [self.visit(a) for a in filter(lambda n: not self.is_literal_statement(n), node_list)]
51+
without_literals = [self.visit(n) for n in node_list if not self.is_literal_statement(n)]
6052

6153
if len(without_literals) == 0:
6254
if isinstance(parent, ast.Module):
6355
return []
6456
else:
65-
expr = ast.Expr(value=ast.Num(0))
66-
add_parent(expr, parent=parent, namespace=parent.namespace)
67-
return [expr]
57+
return [self.add_child(ast.Expr(value=ast.Num(0)), parent=parent)]
6858

6959
return without_literals

src/python_minifier/transforms/remove_pass.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import ast
22

33
from python_minifier.transforms.suite_transformer import SuiteTransformer
4-
from python_minifier.rename.mapper import add_parent
54

65

76
class RemovePass(SuiteTransformer):
@@ -15,14 +14,12 @@ def __call__(self, node):
1514
return self.visit(node)
1615

1716
def suite(self, node_list, parent):
18-
without_pass = [self.visit(a) for a in filter(lambda n: not isinstance(n, ast.Pass), node_list)]
17+
without_pass = [self.visit(a) for a in filter(lambda n: not self.is_node(n, ast.Pass), node_list)]
1918

2019
if len(without_pass) == 0:
2120
if isinstance(parent, ast.Module):
2221
return []
2322
else:
24-
expr = ast.Expr(value=ast.Num(0))
25-
add_parent(expr, parent=parent, namespace=parent.namespace)
26-
return [expr]
23+
return [self.add_child(ast.Expr(value=ast.Num(0)), parent=parent)]
2724

2825
return without_pass

src/python_minifier/transforms/suite_transformer.py

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import ast
22

3+
from python_minifier.rename.mapper import add_parent
4+
5+
36
class NodeVisitor(object):
47

58
def visit(self, node):
@@ -35,6 +38,53 @@ def visit_Constant(self, node):
3538
visitor = getattr(self, method, self.generic_visit)
3639
return visitor(node)
3740

41+
def is_node(self, node, types):
42+
"""
43+
Is a node one of the specified node types
44+
45+
A node type may be an actual ast class, or a string naming one.
46+
types is a single node type or an iterable of many.
47+
48+
If a node_type specified a specific Constant type (Str, Bytes, Num etc),
49+
returns true for Constant nodes of the correct type.
50+
51+
:type node: ast.AST
52+
:param types:
53+
:rtype: bool
54+
"""
55+
56+
57+
if not isinstance(types, tuple):
58+
types = types,
59+
60+
actual_types = []
61+
for type in types:
62+
if isinstance(type, str):
63+
node_type = getattr(ast, type, None)
64+
if node_type is not None:
65+
actual_types.append(node_type)
66+
else:
67+
actual_types.append(type)
68+
69+
if isinstance(node, tuple(actual_types)):
70+
return True
71+
72+
if hasattr(ast, 'Constant') and isinstance(node, ast.Constant):
73+
if node.value in [None, True, False]:
74+
return ast.NameConstant in types
75+
elif isinstance(node.value, (int, float, complex)):
76+
return ast.Num in types
77+
elif isinstance(node.value, str):
78+
return ast.Str in types
79+
elif isinstance(node.value, bytes):
80+
return ast.Bytes in types
81+
elif node.value == Ellipsis:
82+
return ast.Ellipsis in types
83+
else:
84+
raise RuntimeError('Unknown Constant value %r' % type(node.value))
85+
86+
return False
87+
3888
class SuiteTransformer(NodeVisitor):
3989
"""
4090
Transform suites of instructions
@@ -44,17 +94,39 @@ def __call__(self, node):
4494
return self.visit(node)
4595

4696
def visit_ClassDef(self, node):
97+
node.bases = [self.visit(b) for b in node.bases]
98+
4799
node.body = self.suite(node.body, parent=node)
100+
node.decorator_list = [self.visit(d) for d in node.decorator_list]
101+
102+
if hasattr(node, 'starargs') and node.starargs is not None:
103+
node.starargs = self.visit(node.starargs)
104+
105+
if hasattr(node, 'kwargs') and node.kwargs is not None:
106+
node.kwargs = self.visit(node.kwargs)
107+
108+
if hasattr(node, 'keywords'):
109+
node.keywords = [self.visit(kw) for kw in node.keywords]
110+
48111
return node
49112

50113
def visit_FunctionDef(self, node):
114+
node.args = self.visit(node.args)
51115
node.body = self.suite(node.body, parent=node)
116+
node.decorator_list = [self.visit(d) for d in node.decorator_list]
117+
118+
if hasattr(node, 'returns') and node.returns is not None:
119+
node.returns = self.visit(node.returns)
120+
52121
return node
53122

54123
def visit_AsyncFunctionDef(self, node):
55124
return self.visit_FunctionDef(node)
56125

57126
def visit_For(self, node):
127+
node.target = self.visit(node.target)
128+
node.iter = self.visit(node.iter)
129+
58130
node.body = self.suite(node.body, parent=node)
59131

60132
if node.orelse:
@@ -63,15 +135,13 @@ def visit_For(self, node):
63135
return node
64136

65137
def visit_AsyncFor(self, node):
66-
node.body = self.suite(node.body, parent=node)
67-
68-
if node.orelse:
69-
node.orelse = self.suite(node.orelse, parent=node)
70-
71-
return node
138+
return self.visit_For(node)
72139

73140
def visit_If(self, node):
141+
node.test = self.visit(node.test)
142+
74143
node.body = self.suite(node.body, parent=node)
144+
75145
if node.orelse:
76146
node.orelse = self.suite(node.orelse, parent=node)
77147

@@ -80,6 +150,8 @@ def visit_If(self, node):
80150
def visit_Try(self, node):
81151
node.body = self.suite(node.body, parent=node)
82152

153+
node.handlers = [self.visit(h) for h in node.handlers]
154+
83155
if node.orelse:
84156
node.orelse = self.suite(node.orelse, parent=node)
85157

@@ -89,6 +161,8 @@ def visit_Try(self, node):
89161
return node
90162

91163
def visit_While(self, node):
164+
node.test = self.visit(node.test)
165+
92166
node.body = self.suite(node.body, parent=node)
93167

94168
if node.orelse:
@@ -97,12 +171,20 @@ def visit_While(self, node):
97171
return node
98172

99173
def visit_With(self, node):
174+
175+
if hasattr(node, 'items'):
176+
node.items = [self.visit(i) for i in node.items]
177+
else:
178+
if node.context_expr:
179+
node.context_expr = self.visit(node.context_expr)
180+
if node.optional_vars:
181+
node.optional_vars = self.visit(node.optional_vars)
182+
100183
node.body = self.suite(node.body, parent=node)
101184
return node
102185

103186
def visit_AsyncWith(self, node):
104-
node.body = self.suite(node.body, parent=node)
105-
return node
187+
return self.visit_With(node)
106188

107189
def visit_Module(self, node):
108190
node.body = self.suite(node.body, parent=node)
@@ -132,3 +214,29 @@ def generic_visit(self, node):
132214
else:
133215
setattr(node, field, new_node)
134216
return node
217+
218+
def add_child(self, child, parent, namespace=None):
219+
220+
def nearest_function_namespace(node):
221+
"""
222+
Return the namespace node for the nearest function scope.
223+
224+
This could be itself.
225+
226+
:param node: The node to get the function namespace of
227+
:type node: ast.Node
228+
:rtype: ast.Node
229+
230+
"""
231+
232+
if isinstance(node, (ast.FunctionDef, ast.Module)):
233+
return node
234+
if hasattr(ast, 'AsyncFunctionDef') and isinstance(node, ast.AsyncFunctionDef):
235+
return node
236+
return nearest_function_namespace(node.parent)
237+
238+
if namespace is None:
239+
namespace = nearest_function_namespace(parent)
240+
241+
add_parent(child, parent=parent, namespace=namespace)
242+
return child

0 commit comments

Comments
 (0)