Skip to content

Commit 57ca96f

Browse files
StrycekSimonroman-janik-nxpskywall
authored
NXP backend: Per-channel quantization of convolution layer (#14061)
### Summary Add per-channel quantization for convolution layer and introduce NodeArgsIdx class to Neutron Quantizer for better handling of indexes to quantized node's args list. NodeArgsIdx allows selection of nested objects, e.g. an object in a list in node's args list. It also simplifies NeutronAtenQuantizer annotation process by using annotate_inputs() for inputs, weights and biases. ### Test plan The implementation should be covered by either existing or newly added unit tests. --------- Co-authored-by: Roman Janik <[email protected]> Co-authored-by: Lukas Sztefek <[email protected]>
1 parent ab31007 commit 57ca96f

18 files changed

+375
-166
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex
134134

135135
qdq_related_functions = [
136136
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
137+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
137138
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
138139
]
139140

@@ -203,7 +204,8 @@ def _convert_qdq_cluster_q_dq_nodes(
203204
:param conversion_context: ConversionContext instance.
204205
"""
205206
qdq_q_ops_converters = {
206-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, # noqa F405
207+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQPerTensorDequantizeConverter, # noqa F405
208+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: QDQPerChannelDequantizeConverter, # noqa F405
207209
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, # noqa F405
208210
}
209211

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
PermuteCopyConverter,
4242
)
4343
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_dequantize_converter import (
44-
QDQDequantizeConverter,
44+
QDQPerChannelDequantizeConverter,
45+
QDQPerTensorDequantizeConverter,
4546
)
4647
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_quantize_converter import (
4748
QDQQuantizeConverter,
@@ -70,7 +71,8 @@
7071
"PermuteCopyConverter",
7172
"SoftmaxConverter",
7273
"ViewCopyConverter",
73-
"QDQDequantizeConverter",
74+
"QDQPerTensorDequantizeConverter",
75+
"QDQPerChannelDequantizeConverter",
7476
"QDQQuantizeConverter",
7577
"ConstantPadNDConverter",
7678
"ReLUConverter",

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ def _convert_2d_conv(
321321
t_op.tmp_inputs[1] = self.builder.create_transposed_tensor(
322322
weight_tensor, perm
323323
)
324+
325+
if t_op.tmp_inputs[1].quantization is not None:
326+
# Model is quantized
327+
t_op.tmp_inputs[1].quantization.quantized_dimension = 3
324328
else:
325329
raise NotImplementedError("Dynamic Depthwise Conv weights.")
326330

backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from abc import ABC, abstractmethod
56

67
import numpy as np
78

@@ -19,15 +20,23 @@
1920
from torch.nn import Parameter
2021

2122

22-
class QDQDequantizeConverter(NodeConverter):
23+
class QDQDequantizeConverterBase(NodeConverter, ABC):
24+
25+
@abstractmethod
26+
def get_zero_point(self, node: Node) -> np.ndarray:
27+
pass
28+
29+
@abstractmethod
30+
def get_scale(self, node: Node) -> np.ndarray:
31+
pass
2332

2433
@staticmethod
2534
def _is_supported_in_IR(
2635
node: Node,
2736
parameters_mapping: dict[str, Parameter],
2837
custom_delegation_options: CustomDelegationOptions,
2938
) -> bool:
30-
zero_point_type = torch_type_to_numpy_type(node.args[5])
39+
zero_point_type = torch_type_to_numpy_type(node.args[-1])
3140
if "cluster" not in node.meta or zero_point_type not in [np.int8, np.int32]:
3241
return False
3342

@@ -39,10 +48,8 @@ def convert(self, node: Node):
3948
from_tensor = self.builder.tensor_for_name(node.name)
4049
to_tensor = self.builder.tensor_for_name(node.args[0].name)
4150

42-
zero_point_type = torch_type_to_numpy_type(node.args[5])
43-
44-
scale = np.array(node.args[1], dtype=np.float32)
45-
zero_point = np.array(node.args[2], dtype=zero_point_type)
51+
scale = self.get_scale(node)
52+
zero_point = self.get_zero_point(node)
4653

4754
if self.context.parameters_mapping.get(node.args[0].name, None) is None:
4855
# Convert dequantize as identity op (Transpose that will be removed) because
@@ -63,3 +70,22 @@ def convert(self, node: Node):
6370
# Change type so we pass check tensor similarity check when redirecting
6471
from_tensor.type = to_tensor.type
6572
self.builder.redirect_tensor(from_tensor, to_tensor)
73+
74+
75+
class QDQPerTensorDequantizeConverter(QDQDequantizeConverterBase):
76+
77+
def get_zero_point(self, node: Node) -> np.ndarray:
78+
zero_point_type = torch_type_to_numpy_type(node.args[5])
79+
return np.array(node.args[2], dtype=zero_point_type)
80+
81+
def get_scale(self, node: Node) -> np.ndarray:
82+
return np.array(node.args[1], dtype=np.float32)
83+
84+
85+
class QDQPerChannelDequantizeConverter(QDQDequantizeConverterBase):
86+
87+
def get_zero_point(self, node: Node) -> np.ndarray:
88+
return self.context.parameters_mapping[node.args[2].name].numpy()
89+
90+
def get_scale(self, node: Node) -> np.ndarray:
91+
return self.context.parameters_mapping[node.args[1].name].numpy()

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List, Optional, Tuple, Union
8-
97
import torch
108

119
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
@@ -27,6 +25,7 @@
2725
LinearPattern,
2826
MaxPoolPattern,
2927
MeanDimPattern,
28+
NodeArgsIdx,
3029
PadPattern,
3130
PermutePattern,
3231
QuantizationPattern,
@@ -106,57 +105,43 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
106105
)
107106

108107
def annotate_inputs(
109-
inputs: Union[
110-
List[Tuple[fx.Node, int]],
111-
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
112-
],
113-
spec: Optional[QuantizationSpec],
108+
inputs: (
109+
list[tuple[fx.Node, NodeArgsIdx]]
110+
| list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]]
111+
),
112+
spec: QuantizationSpec | None,
114113
) -> None:
115-
for node, idx, *custom_spec in inputs:
114+
for node, args_idx, *custom_spec in inputs:
116115
# pyre-ignore[16]: no attribute
117116
annotation = node.meta.get(
118117
Q_ANNOTATION_KEY,
119118
QuantizationAnnotation(_annotated=True),
120119
)
121120
arg = (
122121
# pyre-ignore[16]: no attribute
123-
node.args[idx]
124-
if isinstance(idx, int)
122+
node.args[args_idx.idx]
123+
if args_idx.inner_idx is None
125124
# pyre-ignore[16]: no attribute
126-
else node.args[idx[0]][idx[1]]
125+
else node.args[args_idx.idx][args_idx.inner_idx]
127126
)
128127
annotation.input_qspec_map[arg] = (
129128
custom_spec[0] if custom_spec else spec
130129
)
131130
# pyre-ignore[16]: no attribute
132131
node.meta[Q_ANNOTATION_KEY] = annotation
133132

134-
def annotate_weights_or_biases(
135-
weights_or_biases: List[Tuple[fx.Node, int]],
136-
spec: Optional[QuantizationSpec],
137-
) -> None:
138-
for node, idx, *custom_spec in weights_or_biases:
139-
annotation = node.meta.get(
140-
Q_ANNOTATION_KEY,
141-
QuantizationAnnotation(_annotated=True),
142-
)
143-
annotation.input_qspec_map[node.args[idx]] = (
144-
custom_spec[0] if custom_spec else spec
145-
)
146-
node.meta[Q_ANNOTATION_KEY] = annotation
147-
148133
# pyre-ignore[6]: incompatible parameter type
149134
annotate_inputs(anchors.inputs, input_act_qspec)
150-
annotate_weights_or_biases(anchors.weights, weight_qspec)
135+
annotate_inputs(anchors.weights, weight_qspec)
151136
# pyre-ignore[6]: incompatible parameter type
152-
annotate_weights_or_biases(anchors.biases, bias_qspec)
137+
annotate_inputs(anchors.biases, bias_qspec)
153138
return model
154139

155140
def validate(self, model: fx.GraphModule) -> None:
156141
pass
157142

158143
@classmethod
159-
def get_supported_operators(cls) -> List[OperatorConfig]:
144+
def get_supported_operators(cls) -> list[OperatorConfig]:
160145
return []
161146

162147

@@ -195,12 +180,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]:
195180

196181
class NeutronQuantizer(ComposableQuantizer):
197182
def __init__(self):
198-
static_qconfig = QuantizationConfig(
199-
act_qspec,
200-
act_qspec,
201-
wgt_qspec,
202-
None,
203-
)
183+
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
204184
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
205185
super().__init__(
206186
[

0 commit comments

Comments
 (0)