Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions backends/nxp/_passes/remove_getitem_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NXP_NODE_FORMAT,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class RemoveGetItemPass(ExportPass):
"""
This remove item is used to remove getitem operator for max_pool2d_with_indices.default operator, and replace it with a single operator,
that extracts the first output. More specifically, we are only getting the first output from aten::maxpool2d operator.
Before Pass:
MaxPool2d ---> GetItem[max_values, max_indexes]
After Pass:
MaxPool2d -> max_values
"""

def call(self, graph_module: torch.fx.GraphModule):
module = graph_module
for node in module.graph.nodes:
if node.op == "call_function":
if (
node.target.__name__ == "aten.max_pool2d_with_indices.default"
or node.target.__name__ == "aten.max.dim"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we handle the aten.max.dim too? Is it a loftover from original pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it was in the original file. I wanted to make as few changes as possible, as it is not the main focus of this PR.

):
users = list(node.users.keys())

if len(users) != 1:
if len(users) == 2 and node.target.__name__ == "aten.max.dim":
# Two users is allowed for max.dim. For that case,
# rather than removing the getitem node in this
# pass, we handle the getitem nodes in the op's
# visitor when serializing
continue
else:
raise AssertionError(
f"Invalid number of users for {node.target.__name__}: {len(users)}"
)

getitem_node = list(node.users.keys())[0]

if getitem_node.target.__name__ != "getitem":
raise AssertionError(
f"Expected max node's user to be getitem, got {getitem_node.target.__name__}"
)

getitem_index = getitem_node.args[1]

with module.graph.inserting_before(node):
if (
node.target.__name__
== "aten.max_pool2d_with_indices.default"
):
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values from the op but not getting the corresponding indices."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.max_pool2d.default,
args=node.args,
kwargs=node.kwargs,
)

else:
if getitem_index != 0:
raise AssertionError(
f"Expected second argument of getitem node for {node.target.__name__} to be 0, got {getitem_index}. XNNPACK delegate currently only supports getting just the max values or getting both the max values and their corresponding indices from the op, but not getting the indices alone."
)
new_max_wd = module.graph.create_node(
"call_function",
exir_ops.edge.aten.amax.default,
args=node.args,
kwargs=node.kwargs,
)

# MODIFIED PART START
# Make sure to preserve the inferred node format.
new_max_wd.meta[NXP_NODE_FORMAT] = node.meta.get(
NXP_NODE_FORMAT, NodeFormat.NONE
)
# MODIFIED PART END

getitem_node.replace_all_uses_with(new_max_wd)

module.graph.erase_node(getitem_node)
module.graph.erase_node(node)

