Skip to content

Commit be47bc3

Browse files
committed
NXP backend: Do not infer format for unknown nodes.
1 parent 9fc32da commit be47bc3

File tree

7 files changed

+57
-24
lines changed

7 files changed

+57
-24
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.fx import Node
1919
from torch.nn.parameter import Parameter
2020
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
21-
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
21+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2222
from executorch.exir.dialects._ops import ops as exir_ops
2323

2424
# noinspection PyProtectedMember

backends/nxp/backend/ir/converter/builder/aten_model_builder_director.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from executorch.backends.nxp.backend.ir.converter.conversion import translator
1010
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
1111
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
12-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
12+
from executorch.backends.nxp.backend.node_format import NodeFormat
1313
from torch.fx import Node
1414
from torch.nn import Parameter
1515

backends/nxp/backend/ir/converter/node_converters/ops_converters/cat_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import (
1919
Concatenation,
2020
)
21-
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
21+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
2222
from torch.fx import Node
2323
from torch.nn import Parameter
2424

backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
pad_v2_options,
2929
)
3030

31-
from executorch.backends.nxp.backend.node_format_inference import NXP_NODE_FORMAT
31+
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
3232
from torch.fx import Node
3333
from torch.nn import Parameter
3434

backends/nxp/backend/ir/tensor_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
from enum import Enum
99

10-
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
10+
from executorch.backends.nxp.backend.node_format import NodeFormat
1111

1212

1313
class TensorFormat(Enum):
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from enum import Enum
7+
8+
# Key into the `meta` attribute of nodes, which is mapped to their inferred node format.
9+
NXP_NODE_FORMAT = "nxp_node_format"
10+
11+
12+
class NodeFormat(Enum):
13+
# Node's output in NCHW format
14+
CHANNELS_FIRST = 0
15+
16+
# Node's output format has no meaning
17+
FORMATLESS = 1
18+
19+
# Format has not been identified
20+
NONE = 2
21+
22+
def is_channels_first(self) -> bool:
23+
return self == NodeFormat.CHANNELS_FIRST

backends/nxp/backend/node_format_inference.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,18 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import logging
7-
from enum import Enum
7+
import operator
88

9+
from executorch.backends.nxp.backend.edge_program_converter import functions_converters
10+
from executorch.backends.nxp.backend.node_format import NodeFormat, NXP_NODE_FORMAT
911
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1013

11-
from torch import Node
1214
from torch.export import ExportedProgram
15+
from torch.fx import Node
1316

1417
logger = logging.getLogger(__name__)
1518

16-
NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format.
17-
18-
19-
class NodeFormat(Enum):
20-
# Node's output in NCHW format
21-
CHANNELS_FIRST = 0
22-
23-
# Node's output format has no meaning
24-
FORMATLESS = 1
25-
26-
# Format has not been identified
27-
NONE = 2
28-
29-
def is_channels_first(self) -> bool:
30-
return self == NodeFormat.CHANNELS_FIRST
31-
3219

3320
class NodeFormatInference:
3421
# Dictionary with Edge Aten ops that always use channels first format.
@@ -53,6 +40,9 @@ class NodeFormatInference:
5340
# Mapping between Node and its children (outputs)
5441
_node_outputs: dict[Node, list[Node]]
5542

43+
# List of all edge operations, which are supported by the converter.
44+
_known_targets: list[EdgeOpOverload]
45+
5646
def __init__(self, edge_program: ExportedProgram):
5747
self._edge_program = edge_program
5848

@@ -66,6 +56,12 @@ def __init__(self, edge_program: ExportedProgram):
6656

6757
self._type_changed_during_last_run = False
6858

59+
self._known_targets = list(functions_converters) + [
60+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
61+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
62+
operator.getitem,
63+
]
64+
6965
def identify_node_formats(self):
7066
self._type_changed_during_last_run = True
7167

@@ -100,9 +96,19 @@ def _infer_format_of_nodes(self, node: Node):
10096
logger.error(
10197
f"Node format inference for node type: {op_type} not found!"
10298
)
103-
else:
99+
elif node.op != "call_function" or (
100+
hasattr(node, "target") and node.target in self._known_targets
101+
):
102+
# Generic node, or tensor.
104103
self._handle_node_which_can_use_any_node_format(node)
105104

105+
else:
106+
# Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide
107+
# delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these
108+
# partitions, which would require extra transpositions.
109+
for processed_node in self._node_inputs[node] + [node]:
110+
self._assign_format_to_node(processed_node, NodeFormat.NONE)
111+
106112
def _infer_format_based_on_io_ranks(self, node: Node):
107113
"""Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input
108114
and output.
@@ -155,6 +161,10 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
155161
# Once CHANNEL_FIRST was assigned, we don't want to reassign
156162
return
157163

164+
if node_format is NodeFormat.NONE and old_node_format is not NodeFormat.NONE:
165+
# A format has already been assigned to the node before. Don't replace it with `NONE`.
166+
return
167+
158168
if old_node_format != node_format:
159169
self._type_changed_during_last_run = True
160170

0 commit comments

Comments
 (0)