4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
import logging
7
- from enum import Enum
7
+ import operator
8
8
9
+ from executorch .backends .nxp .backend .edge_program_converter import functions_converters
10
+ from executorch .backends .nxp .backend .node_format import NodeFormat , NXP_NODE_FORMAT
9
11
from executorch .exir .dialects ._ops import ops as exir_ops
12
+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload
10
13
11
- from torch import Node
12
14
from torch .export import ExportedProgram
15
+ from torch .fx import Node
13
16
14
17
logger = logging .getLogger (__name__ )
15
18
16
- NXP_NODE_FORMAT = "nxp_node_format" # Key into the `meta` attribute of nodes, which is mapped to the inferred format.
17
-
18
-
19
- class NodeFormat (Enum ):
20
- # Node's output in NCHW format
21
- CHANNELS_FIRST = 0
22
-
23
- # Node's output format has no meaning
24
- FORMATLESS = 1
25
-
26
- # Format has not been identified
27
- NONE = 2
28
-
29
- def is_channels_first (self ) -> bool :
30
- return self == NodeFormat .CHANNELS_FIRST
31
-
32
19
33
20
class NodeFormatInference :
34
21
# Dictionary with Edge Aten ops that always use channels first format.
@@ -53,6 +40,9 @@ class NodeFormatInference:
53
40
# Mapping between Node and its children (outputs)
54
41
_node_outputs : dict [Node , list [Node ]]
55
42
43
+ # List of all edge operations, which are supported by the converter.
44
+ _known_targets : list [EdgeOpOverload ]
45
+
56
46
def __init__ (self , edge_program : ExportedProgram ):
57
47
self ._edge_program = edge_program
58
48
@@ -66,6 +56,12 @@ def __init__(self, edge_program: ExportedProgram):
66
56
67
57
self ._type_changed_during_last_run = False
68
58
59
+ self ._known_targets = list (functions_converters ) + [
60
+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
61
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
62
+ operator .getitem ,
63
+ ]
64
+
69
65
def identify_node_formats (self ):
70
66
self ._type_changed_during_last_run = True
71
67
@@ -100,9 +96,19 @@ def _infer_format_of_nodes(self, node: Node):
100
96
logger .error (
101
97
f"Node format inference for node type: { op_type } not found!"
102
98
)
103
- else :
99
+ elif node .op != "call_function" or (
100
+ hasattr (node , "target" ) and node .target in self ._known_targets
101
+ ):
102
+ # Generic node, or tensor.
104
103
self ._handle_node_which_can_use_any_node_format (node )
105
104
105
+ else :
106
+ # Don't infer the format for unknown nodes. These nodes will never be delegated, so they will divide
107
+ # delegated partitions. Propagating the format here could unnecessarily enforce the format in one of these
108
+ # partitions, which would require extra transpositions.
109
+ for processed_node in self ._node_inputs [node ] + [node ]:
110
+ self ._assign_format_to_node (processed_node , NodeFormat .NONE )
111
+
106
112
def _infer_format_based_on_io_ranks (self , node : Node ):
107
113
"""Determine the format of the output tensor of given "reshape style operator" based on the ranks of its input
108
114
and output.
@@ -155,6 +161,10 @@ def _assign_format_to_node(self, node: Node, node_format: NodeFormat):
155
161
# Once CHANNEL_FIRST was assigned, we don't want to reassign
156
162
return
157
163
164
+ if node_format is NodeFormat .NONE and old_node_format is not NodeFormat .NONE :
165
+ # A format has already been assigned to the node before. Don't replace it with `NONE`.
166
+ return
167
+
158
168
if old_node_format != node_format :
159
169
self ._type_changed_during_last_run = True
160
170
0 commit comments