5
5
6
6
import copy
7
7
8
+ from executorch .backends .nxp .backend .ir .edge_passes .remove_io_quant_ops_pass import (
9
+ RemoveIOQuantOpsPass ,
10
+ )
11
+
8
12
from executorch .backends .nxp .edge_passes .move_auxiliary_operator_into_separate_qdq_cluster_pass import (
9
13
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass ,
10
14
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass ,
24
28
25
29
class NeutronEdgePassManager (PassManager ):
26
30
27
- def __init__ (self , passes : list [NeutronEdgePass ] = None ):
31
+ def __init__ (
32
+ self , passes : list [NeutronEdgePass ] = None , remove_io_quant_ops : bool = False
33
+ ):
28
34
passes : list [NeutronEdgePass ] = passes or [
29
35
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass (),
30
36
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass (),
@@ -35,6 +41,8 @@ def __init__(self, passes: list[NeutronEdgePass] = None):
35
41
steps = 10 , # Empirical value. At most 10 cycles of passes will be run.
36
42
)
37
43
44
+ self .remove_io_quant_ops = remove_io_quant_ops
45
+
38
46
def _transform_graph_module (self , module : nn .Module ) -> PassResult :
39
47
"""Apply the passes to a single graph module."""
40
48
pass_result : PassResult = super ().__call__ (module )
@@ -78,12 +86,17 @@ def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:
78
86
79
87
new_programs [name ] = new_program
80
88
81
- if len (new_programs ) == 0 :
82
- # No passes were run, return the old EdgeProgramManager.
83
- return epm
89
+ result = epm
84
90
85
- else :
86
- # Return a new EdgeProgramManager with the updated programs.
87
- return EdgeProgramManager (
91
+ if len ( new_programs ) > 0 :
92
+ # Use a new EdgeProgramManager with the updated programs if any update was performed .
93
+ result = EdgeProgramManager (
88
94
new_programs , copy .deepcopy (epm ._config_methods ), epm .compile_config
89
95
)
96
+
97
+ if self .remove_io_quant_ops :
98
+ result = result .transform (
99
+ [RemoveIOQuantOpsPass (edge_program_manager = result )]
100
+ )
101
+
102
+ return result
0 commit comments