Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex

qdq_related_functions = [
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
]

Expand Down Expand Up @@ -203,7 +204,8 @@ def _convert_qdq_cluster_q_dq_nodes(
:param conversion_context: ConversionContext instance.
"""
qdq_q_ops_converters = {
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, # noqa F405
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQPerTensorDequantizeConverter, # noqa F405
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: QDQPerChannelDequantizeConverter, # noqa F405
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, # noqa F405
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
PermuteCopyConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_dequantize_converter import (
QDQDequantizeConverter,
QDQPerChannelDequantizeConverter,
QDQPerTensorDequantizeConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_quantize_converter import (
QDQQuantizeConverter,
Expand Down Expand Up @@ -70,7 +71,8 @@
"PermuteCopyConverter",
"SoftmaxConverter",
"ViewCopyConverter",
"QDQDequantizeConverter",
"QDQPerTensorDequantizeConverter",
"QDQPerChannelDequantizeConverter",
"QDQQuantizeConverter",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only QDQDequantizer needs to be updated, not QDQQuantizeConverter too?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, as there are no changes to QDQQuantizeConverter. Per channel quantization scheme is used only for weights and biases, which are inputs - dequantize nodes.

"ConstantPadNDConverter",
"ReLUConverter",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ def _convert_2d_conv(
t_op.tmp_inputs[1] = self.builder.create_transposed_tensor(
weight_tensor, perm
)

if t_op.tmp_inputs[1].quantization is not None:
# Model is quantized
t_op.tmp_inputs[1].quantization.quantized_dimension = 3
else:
raise NotImplementedError("Dynamic Depthwise Conv weights.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod

import numpy as np

Expand All @@ -19,15 +20,23 @@
from torch.nn import Parameter


class QDQDequantizeConverter(NodeConverter):
class QDQDequantizeConverterBase(NodeConverter, ABC):

@abstractmethod
def get_zero_point(self, node: Node) -> np.ndarray:
pass

@abstractmethod
def get_scale(self, node: Node) -> np.ndarray:
pass

@staticmethod
def _is_supported_in_IR(
node: Node,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
zero_point_type = torch_type_to_numpy_type(node.args[5])
zero_point_type = torch_type_to_numpy_type(node.args[-1])
if "cluster" not in node.meta or zero_point_type not in [np.int8, np.int32]:
return False

Expand All @@ -39,10 +48,8 @@ def convert(self, node: Node):
from_tensor = self.builder.tensor_for_name(node.name)
to_tensor = self.builder.tensor_for_name(node.args[0].name)

zero_point_type = torch_type_to_numpy_type(node.args[5])

scale = np.array(node.args[1], dtype=np.float32)
zero_point = np.array(node.args[2], dtype=zero_point_type)
scale = self.get_scale(node)
zero_point = self.get_zero_point(node)

if self.context.parameters_mapping.get(node.args[0].name, None) is None:
# Convert dequantize as identity op (Transpose that will be removed) because
Expand All @@ -63,3 +70,22 @@ def convert(self, node: Node):
# Change type so we pass check tensor similarity check when redirecting
from_tensor.type = to_tensor.type
self.builder.redirect_tensor(from_tensor, to_tensor)


class QDQPerTensorDequantizeConverter(QDQDequantizeConverterBase):

def get_zero_point(self, node: Node) -> np.ndarray:
zero_point_type = torch_type_to_numpy_type(node.args[5])
return np.array(node.args[2], dtype=zero_point_type)

def get_scale(self, node: Node) -> np.ndarray:
return np.array(node.args[1], dtype=np.float32)


class QDQPerChannelDequantizeConverter(QDQDequantizeConverterBase):

def get_zero_point(self, node: Node) -> np.ndarray:
return self.context.parameters_mapping[node.args[2].name].numpy()

def get_scale(self, node: Node) -> np.ndarray:
return self.context.parameters_mapping[node.args[1].name].numpy()
48 changes: 14 additions & 34 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple, Union

import torch

from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
Expand All @@ -27,6 +25,7 @@
LinearPattern,
MaxPoolPattern,
MeanDimPattern,
NodeArgsIdx,
PadPattern,
PermutePattern,
QuantizationPattern,
Expand Down Expand Up @@ -106,57 +105,43 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
)

def annotate_inputs(
inputs: Union[
List[Tuple[fx.Node, int]],
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
],
spec: Optional[QuantizationSpec],
inputs: (
list[tuple[fx.Node, NodeArgsIdx]]
| list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]]
),
spec: QuantizationSpec | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why switch from Optional to | None?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a part of move to Python 3.10 type hints and leaving imports from Typing.

) -> None:
for node, idx, *custom_spec in inputs:
for node, args_idx, *custom_spec in inputs:
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
arg = (
# pyre-ignore[16]: no attribute
node.args[idx]
if isinstance(idx, int)
node.args[args_idx.idx]
if args_idx.inner_idx is None
# pyre-ignore[16]: no attribute
else node.args[idx[0]][idx[1]]
else node.args[args_idx.idx][args_idx.inner_idx]
)
annotation.input_qspec_map[arg] = (
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta[Q_ANNOTATION_KEY] = annotation

def annotate_weights_or_biases(
weights_or_biases: List[Tuple[fx.Node, int]],
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in weights_or_biases:
annotation = node.meta.get(
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
node.meta[Q_ANNOTATION_KEY] = annotation

# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.inputs, input_act_qspec)
annotate_weights_or_biases(anchors.weights, weight_qspec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this function no longer used at all now and can be removed entirely?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is replaced by annotate_inputs().

annotate_inputs(anchors.weights, weight_qspec)
# pyre-ignore[6]: incompatible parameter type
annotate_weights_or_biases(anchors.biases, bias_qspec)
annotate_inputs(anchors.biases, bias_qspec)
return model

def validate(self, model: fx.GraphModule) -> None:
pass

@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
def get_supported_operators(cls) -> list[OperatorConfig]:
return []


Expand Down Expand Up @@ -195,12 +180,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]:

class NeutronQuantizer(ComposableQuantizer):
def __init__(self):
static_qconfig = QuantizationConfig(
act_qspec,
act_qspec,
wgt_qspec,
None,
)
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
super().__init__(
[
Expand Down
Loading
Loading