Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions backends/nxp/edge_passes/neutron_edge_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass,
)
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass

from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
RemoveIOQuantOpsPass,
)
from executorch.exir import EdgeProgramManager
from executorch.exir.program._program import (
_get_updated_graph_signature,
Expand All @@ -24,7 +28,9 @@

class NeutronEdgePassManager(PassManager):

def __init__(self, passes: list[NeutronEdgePass] = None):
def __init__(
self, passes: list[NeutronEdgePass] = None, remove_io_quant_ops: bool = False
):
passes: list[NeutronEdgePass] = passes or [
MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(),
Expand All @@ -35,6 +41,8 @@ def __init__(self, passes: list[NeutronEdgePass] = None):
steps=10, # Empirical value. At most 10 cycles of passes will be run.
)

self.remove_io_quant_ops = remove_io_quant_ops

def _transform_graph_module(self, module: nn.Module) -> PassResult:
"""Apply the passes to a single graph module."""
pass_result: PassResult = super().__call__(module)
Expand Down Expand Up @@ -78,12 +86,17 @@ def __call__(self, epm: EdgeProgramManager) -> EdgeProgramManager:

new_programs[name] = new_program

if len(new_programs) == 0:
# No passes were run, return the old EdgeProgramManager.
return epm
result = epm

else:
# Return a new EdgeProgramManager with the updated programs.
return EdgeProgramManager(
if len(new_programs) > 0:
# Use a new EdgeProgramManager with the updated programs if any update was performed.
result = EdgeProgramManager(
new_programs, copy.deepcopy(epm._config_methods), epm.compile_config
)

if self.remove_io_quant_ops:
result = result.transform(
[RemoveIOQuantOpsPass(edge_program_manager=result)]
)

return result
12 changes: 3 additions & 9 deletions backends/nxp/tests/executorch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
NeutronEdgePassManager,
)
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
RemoveIOQuantOpsPass,
)
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
Expand Down Expand Up @@ -115,7 +112,9 @@ def to_quantized_edge_program(
edge_compile_config=edge_compile_config,
)

edge_program_manager = NeutronEdgePassManager()(edge_program_manager)
edge_program_manager = NeutronEdgePassManager(
remove_io_quant_ops=remove_quant_io_ops
)(edge_program_manager)

compile_spec = generate_neutron_compile_spec(
target,
Expand All @@ -125,11 +124,6 @@ def to_quantized_edge_program(
partitioner = NeutronPartitioner(compile_spec, custom_delegation_options)
edge_program_manager = edge_program_manager.to_backend(partitioner)

if remove_quant_io_ops:
edge_program_manager = edge_program_manager.transform(
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
)

return edge_program_manager


Expand Down
54 changes: 22 additions & 32 deletions examples/nxp/aot_neutron_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import executorch.kernels.quantized # noqa F401

import torch
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
RemoveIOQuantOpsPass,
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
NeutronEdgePassManager,
)
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
Expand All @@ -33,7 +33,6 @@
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model

from .models.mobilenet_v2 import MobilenetV2

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
Expand Down Expand Up @@ -228,7 +227,7 @@ def _get_batch_size(data):

module = exported_program.module()

# 4. Quantize if required
# 3. Quantize if required
if args.quantize:
if calibration_inputs is None:
logging.warning(
Expand All @@ -254,39 +253,30 @@ def _get_batch_size(data):
quantized_str = "quantized " if args.quantize else ""
print(f"\nAccuracy of the {quantized_str}`{args.model_name}`: {accuracy}\n")

# 5. Export to edge program
partitioner_list = []
if args.delegate is True:
partitioner_list = [
NeutronPartitioner(
generate_neutron_compile_spec(
args.target,
args.neutron_converter_flavor,
operators_not_to_delegate=args.operators_not_to_delegate,
)
)
]
# 4. Transform and lower

compile_spec = generate_neutron_compile_spec(
args.target,
operators_not_to_delegate=args.operators_not_to_delegate,
neutron_converter_flavor=args.neutron_converter_flavor,
)
partitioners = [NeutronPartitioner(compile_spec)] if args.delegate else []

edge_program = to_edge_transform_and_lower(
edge_program_manager = to_edge_transform_and_lower(
export(module, example_inputs, strict=True),
partitioner=partitioner_list,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
partitioner=partitioners,
compile_config=EdgeCompileConfig(),
)
logging.debug(f"Exported graph:\n{edge_program.exported_program().graph}")

if args.remove_quant_io_ops:
edge_program = edge_program.transform(
[RemoveIOQuantOpsPass(edge_program_manager=edge_program)]
)
logging.debug(
f"Exported graph (RemoveIOQuantOpsPass):\n{edge_program.exported_program().graph}"
)
edge_program_manager = NeutronEdgePassManager(
remove_io_quant_ops=args.remove_quant_io_ops
)(edge_program_manager)

logging.debug(f"Lowered graph:\n{edge_program_manager.exported_program().graph}")

# 6. Export to ExecuTorch program
# 5. Export to ExecuTorch program
try:
exec_prog = edge_program.to_executorch(
exec_prog = edge_program_manager.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
except RuntimeError as e:
Expand All @@ -306,7 +296,7 @@ def executorch_program_to_str(ep, verbose=False):

logging.debug(f"Executorch program:\n{executorch_program_to_str(exec_prog)}")

# 7. Serialize to *.pte
# 6. Serialize to *.pte
model_name = f"{args.model_name}" + (
"_nxp_delegate" if args.delegate is True else ""
)
Expand Down
Loading