From d299baea9381fa32eaa8e618d15312c1b467df8a Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Wed, 3 Sep 2025 12:39:22 +0200 Subject: [PATCH 1/6] NXP backend: Store inferred node format in the `node.meta`. --- .../nxp/backend/edge_program_converter.py | 11 +++----- backends/nxp/backend/ir/conversion_context.py | 5 ---- backends/nxp/backend/node_format_inference.py | 27 ++++++++++++------- .../nxp/tests/test_node_format_inference.py | 19 ++++++------- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index ddbbf5b2e3a..657f1526da1 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -19,8 +19,8 @@ from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 from executorch.backends.nxp.backend.node_format_inference import ( - NodeFormat, NodeFormatInference, + NXP_NODE_FORMAT, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -70,12 +70,11 @@ def convert_program( :param custom_delegation_options: Custom user options which affect node delegation. :return: TFLite flatbuffers as bytes. """ - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() parameters_mapping = self.map_inputs_to_parameters(edge_program) cc = self.build_conversion_context( parameters_mapping, - node_formats, conversion_config, custom_delegation_options, ) @@ -101,7 +100,7 @@ def convert_program( def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext): for node in nodes: if node.op == "placeholder": - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] if node.name in context.parameters_mapping: # Node is placeholder and has data -> append as static tensor with data @@ -114,7 +113,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "call_function": # Node is call function -> append only output as a tensor - node_format = context.node_formats[node] + node_format = node.meta[NXP_NODE_FORMAT] context.tflite_builder.append_as_fake_tensor(node, node_format) elif node.op == "output": # Nothing to do @@ -171,7 +170,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet @staticmethod def build_conversion_context( parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], conversion_config: ConversionConfig = _default_conversion_config, custom_delegation_options: CustomDelegationOptions = _default_delegation_options, ) -> ConversionContext: @@ -186,7 +184,6 @@ def build_conversion_context( tflite_builder, conversion_config, parameters_mapping, - node_formats, custom_delegation_options, ) diff --git a/backends/nxp/backend/ir/conversion_context.py b/backends/nxp/backend/ir/conversion_context.py index 6fb7e98424e..d4746fbde01 100644 --- a/backends/nxp/backend/ir/conversion_context.py +++ b/backends/nxp/backend/ir/conversion_context.py @@ -10,8 +10,6 @@ from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import ( AtenModelBuilderDirector, ) -from executorch.backends.nxp.backend.node_format_inference import NodeFormat -from torch import Node from torch.nn import Parameter @@ -19,7 +17,6 @@ class ConversionContext: tflite_builder: AtenModelBuilderDirector conversion_config: ConversionConfig parameters_mapping: dict[str, Parameter] - node_formats: dict[Node, NodeFormat] custom_delegation_options: CustomDelegationOptions def __init__( @@ -27,7 +24,6 @@ def __init__( tflite_builder: AtenModelBuilderDirector, conversion_config: ConversionConfig, parameters_mapping: dict, - node_formats: dict[Node, NodeFormat], custom_delegation_options: CustomDelegationOptions, ): """ @@ -39,5 +35,4 @@ def __init__( self.tflite_builder = tflite_builder self.conversion_config = conversion_config self.parameters_mapping = parameters_mapping - self.node_formats = node_formats self.custom_delegation_options = custom_delegation_options diff --git a/backends/nxp/backend/node_format_inference.py b/backends/nxp/backend/node_format_inference.py index 76b05d172a4..1e4c86c2aec 100644 --- a/backends/nxp/backend/node_format_inference.py +++ b/backends/nxp/backend/node_format_inference.py @@ -13,6 +13,8 @@ logger = logging.getLogger(__name__) +NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format. + class NodeFormat(Enum): # Node's output in NCHW format @@ -43,8 +45,6 @@ class NodeFormatInference: # are channels first but output is formatless). ops_that_can_change_tensor_format = {exir_ops.edge.aten.view_copy.default} - _node_format_mapping: dict[Node, NodeFormat] - _type_changed_during_last_run: bool # Mapping between Node and its ancestors (inputs) @@ -57,7 +57,6 @@ def __init__(self, edge_program: ExportedProgram): self._edge_program = edge_program self._nodes = edge_program.graph.nodes - self._node_format_mapping = {} self._node_inputs = { node: node.all_input_nodes for node in edge_program.graph.nodes } @@ -67,7 +66,7 @@ def __init__(self, edge_program: ExportedProgram): self._type_changed_during_last_run = False - def identify_node_formats(self) -> dict[Node, NodeFormat]: + def identify_node_formats(self): self._type_changed_during_last_run = True # Re-run format inference until there are no changes @@ -77,7 +76,15 @@ def identify_node_formats(self) -> dict[Node, NodeFormat]: for node in self._nodes: self._infer_format_of_nodes(node) - return self._node_format_mapping + for node in self._nodes: + if self._get_node_op_type(node) is None: + continue + if not hasattr(node, "meta"): + logging.warning(f"Node `{node}` does not have the `meta` attribute.") + node.meta = {} + if NXP_NODE_FORMAT not in node.meta: + logging.warning(f"Node `{node}` does not have inferred format.") + node.meta[NXP_NODE_FORMAT] = NodeFormat.NONE def _infer_format_of_nodes(self, node: Node): op_type = self._get_node_op_type(node) @@ -151,7 +158,7 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat): if old_node_format != node_format: self._type_changed_during_last_run = True - self._node_format_mapping[node] = node_format + node.meta[NXP_NODE_FORMAT] = node_format def _get_node_op_type(self, node: Node) -> str | None: """ @@ -252,8 +259,10 @@ def _node_produces_or_consumes_channels_first_format(self, node) -> bool: for ancestor_node in input_nodes ) - def _get_node_format(self, node): - return self._node_format_mapping.get(node, NodeFormat.NONE) + def _get_node_format(self, node) -> NodeFormat: + if not hasattr(node, "meta"): + node.meta = {} + return node.meta.get(NXP_NODE_FORMAT, NodeFormat.NONE) - def _node_is_placeholder(self, node: Node): + def _node_is_placeholder(self, node: Node) -> bool: return node.op == "placeholder" diff --git a/backends/nxp/tests/test_node_format_inference.py b/backends/nxp/tests/test_node_format_inference.py index e2796187ce8..d0a73328037 100644 --- a/backends/nxp/tests/test_node_format_inference.py +++ b/backends/nxp/tests/test_node_format_inference.py @@ -9,6 +9,7 @@ from executorch.backends.nxp.backend.node_format_inference import ( NodeFormat, NodeFormatInference, + NXP_NODE_FORMAT, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager from executorch.backends.nxp.tests.models import ( @@ -27,7 +28,7 @@ def test_convolution(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "p_conv_weight": NodeFormat.CHANNELS_FIRST, @@ -37,8 +38,8 @@ def test_convolution(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_softmax(): @@ -48,7 +49,7 @@ def test_softmax(): exir_program = torch.export.export(model, example_input) edge_program = exir.to_edge(exir_program).exported_program() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.FORMATLESS, @@ -56,8 +57,8 @@ def test_softmax(): "output": NodeFormat.FORMATLESS, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] def test_maxpool2d(): @@ -78,7 +79,7 @@ def test_maxpool2d(): # Remove MaxPool-related "getitem" nodes from graph edge_program = NeutronPassManager(edge_program, [RemoveGetItemPass]).transform() - node_formats = NodeFormatInference(edge_program).identify_node_formats() + NodeFormatInference(edge_program).identify_node_formats() expected_mapping = { "x": NodeFormat.CHANNELS_FIRST, @@ -86,5 +87,5 @@ def test_maxpool2d(): "output": NodeFormat.CHANNELS_FIRST, } - for node, node_format in node_formats.items(): - assert expected_mapping[node.name] == node_format + for node in edge_program.graph.nodes: + assert expected_mapping[node.name] == node.meta[NXP_NODE_FORMAT] From df3a6019bbd317fbb4112252b9ef08894ef1c5d0 Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Wed, 3 Sep 2025 13:33:26 +0200 Subject: [PATCH 2/6] NXP backend: Update pass which removes GetItem nodes, to preserve the node format. The pass `RemoveGetItemPass` replaces a `max_pool2d_with_indices` node with a `max_pool2d` node, that doesn't require a GetItem afterward. The new operator must, however, preserve the original node format. Therefore, a copy of the pass was created in `backends/nxp/_passes`, where it was modified. The new directory was created, because the pass doesn't follow the `NeutronEdgePass` interface. --- backends/nxp/_passes/remove_getitem_pass.py | 103 ++++++++++++++++++++ backends/nxp/nxp_backend.py | 2 +- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 backends/nxp/_passes/remove_getitem_pass.py diff --git a/backends/nxp/_passes/remove_getitem_pass.py b/backends/nxp/_passes/remove_getitem_pass.py new file mode 100644 index 00000000000..646f5083adf --- /dev/null +++ b/backends/nxp/_passes/remove_getitem_pass.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2025 NXP +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.nxp.backend.node_format_inference import ( + NodeFormat, + NXP_NODE_FORMAT, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RemoveGetItemPass(ExportPass): + """ + This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator, + that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator. + Before Pass: + MaxPool2d ---> GetItem[max_values, max_indexes] + After Pass: + MaxPool2d -> max_values + """ + + def call(self, graph_module: torch.fx.GraphModule): + module = graph_module + for node in module.graph.nodes: + if node.op == "call_function": + if ( + node.target.__name__ == "aten.max_pool2d_with_indices.default" + or node.target.__name__ == "aten.max.dim" + ): + users = list(node.users.keys()) + + if len(users) != 1: + if len(users) == 2 and node.target.__name__ == "aten.max.dim": + # Two users is allowed for max.dim. For that case, + # rather than removing the getitem node in this + # pass, we handle the getitem nodes in the op's + # visitor when serializing + continue + else: + raise AssertionError( + f"Invalid number of users for {node.target.__name__}: {len(users)}" + ) + + getitem_node = list(node.users.keys())[0] + + if getitem_node.target.__name__ != "getitem": + raise AssertionError( + f"Expected max node's user to be getitem, got {getitem_node.target.__name__}" + ) + + getitem_index = getitem_node.args[1] + + with module.graph.inserting_before(node): + if ( + node.target.__name__ + == "aten.max_pool2d_with_indices.default" + ): + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.max_pool2d.default, + args=node.args, + kwargs=node.kwargs, + ) + + else: + if getitem_index != 0: + raise AssertionError( + f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone." + ) + new_max_wd = module.graph.create_node( + "call_function", + exir_ops.edge.aten.amax.default, + args=node.args, + kwargs=node.kwargs, + ) + + # MODIFIED PART START + # Make sure to preserve the inferred node format. + new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get( + NXP_NODE_FORMAT, NodeFormat.NONE + ) + # MODIFIED PART END + + getitem_node.replace_all_uses_with(new_max_wd) + + module.graph.erase_node(getitem_node) + module.graph.erase_node(node) + + graph_module.recompile() + # Propagate metadata and retrace module + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/nxp/nxp_backend.py b/backends/nxp/nxp_backend.py index c801eefec81..e6e6c0db443 100644 --- a/backends/nxp/nxp_backend.py +++ b/backends/nxp/nxp_backend.py @@ -14,6 +14,7 @@ import numpy as np import torch +from executorch.backends.nxp._passes.remove_getitem_pass import RemoveGetItemPass from executorch.backends.nxp.backend.edge_program_converter import ( EdgeProgramToIRConverter, @@ -28,7 +29,6 @@ NeutronNodeArtifacts, ) from executorch.backends.nxp.neutron_pass_manager import NeutronPassManager -from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.verification.verifier import EXIREdgeDialectVerifier From d7558e976bd477bfd27fc6fe4356dbb48760cd28 Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Wed, 3 Sep 2025 16:25:29 +0200 Subject: [PATCH 3/6] NXP backend: Perform node format inference before partitioning. Before, the format inference was done during conversion to NeutronIR (after partitioning), so the partitioner didn't yet know the formats. Now, the partitioner has the format data, which can be used to accurately select nodes for delegation. --- backends/nxp/backend/edge_program_converter.py | 6 +----- backends/nxp/neutron_partitioner.py | 5 +++++ backends/nxp/tests/executors.py | 6 ++++-- .../tests/ir/converter/node_converter/test_cat_converter.py | 6 ++++++ .../ir/converter/node_converter/test_softmax_converter.py | 4 ++++ backends/nxp/tests/test_neutron_converter_manager.py | 3 +++ 6 files changed, 23 insertions(+), 7 deletions(-) diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 657f1526da1..0b945b75694 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -18,10 +18,7 @@ from torch.fx import Node from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 -from executorch.backends.nxp.backend.node_format_inference import ( - NodeFormatInference, - NXP_NODE_FORMAT, -) +from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops # noinspection PyProtectedMember @@ -70,7 +67,6 @@ def convert_program( :param custom_delegation_options: Custom user options which affect node delegation. :return: TFLite flatbuffers as bytes. """ - NodeFormatInference(edge_program).identify_node_formats() parameters_mapping = self.map_inputs_to_parameters(edge_program) cc = self.build_conversion_context( diff --git a/backends/nxp/neutron_partitioner.py b/backends/nxp/neutron_partitioner.py index 5bcdee0f8b6..d04c10502bd 100644 --- a/backends/nxp/neutron_partitioner.py +++ b/backends/nxp/neutron_partitioner.py @@ -24,6 +24,7 @@ from torch.fx.passes.operator_support import OperatorSupportBase from torch.nn import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.nxp_backend import NeutronBackend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( @@ -342,6 +343,10 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: allows_single_node_partition=True, ) + # Identify the format (NCHW/NHWC/...) for all nodes in the graph, and store it in the `node.meta`. + # This format will be used by the `CapabilityBasedPartitioner` to determine which nodes will be delegated. + NodeFormatInference(exported_program).identify_node_formats() + partition_list = capability_partitioner.propose_partitions() for partition in partition_list: for node in partition.nodes: diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index afdb15af106..f55b173ddae 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 NXP +# Copyright 2023-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -22,11 +22,12 @@ NodeConverter, Target, ) + +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from torch.export import ExportedProgram from torch.fx import Node from torch.fx.graph import Graph - # If executed on i.MX platform, there is no tensorflow module. And typically the intention is to use the tflite python # interpreter available in tflite_runtime try: @@ -305,6 +306,7 @@ def convert_run_compare( ) -> (TFLiteExecutor, EdgeProgramExecutor): if tfl_model is None: + NodeFormatInference(edge_program).identify_node_formats() tfl_model, _ = EdgeProgramToIRConverter().convert_program( edge_program, conversion_config ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index 3df703f5bba..cd66602cb57 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -17,6 +17,8 @@ from executorch.backends.nxp.tests.executors import ( convert_run_compare, graph_contains_any_of_ops, + ToNCHWPreprocess, + ToNHWCPreprocess, ) from executorch.exir.dialects._ops import ops as exir_ops from torch.export import ExportedProgram @@ -126,6 +128,8 @@ def test_cat__channels_first__same_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) @@ -241,6 +245,8 @@ def test_cat__channels_first__different_shapes(dim, num_inputs, mocker): exported_program, tfl_model=tflite_flatbuffers_model, input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), atol=1, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py index 92af90b923d..b2e00fefc5a 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_softmax_converter.py @@ -11,6 +11,7 @@ EdgeProgramToIRConverter, ) from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.executorch_pipeline import to_edge_program from executorch.backends.nxp.tests.executors import convert_run_compare from executorch.backends.nxp.tests.models import SoftmaxConvModule, SoftmaxModule @@ -56,6 +57,7 @@ def test_softmax_conversion__unknown_input_format(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -78,6 +80,7 @@ def test_softmax_conversion_channel_last(input_shape, dim: int): model = SoftmaxConvModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() # TODO (Robert Kalmar) Currently this test not pass because the convertibility checker doesn't use tensor formats. with pytest.raises( @@ -104,6 +107,7 @@ def test_softmax_conversion_unsupported_dims(input_shape, dim: int): model = SoftmaxModule(dim) edge_program = to_edge_program(model, input_shape).exported_program() + NodeFormatInference(edge_program).identify_node_formats() with pytest.raises( AssertionError, match="`aten__softmax_default` is not convertible" diff --git a/backends/nxp/tests/test_neutron_converter_manager.py b/backends/nxp/tests/test_neutron_converter_manager.py index af723ec9c7a..31a33940b6e 100644 --- a/backends/nxp/tests/test_neutron_converter_manager.py +++ b/backends/nxp/tests/test_neutron_converter_manager.py @@ -13,6 +13,7 @@ from executorch.backends.nxp.backend.neutron_converter_manager import ( NeutronConverterManager, ) +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference from executorch.backends.nxp.tests.models import Conv2dModule @@ -23,6 +24,7 @@ def test_conv2d_neutron_conversion__default_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() @@ -45,6 +47,7 @@ def test__conv2d_neutron_conversion__invalid_flavor(): exir_program = torch.export.export(model, example_input) edge_program_manager = exir.to_edge(exir_program) + NodeFormatInference(edge_program_manager.exported_program()).identify_node_formats() edge_program_converter = EdgeProgramToIRConverter() tflite_model, _ = edge_program_converter.convert_program( edge_program_manager.exported_program() From 904d36b95293a995b43b23982e8055d5af34e626 Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Thu, 4 Sep 2025 09:01:44 +0200 Subject: [PATCH 4/6] NXP backend: Improve `cat` delegation by using inferred node formats. --- .../ops_converters/cat_converter.py | 27 ++++--- .../node_converter/test_cat_converter.py | 75 +++++++++++++++++++ 2 files changed, 90 insertions(+), 12 deletions(-) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index 4f7f00fe5ba..ef3163d2da9 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -18,6 +18,7 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import ( Concatenation, ) +from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -88,25 +89,27 @@ def _is_supported_on_target( return False # Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the - # last dimension, depending on the formats of the node. The format, however, cannot be determined - # during conversion, as it depends on what other nodes are delegated. + # last dimension, depending on the formats of the node. + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # During conversion to IR, the shape will be permuted to channels last, and the dimension on index + # `1` will end up being the channels (last dim in NHWC). + channels_index = 1 + else: + # The shape will not be permuted during conversion, so the channels will remain the last dimension. + channels_index = -1 + input_channels = [ - # The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it - # will still be the channels in the IR. - _get_shape(input_)[1] - for input_ in node.all_input_nodes - ] + [ - # If the inputs/outputs are channels first, the last dimension will be the channels. - _get_shape(input_)[-1] + _get_shape(input_)[channels_index] for input_ in node.all_input_nodes ] + output_channels = _get_shape(node)[channels_index] + if any((input_channel % 8) != 0 for input_channel in input_channels): # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492 return False - output_channels = [_get_shape(node)[1], _get_shape(node)[-1]] - # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 - if any((out_c % 8) != 0 for out_c in output_channels): + if (output_channels % 8) != 0: + # neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493 return False if len(node.all_input_nodes) < 2: # Not supported on Neutron diff --git a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py index cd66602cb57..d9b58eda839 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_cat_converter.py @@ -296,3 +296,78 @@ def test_cat__force_delegate(): graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] ) assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + +def test_cat__format_specific_support__formatless(mocker): + # The last dim will end up being the channels, as the format is `formatless`. + # Only the last dim satisfies the Neutron requirements for the channels. + input_shape = (3, 3, 3, 8) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + quantized_program = to_quantized_edge_program( + CatModule(dim), input_shapes + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + atol=1, + ) + + +def test_cat__format_specific_support__channels_first(mocker): + # The second dim will end up being the channels, as the format is `formatless`. + # Only the second dim satisfies the Neutron requirements for the channels. + input_shape = (3, 8, 3, 3) + num_inputs = 2 + dim = 2 + + input_shapes = [input_shape] * num_inputs + + converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program") + + channels = ( + sum(shape[1] for shape in input_shapes) if dim in [1, -3] else input_shape[1] + ) + quantized_program = to_quantized_edge_program( + CatConvModule(dim, channels), input_shapes + ).exported_program() + + # Make sure the `Cat` was delegated. + assert not graph_contains_any_of_ops( + graph=quantized_program.graph, ops=[exir_ops.edge.aten.cat.default] + ) + assert any("lowered_module" in node.name for node in quantized_program.graph.nodes) + + tflite_flatbuffers_model, io_formats = converter_spy.spy_return + exported_program: ExportedProgram = converter_spy.call_args.args[1] + input_data = { + i: (np.random.random(shape) * 50).astype(np.int8) + for i, shape in enumerate(input_shapes) + } + convert_run_compare( + exported_program, + tfl_model=tflite_flatbuffers_model, + input_data=input_data, + tflite_input_preprocess=ToNHWCPreprocess(), + tflite_output_preprocess=ToNCHWPreprocess(), + atol=1, + ) From 8255593d21c9aaa370df5ea395e0d679061f3ce8 Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Thu, 4 Sep 2025 09:28:07 +0200 Subject: [PATCH 5/6] NXP backend: Improve `constant_pad_nd` delegation by using inferred node formats. --- .../constant_pad_nd_converter.py | 20 ++++--- .../test_constant_pad_nd_converter.py | 52 ++++++++++++++++++- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py index f58df1a88d9..9257ee7e229 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py @@ -27,6 +27,8 @@ pad_options, pad_v2_options, ) + +from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter @@ -41,11 +43,17 @@ def _is_supported_on_target( ) -> bool: match target: case Target.RT700: - # TODO: Consider different tensor formats (dim-order) paddings = node.args[1] - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension, which is not supported on Neutron. - return False + if node.meta[NXP_NODE_FORMAT].is_channels_first(): + # Dim `1` will end up being the channels. It is padded by paddings[4:6]. + if len(paddings) > 4 and paddings[4:6] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported + return False + else: + # Dim `-1` will end up being the channels. It is padded by paddings[:2]. + if len(paddings) > 0 and paddings[:2] != [0, 0]: + # Attempt to Pad channels dimension -> currently not supported + return False return True @@ -71,10 +79,6 @@ def _is_supported_in_IR( if not NodeConverter._has_shared_q_params_if_quantized(node): return False - if len(paddings) > 4 and paddings[4:6] != [0, 0]: - # Attempt to Pad channels dimension -> currently not supported - return False - return True # noinspection PyMethodMayBeStatic diff --git a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py index 47cd54c4efb..56be613a664 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -13,6 +13,7 @@ ) from executorch.backends.nxp.tests.executors import ( convert_run_compare, + graph_contains_any_of_ops, ToNCHWPreprocess, ToNHWCPreprocess, ) @@ -20,6 +21,7 @@ ConstantPadNDConvModule, ConstantPadNDModule, ) +from executorch.exir.dialects._ops import ops as exir_ops @pytest.fixture(autouse=True) @@ -121,3 +123,51 @@ def test_constant_pad_nd__unsupported_paddings(input_shape, paddings): nodes = list(exec_program.graph.nodes) # There is at least one non-delegated Pad node assert any(node.name == "aten_constant_pad_nd_default" for node in nodes) + + +def test_constant_pad_nd__delegation__formatless__supported_padding(): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 0, 1, 2, 3, 4] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__formatless__unsupported_padding(): + input_shape = (2, 4, 6, 8) # Formatless -> the last dim (8) will be padded. + paddings = [0, 1] # The last dim is padded using the first 2 paddings. + model = ConstantPadNDModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__supported_padding(): + input_shape = (2, 4, 6, 8) # Channels first -> the second dim (4) will be padded. + paddings = [1, 2, 3, 4, 0, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was delegated. + assert not graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) + + +def test_constant_pad_nd__delegation__channels_first__unsupported_padding(): + input_shape = (2, 3, 6, 8) # Channels first -> the second dim (3) will be padded. + paddings = [0, 0, 0, 0, 1, 0] # The second dim is padded using the paddings[4:6]. + model = ConstantPadNDConvModule(paddings) + exec_program = to_quantized_edge_program(model, input_shape).exported_program() + + # Make sure the `pad` was NOT delegated. + assert graph_contains_any_of_ops( + exec_program.graph, [exir_ops.edge.aten.constant_pad_nd.default] + ) From 403f8924a51223618f268f6f01343d5d3971af10 Mon Sep 17 00:00:00 2001 From: Martin Pavella Date: Thu, 11 Sep 2025 15:48:35 +0200 Subject: [PATCH 6/6] NXP backend: Do not infer format for unknown nodes. --- .../nxp/backend/edge_program_converter.py | 2 +- .../builder/aten_model_builder_director.py | 2 +- .../ops_converters/cat_converter.py | 2 +- .../constant_pad_nd_converter.py | 2 +- backends/nxp/backend/ir/tensor_formatting.py | 2 +- backends/nxp/backend/node_format.py | 23 +++++++++ backends/nxp/backend/node_format_inference.py | 48 +++++++++++-------- 7 files changed, 57 insertions(+), 24 deletions(-) create mode 100644 backends/nxp/backend/node_format.py diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index 0b945b75694..522bebcb186 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -18,7 +18,7 @@ from torch.fx import Node from torch.nn.parameter import Parameter from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403 -from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops # noinspection PyProtectedMember diff --git a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py index a420cea9aa7..51a4a226fc8 100644 --- a/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py +++ b/backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py @@ -9,7 +9,7 @@ from executorch.backends.nxp.backend.ir.converter.conversion import translator from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat from torch.fx import Node from torch.nn import Parameter diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py index ef3163d2da9..67355d4ecbf 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py @@ -18,7 +18,7 @@ from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import ( Concatenation, ) -from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py index 9257ee7e229..78c2b1479af 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py @@ -28,7 +28,7 @@ pad_v2_options, ) -from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT +from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT from torch.fx import Node from torch.nn import Parameter diff --git a/backends/nxp/backend/ir/tensor_formatting.py b/backends/nxp/backend/ir/tensor_formatting.py index aab22c3c368..492900e788a 100644 --- a/backends/nxp/backend/ir/tensor_formatting.py +++ b/backends/nxp/backend/ir/tensor_formatting.py @@ -7,7 +7,7 @@ # from enum import Enum -from executorch.backends.nxp.backend.node_format_inference import NodeFormat +from executorch.backends.nxp.backend.node_format import NodeFormat class TensorFormat(Enum): diff --git a/backends/nxp/backend/node_format.py b/backends/nxp/backend/node_format.py new file mode 100644 index 00000000000..91049c200d7 --- /dev/null +++ b/backends/nxp/backend/node_format.py @@ -0,0 +1,23 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + +# Key into the `meta` attribute of nodes, which is mapped to their inferred node format. +NXP_NODE_FORMAT = "nxp_node_format" + + +class NodeFormat(Enum): + # Node's output in NCHW format + CHANNELS_FIRST = 0 + + # Node's output format has no meaning + FORMATLESS = 1 + + # Format has not been identified + NONE = 2 + + def is_channels_first(self) -> bool: + return self == NodeFormat.CHANNELS_FIRST diff --git a/backends/nxp/backend/node_format_inference.py b/backends/nxp/backend/node_format_inference.py index 1e4c86c2aec..77f4fd17900 100644 --- a/backends/nxp/backend/node_format_inference.py +++ b/backends/nxp/backend/node_format_inference.py @@ -4,31 +4,18 @@ # LICENSE file in the root directory of this source tree. import logging -from enum import Enum +import operator +from executorch.backends.nxp.backend.edge_program_converter import functions_converters +from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload -from torch import Node from torch.export import ExportedProgram +from torch.fx import Node logger = logging.getLogger(__name__) -NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format. - - -class NodeFormat(Enum): - # Node's output in NCHW format - CHANNELS_FIRST = 0 - - # Node's output format has no meaning - FORMATLESS = 1 - - # Format has not been identified - NONE = 2 - - def is_channels_first(self) -> bool: - return self == NodeFormat.CHANNELS_FIRST - class NodeFormatInference: # Dictionary with Edge Aten ops that always use channels first format. @@ -53,6 +40,9 @@ class NodeFormatInference: # Mapping between Node and its children (outputs) _node_outputs: dict[Node, list[Node]] + # List of all edge operations, which are supported by the converter. + _known_targets: list[EdgeOpOverload] + def __init__(self, edge_program: ExportedProgram): self._edge_program = edge_program @@ -66,6 +56,12 @@ def __init__(self, edge_program: ExportedProgram): self._type_changed_during_last_run = False + self._known_targets = list(functions_converters) + [ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + operator.getitem, + ] + def identify_node_formats(self): self._type_changed_during_last_run = True @@ -100,9 +96,19 @@ def _infer_format_of_nodes(self, node: Node): logger.error( f"Node format inference for node type: {op_type} not found!" ) - else: + elif node.op != "call_function" or ( + hasattr(node, "target") and node.target in self._known_targets + ): + # Generic node, or tensor. self._handle_node_which_can_use_any_node_format(node) + else: + # Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide + # delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these + # partitions, which would require extra transpositions. + for processed_node in self._node_inputs[node] + [node]: + self._assign_format_to_node(processed_node, NodeFormat.NONE) + def _infer_format_based_on_io_ranks(self, node: Node): """Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input and output. @@ -155,6 +161,10 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat): # Once CHANNEL_FIRST was assigned, we don't want to reassign return + if node_format is NodeFormat.NONE and old_node_format is not NodeFormat.NONE: + # A format has already been assigned to the node before. Don't replace it with `NONE`. + return + if old_node_format != node_format: self._type_changed_during_last_run = True