graph_module.recompile()
# Propagate metadata and retrace module
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)
13 changes: 3 additions & 10 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
from torch.fx import Node
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NodeFormatInference,
)
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
Expand Down Expand Up @@ -70,12 +67,10 @@ def convert_program(
:param custom_delegation_options: Custom user options which affect node delegation.
:return: TFLite flatbuffers as bytes.
"""
node_formats = NodeFormatInference(edge_program).identify_node_formats()
parameters_mapping = self.map_inputs_to_parameters(edge_program)

cc = self.build_conversion_context(
parameters_mapping,
node_formats,
conversion_config,
custom_delegation_options,
)
Expand All @@ -101,7 +96,7 @@ def convert_program(
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
for node in nodes:
if node.op == "placeholder":
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]

if node.name in context.parameters_mapping:
# Node is placeholder and has data -> append as static tensor with data
Expand All @@ -114,7 +109,7 @@ def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContex
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "call_function":
# Node is call function -> append only output as a tensor
node_format = context.node_formats[node]
node_format = node.meta[NXP_NODE_FORMAT]
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "output":
# Nothing to do
Expand Down Expand Up @@ -171,7 +166,6 @@ def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Paramet
@staticmethod
def build_conversion_context(
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
conversion_config: ConversionConfig = _default_conversion_config,
custom_delegation_options: CustomDelegationOptions = _default_delegation_options,
) -> ConversionContext:
Expand All @@ -186,7 +180,6 @@ def build_conversion_context(
tflite_builder,
conversion_config,
parameters_mapping,
node_formats,
custom_delegation_options,
)

Expand Down
5 changes: 0 additions & 5 deletions backends/nxp/backend/ir/conversion_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,20 @@
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
AtenModelBuilderDirector,
)
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from torch import Node
from torch.nn import Parameter


class ConversionContext:
tflite_builder: AtenModelBuilderDirector
conversion_config: ConversionConfig
parameters_mapping: dict[str, Parameter]
node_formats: dict[Node, NodeFormat]
custom_delegation_options: CustomDelegationOptions

def __init__(
self,
tflite_builder: AtenModelBuilderDirector,
conversion_config: ConversionConfig,
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
custom_delegation_options: CustomDelegationOptions,
):
"""
Expand All @@ -39,5 +35,4 @@ def __init__(
self.tflite_builder = tflite_builder
self.conversion_config = conversion_config
self.parameters_mapping = parameters_mapping
self.node_formats = node_formats
self.custom_delegation_options = custom_delegation_options
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.tensor_formatting import TensorFormat
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from executorch.backends.nxp.backend.node_format import NodeFormat
from torch.fx import Node
from torch.nn import Parameter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options.concatenation_options import (
Concatenation,
)
from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from torch.fx import Node
from torch.nn import Parameter

Expand Down Expand Up @@ -88,25 +89,27 @@ def _is_supported_on_target(
return False

# Neutron requires the channels to be a multiple of `8`. The channels could either be the second or the
# last dimension, depending on the formats of the node. The format, however, cannot be determined
# during conversion, as it depends on what other nodes are delegated.
# last dimension, depending on the formats of the node.
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# During conversion to IR, the shape will be permuted to channels last, and the dimension on index
# `1` will end up being the channels (last dim in NHWC).
channels_index = 1
else:
# The shape will not be permuted during conversion, so the channels will remain the last dimension.
channels_index = -1

input_channels = [
# The second dimension is the channels in PyTorch. If the inputs/output are not channels first, it
# will still be the channels in the IR.
_get_shape(input_)[1]
for input_ in node.all_input_nodes
] + [
# If the inputs/outputs are channels first, the last dimension will be the channels.
_get_shape(input_)[-1]
_get_shape(input_)[channels_index]
for input_ in node.all_input_nodes
]
output_channels = _get_shape(node)[channels_index]

if any((input_channel % 8) != 0 for input_channel in input_channels):
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1492
return False

output_channels = [_get_shape(node)[1], _get_shape(node)[-1]]
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
if any((out_c % 8) != 0 for out_c in output_channels):
if (output_channels % 8) != 0:
# neutron-library/src/utils/NeutronLibraryInterrogation.cpp#1493
return False

if len(node.all_input_nodes) < 2: # Not supported on Neutron
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
pad_options,
pad_v2_options,
)

from executorch.backends.nxp.backend.node_format import NXP_NODE_FORMAT
from torch.fx import Node
from torch.nn import Parameter

Expand All @@ -41,11 +43,17 @@ def _is_supported_on_target(
) -> bool:
match target:
case Target.RT700:
# TODO: Consider different tensor formats (dim-order)
paddings = node.args[1]
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension, which is not supported on Neutron.
return False
if node.meta[NXP_NODE_FORMAT].is_channels_first():
# Dim `1` will end up being the channels. It is padded by paddings[4:6].
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False
else:
# Dim `-1` will end up being the channels. It is padded by paddings[:2].
if len(paddings) > 0 and paddings[:2] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False

return True

Expand All @@ -71,10 +79,6 @@ def _is_supported_in_IR(
if not NodeConverter._has_shared_q_params_if_quantized(node):
return False

if len(paddings) > 4 and paddings[4:6] != [0, 0]:
# Attempt to Pad channels dimension -> currently not supported
return False

return True

# noinspection PyMethodMayBeStatic
Expand Down
2 changes: 1 addition & 1 deletion backends/nxp/backend/ir/tensor_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#
from enum import Enum

from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from executorch.backends.nxp.backend.node_format import NodeFormat


class TensorFormat(Enum):
Expand Down
23 changes: 23 additions & 0 deletions backends/nxp/backend/node_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum

# Key into the `meta` attribute of nodes, which is mapped to their inferred node format.
NXP_NODE_FORMAT = "nxp_node_format"


class NodeFormat(Enum):
# Node's output in NCHW format
CHANNELS_FIRST = 0

# Node's output format has no meaning
FORMATLESS = 1

# Format has not been identified
NONE = 2

def is_channels_first(self) -> bool:
return self == NodeFormat.CHANNELS_FIRST
Loading