Skip to content

Commit b669b4f

Browse files
NXP backend: Resolve limitations of uncertain tensor formats (#13942)
### Summary This PR resolves format related issues by inferring the format (NCHW/NHWC) for all nodes before partitioning. These formats are then used by the NeutronPartitioner to accurately determine which nodes are supported on Neutron. ### Test plan Unit tests provided, and correct function is tested by nearly every test in the nxp backend.
1 parent 30568d2 commit b669b4f

17 files changed

+362
-76
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2025 NXP
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import torch
9+
10+
from executorch.backends.nxp.backend.node_format_inference import (
11+
NodeFormat,
12+
NXP_NODE_FORMAT,
13+
)
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
18+
class RemoveGetItemPass(ExportPass):
19+
"""
20+
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
21+
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
22+
Before Pass:
23+
MaxPool2d ---> GetItem[max_values, max_indexes]
24+
After Pass:
25+
MaxPool2d -> max_values
26+
"""
27+
28+
def call(self, graph_module: torch.fx.GraphModule):
29+
module = graph_module
30+
for node in module.graph.nodes:
31+
if node.op == "call_function":
32+
if (
33+
node.target.__name__ == "aten.max_pool2d_with_indices.default"
34+
or node.target.__name__ == "aten.max.dim"
35+
):
36+
users = list(node.users.keys())
37+
38+
if len(users) != 1:
39+
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
40+
# Two users is allowed for max.dim. For that case,
41+
# rather than removing the getitem node in this
42+
# pass, we handle the getitem nodes in the op's
43+
# visitor when serializing
44+
continue
45+
else:
46+
raise AssertionError(
47+
f"Invalid number of users for {node.target.__name__}: {len(users)}"
48+
)
49+
50+
getitem_node = list(node.users.keys())[0]
51+
52+
if getitem_node.target.__name__ != "getitem":
53+
raise AssertionError(
54+
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
55+
)
56+
57+
getitem_index = getitem_node.args[1]
58+
59+
with module.graph.inserting_before(node):
60+
if (
61+
node.target.__name__
62+
== "aten.max_pool2d_with_indices.default"
63+
):
64+
if getitem_index != 0:
65+
raise AssertionError(
66+
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices."
67+
)
68+
new_max_wd = module.graph.create_node(
69+
"call_function",
70+
exir_ops.edge.aten.max_pool2d.default,
71+
args=node.args,
72+
kwargs=node.kwargs,
73+
)
74+
75+
else:
76+
if getitem_index != 0:
77+
raise AssertionError(
78+
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone."
79+
)
80+
new_max_wd = module.graph.create_node(
81+
"call_function",
82+
exir_ops.edge.aten.amax.default,
83+
args=node.args,
84+
kwargs=node.kwargs,
85+
)
86+
87+
# MODIFIED PART START
88+
# Make sure to preserve the inferred node format.
89+
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
90+
NXP_NODE_FORMAT, NodeFormat.NONE
91+
)
92+
# MODIFIED PART END
93+
94+
getitem_node.replace_all_uses_with(new_max_wd)
95+
96+
module.graph.erase_node(getitem_node)
97+
module.graph.erase_node(node)
98+
99+
graph_module.recompile()
100+
# Propagate metadata and retrace module
101+
graph_module = super().call(graph_module).graph_module
102+
103+
return PassResult(graph_module, True)

backends/nxp/backend/edge_program_converter.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
from torch.fx import Node
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
21-
from executorch.backends.nxp.backend.node_format_inference import (
22-
NodeFormat,
23-
NodeFormatInference,
24-
)
21+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2522
from executorch.exir.dialects._ops import ops as exir_ops
2623

2724
# noinspection PyProtectedMember
@@ -70,12 +67,10 @@ def convert_program(
7067
:param custom_delegation_options: Custom user options which affect node delegation.
7168
:return: TFLite flatbuffers as bytes.
7269
"""
73-
node_formats = NodeFormatInference(edge_program).identify_node_formats()
7470
parameters_mapping = self.map_inputs_to_parameters(edge_program)
7571

7672
cc = self.build_conversion_context(
7773
parameters_mapping,
78-
node_formats,
7974
conversion_config,
8075
custom_delegation_options,
8176
)
@@ -101,7 +96,7 @@ def convert_program(
10196
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
10297
for node in nodes:
10398
if node.op == "placeholder":
104-
node_format = context.node_formats[node]
99+
node_format = node.meta[NXP_NODE_FORMAT]
105100

106101
if node.name in context.parameters_mapping:
107102
# Node is placeholder and has data -> append as static tensor with data
@@ -114,7 +109,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
114109
context.tflite_builder.append_as_fake_tensor(node, node_format)
115110
elif node.op == "call_function":
116111
# Node is call function -> append only output as a tensor
117-
node_format = context.node_formats[node]
112+
node_format = node.meta[NXP_NODE_FORMAT]
118113
context.tflite_builder.append_as_fake_tensor(node, node_format)
119114
elif node.op == "output":
120115
# Nothing to do
@@ -171,7 +166,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
171166
@staticmethod
172167
def build_conversion_context(
173168
parameters_mapping: dict,
174-
node_formats: dict[Node, NodeFormat],
175169
conversion_config: ConversionConfig = _default_conversion_config,
176170
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
177171
) -> ConversionContext:
@@ -186,7 +180,6 @@ def build_conversion_context(
186180
tflite_builder,
187181
conversion_config,
188182
parameters_mapping,
189-
node_formats,
190183
custom_delegation_options,
191184
)
192185

backends/nxp/backend/ir/conversion_context.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,20 @@
1010
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
1111
AtenModelBuilderDirector,
1212
)
13-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
14-
from torch import Node
1513
from torch.nn import Parameter
1614

1715

1816
class ConversionContext:
1917
tflite_builder: AtenModelBuilderDirector
2018
conversion_config: ConversionConfig
2119
parameters_mapping: dict[str, Parameter]
22-
node_formats: dict[Node, NodeFormat]
2320
custom_delegation_options: CustomDelegationOptions
2421

2522
def __init__(
2623
self,
2724
tflite_builder: AtenModelBuilderDirector,
2825
conversion_config: ConversionConfig,
2926
parameters_mapping: dict,
30-
node_formats: dict[Node, NodeFormat],
3127
custom_delegation_options: CustomDelegationOptions,
3228
):
3329
"""
@@ -39,5 +35,4 @@ def __init__(
3935
self.tflite_builder = tflite_builder
4036
self.conversion_config = conversion_config
4137
self.parameters_mapping = parameters_mapping
42-
self.node_formats = node_formats
4338
self.custom_delegation_options = custom_delegation_options

backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1010
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
1111
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
12-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
12+
from executorch.backends.nxp.backend.node_format import NodeFormat
1313
from torch.fx import Node
1414
from torch.nn import Parameter
1515

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import (
1919
Concatenation,
2020
)
21+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2122
from torch.fx import Node
2223
from torch.nn import Parameter
2324

@@ -88,25 +89,27 @@ def _is_supported_on_target(
8889
return False
8990

9091
# Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the
91-
# last dimension, depending on the formats of the node. The format, however, cannot be determined
92-
# during conversion, as it depends on what other nodes are delegated.
92+
# last dimension, depending on the formats of the node.
93+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
94+
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
95+
# `1` will end up being the channels (last dim in NHWC).
96+
channels_index = 1
97+
else:
98+
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
99+
channels_index = -1
100+
93101
input_channels = [
94-
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
95-
# will still be the channels in the IR.
96-
_get_shape(input_)[1]
97-
for input_ in node.all_input_nodes
98-
] + [
99-
# If the inputs/outputs are channels first, the last dimension will be the channels.
100-
_get_shape(input_)[-1]
102+
_get_shape(input_)[channels_index]
101103
for input_ in node.all_input_nodes
102104
]
105+
output_channels = _get_shape(node)[channels_index]
106+
103107
if any((input_channel % 8) != 0 for input_channel in input_channels):
104108
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
105109
return False
106110

107-
output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
108-
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
109-
if any((out_c % 8) != 0 for out_c in output_channels):
111+
if (output_channels % 8) != 0:
112+
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
110113
return False
111114

112115
if len(node.all_input_nodes) < 2: # Not supported on Neutron

backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
pad_options,
2828
pad_v2_options,
2929
)
30+
31+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
3032
from torch.fx import Node
3133
from torch.nn import Parameter
3234

@@ -41,11 +43,17 @@ def _is_supported_on_target(
4143
) -> bool:
4244
match target:
4345
case Target.RT700:
44-
# TODO: Consider different tensor formats (dim-order)
4546
paddings = node.args[1]
46-
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
47-
# Attempt to Pad channels dimension, which is not supported on Neutron.
48-
return False
47+
if node.meta[NXP_NODE_FORMAT].is_channels_first():
48+
# Dim `1` will end up being the channels. It is padded by paddings[4:6].
49+
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
50+
# Attempt to Pad channels dimension -> currently not supported
51+
return False
52+
else:
53+
# Dim `-1` will end up being the channels. It is padded by paddings[:2].
54+
if len(paddings) > 0 and paddings[:2] != [0, 0]:
55+
# Attempt to Pad channels dimension -> currently not supported
56+
return False
4957

5058
return True
5159

@@ -71,10 +79,6 @@ def _is_supported_in_IR(
7179
if not NodeConverter._has_shared_q_params_if_quantized(node):
7280
return False
7381

74-
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
75-
# Attempt to Pad channels dimension -> currently not supported
76-
return False
77-
7882
return True
7983

8084
# noinspection PyMethodMayBeStatic

backends/nxp/backend/ir/tensor_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
from enum import Enum
99

10-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
10+
from executorch.backends.nxp.backend.node_format import NodeFormat
1111

1212

1313
class TensorFormat(Enum):
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from enum import Enum
7+
8+
# Key into the `meta` attribute of nodes, which is mapped to their inferred node format.
9+
NXP_NODE_FORMAT = "nxp_node_format"
10+
11+
12+
class NodeFormat(Enum):
13+
# Node's output in NCHW format
14+
CHANNELS_FIRST = 0
15+
16+
# Node's output format has no meaning
17+
FORMATLESS = 1
18+
19+
# Format has not been identified
20+
NONE = 2
21+
22+
def is_channels_first(self) -> bool:
23+
return self == NodeFormat.CHANNELS_FIRST

0 commit comments

Comments
 (0)