From e7c43b67c2c87e06929ce132bbe9fc6e87272c0b Mon Sep 17 00:00:00 2001 From: Roman Janik Date: Wed, 18 Jun 2025 18:25:52 +0200 Subject: [PATCH 1/5] NXP backend: Abstract PartitionAnchors annotations of arg indexes --- backends/nxp/quantizer/neutron_quantizer.py | 48 +++---- backends/nxp/quantizer/patterns.py | 131 +++++++++++--------- 2 files changed, 86 insertions(+), 93 deletions(-) diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index d3f84144aa3..d9dd019c864 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple, Union - import torch from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( @@ -27,6 +25,7 @@ LinearPattern, MaxPoolPattern, MeanDimPattern, + NodeArgsIdx, PadPattern, PermutePattern, QuantizationPattern, @@ -106,13 +105,13 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ) def annotate_inputs( - inputs: Union[ - List[Tuple[fx.Node, int]], - List[Tuple[fx.Node, int, DerivedQuantizationSpec],], - ], - spec: Optional[QuantizationSpec], + inputs: ( + list[tuple[fx.Node, NodeArgsIdx]] + | list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]] + ), + spec: QuantizationSpec | None, ) -> None: - for node, idx, *custom_spec in inputs: + for node, args_idx, *custom_spec in inputs: # pyre-ignore[16]: no attribute annotation = node.meta.get( Q_ANNOTATION_KEY, @@ -120,10 +119,10 @@ def annotate_inputs( ) arg = ( # pyre-ignore[16]: no attribute - node.args[idx] - if isinstance(idx, int) + node.args[args_idx.idx] + if args_idx.inner_idx is None # pyre-ignore[16]: no attribute - else node.args[idx[0]][idx[1]] + else node.args[args_idx.idx][args_idx.inner_idx] ) annotation.input_qspec_map[arg] = ( custom_spec[0] if custom_spec else spec @@ -131,32 +130,18 @@ def annotate_inputs( # pyre-ignore[16]: no attribute node.meta[Q_ANNOTATION_KEY] = annotation - def annotate_weights_or_biases( - weights_or_biases: List[Tuple[fx.Node, int]], - spec: Optional[QuantizationSpec], - ) -> None: - for node, idx, *custom_spec in weights_or_biases: - annotation = node.meta.get( - Q_ANNOTATION_KEY, - QuantizationAnnotation(_annotated=True), - ) - annotation.input_qspec_map[node.args[idx]] = ( - custom_spec[0] if custom_spec else spec - ) - node.meta[Q_ANNOTATION_KEY] = annotation - # pyre-ignore[6]: incompatible parameter type annotate_inputs(anchors.inputs, input_act_qspec) - annotate_weights_or_biases(anchors.weights, weight_qspec) + annotate_inputs(anchors.weights, weight_qspec) # pyre-ignore[6]: incompatible parameter type - annotate_weights_or_biases(anchors.biases, bias_qspec) + annotate_inputs(anchors.biases, bias_qspec) return model def validate(self, model: fx.GraphModule) -> None: pass @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: + def get_supported_operators(cls) -> list[OperatorConfig]: return [] @@ -195,12 +180,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]: class NeutronQuantizer(ComposableQuantizer): def __init__(self): - static_qconfig = QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, - None, - ) + static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) super().__init__( [ diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 651f995d570..1608c75c412 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Type, Union import torch @@ -22,11 +21,27 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +@dataclass +class NodeArgsIdx: + """ + Specifies indexes to args paramater of Node in node input annotation. + + + Attributes: + idx (int): Index to Node's args paramater (list). Selects an input Node or a list of Nodes at the index. + inner_idx (int): If specified, index to a list pointed by 'idx' attribute. Selects an input Node at the index. + Default: None. + """ + + idx: int + inner_idx: int = None + + @dataclass class PartitionAnchors: """ - All fields except output are lists of (node, args_index) pair, where node is from - the given partition and node.args[args_index] is an input to the partition. Assumes + All fields except output are lists of (node, node_args_idx) or (node, node_args_idx, quantization_spec) tuples, + where node is from the given partition and node.args[node_args_idx] is an input to the partition. Assumes a single output. Quantizer uses inputs, weights and biases for quantization annotation. The others @@ -35,25 +50,21 @@ class PartitionAnchors: """ # Inputs can share quantization parameters - inputs: List[ - Union[ - Tuple[fx.Node, Union[int, Tuple[int, int]]], - Tuple[ - fx.Node, - Union[int, Tuple[int, int]], - SharedQuantizationSpec, - ], - ] + inputs: list[ + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec], ] = field(default_factory=list) - weights: List[Tuple[fx.Node, int]] = field(default_factory=list) - biases: List[ - Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] + weights: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + biases: list[ + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec], + ] = field(default_factory=list) + others: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + output: list[ + tuple[fx.Node] + | tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec], ] = field(default_factory=list) - others: List[Tuple[fx.Node, int]] = field(default_factory=list) - literals: List[Tuple[fx.Node, int]] = field(default_factory=list) - output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field( - default_factory=list - ) empty: bool = False @@ -67,8 +78,8 @@ def partition_types(self) -> list[OpOverload]: @abstractmethod def get_anchors( - self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> Optional[PartitionAnchors]: + self, gm: torch.fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: pass @@ -80,11 +91,11 @@ class SharedSpecPattern(QuantizationPattern): quantization parameters (scale and zero-point). """ - def partition_types(self) -> List[Type[torch.nn.Module]]: + def partition_types(self) -> list[torch.nn.Module]: pass def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] assert len(fused_partition[0].input_nodes) == 1 @@ -97,7 +108,7 @@ def get_anchors( qspec = SharedQuantizationSpec(prev_node) return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[ @@ -126,7 +137,7 @@ def get_anchors_for_fixed_quant_specs( ) return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[ @@ -154,11 +165,11 @@ def partition_types(self): class AddmmPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.addmm.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... addmm_node = fused_partition[0].nodes[-1] @@ -176,9 +187,9 @@ def get_anchors( ) return PartitionAnchors( - inputs=[(addmm_node, 1)], - weights=[(addmm_node, 2)], - biases=[(addmm_node, 0, bias_qspec)], + inputs=[(addmm_node, NodeArgsIdx(1))], + weights=[(addmm_node, NodeArgsIdx(2))], + biases=[(addmm_node, NodeArgsIdx(0), bias_qspec)], output=[(addmm_node,)], ) @@ -190,16 +201,16 @@ class AddTensorPattern(QuantizationPattern): Basic quantization for all inputs and output. """ - def partition_types(self) -> List[Type[torch.nn.Module]]: + def partition_types(self) -> list[torch.nn.Module]: return [torch.ops.aten.add.Tensor] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] - inputs = [(node, 0)] + inputs = [(node, NodeArgsIdx(0))] if len(fused_partition[0].input_nodes) == 2: - inputs = [(node, 0), (node, 1)] + inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))] return PartitionAnchors( inputs=inputs, @@ -242,13 +253,15 @@ def get_anchors( if quantized_input is not None: inputs = [] for idx, _ in enumerate(node.args[0]): - inputs.append((node, (0, idx), SharedQuantizationSpec(quantized_input))) + inputs.append( + (node, NodeArgsIdx(0, idx), SharedQuantizationSpec(quantized_input)) + ) outputs = [(node, SharedQuantizationSpec(quantized_input))] else: # No previous node was quantized => we are not able to share q-params. The conversion to IR will have to # re-quantize the inputs if necessary. - inputs = [(node, (0, idx)) for idx in range(len(node.args[0]))] + inputs = [(node, NodeArgsIdx(0, idx)) for idx in range(len(node.args[0]))] outputs = [(node,)] return PartitionAnchors( @@ -260,11 +273,11 @@ def get_anchors( class Conv1dPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.conv1d.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv1d_node = fused_partition[0].nodes[-1] @@ -284,11 +297,11 @@ def get_anchors( # Keep bias empty if not supplied bias = [] if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: - bias = [(conv1d_node, 2, bias_qspec)] + bias = [(conv1d_node, NodeArgsIdx(2), bias_qspec)] return PartitionAnchors( - inputs=[(conv1d_node, 0)], - weights=[(conv1d_node, 1)], + inputs=[(conv1d_node, NodeArgsIdx(0))], + weights=[(conv1d_node, NodeArgsIdx(1))], # pyre-fixme[6]: Incompatible parameter type biases=bias, output=[(conv1d_node,)], @@ -296,11 +309,11 @@ def get_anchors( class Conv2dPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.conv2d.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv2d_node = fused_partition[0].nodes[-1] @@ -320,11 +333,11 @@ def get_anchors( # Keep bias empty if not supplied bias = [] if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: - bias = [(conv2d_node, 2, bias_qspec)] + bias = [(conv2d_node, NodeArgsIdx(2), bias_qspec)] return PartitionAnchors( - inputs=[(conv2d_node, 0)], - weights=[(conv2d_node, 1)], + inputs=[(conv2d_node, NodeArgsIdx(0))], + weights=[(conv2d_node, NodeArgsIdx(1))], # pyre-fixme[6]: Incompatible parameter type biases=bias, output=[(conv2d_node,)], @@ -359,12 +372,12 @@ def partition_types(self): return [torch.ops.aten.hardtanh.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[(node,)], @@ -384,12 +397,12 @@ def partition_types(self): return [torch.ops.aten.hardtanh_.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[(node,)], @@ -400,11 +413,11 @@ def replacement_op(self): class LinearPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.linear.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... linear_node = fused_partition[0].nodes[-1] @@ -424,11 +437,11 @@ def get_anchors( # Keep bias empty if not supplied bias = [] if len(linear_node.args) > 2: - bias = [(linear_node, 2, bias_qspec)] + bias = [(linear_node, NodeArgsIdx(2), bias_qspec)] return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], + inputs=[(linear_node, NodeArgsIdx(0))], + weights=[(linear_node, NodeArgsIdx(1))], # pyre-fixme[6]: Incompatible parameter type biases=bias, output=[(linear_node,)], @@ -515,7 +528,7 @@ class SoftMaxPattern(QuantizationPattern): The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8. """ - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.softmax.int] def get_anchors( @@ -569,11 +582,11 @@ class SigmoidPattern(QuantizationPattern): The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8. """ - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.sigmoid.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_fixed_quant_specs( fused_partition, scale=1.0 / 256.0, zero_point=-128 From 2c82054785d314e3913d4a54c2e92b51039a615e Mon Sep 17 00:00:00 2001 From: Lukas Sztefek Date: Wed, 18 Jun 2025 16:18:26 +0200 Subject: [PATCH 2/5] NXP backend: Add support for per-channel quantization for Conv --- .../nxp/backend/edge_program_converter.py | 4 +- .../ops_converters/__init__.py | 6 +- .../qdq_dequantize_converter.py | 38 ++++- backends/nxp/quantizer/patterns.py | 5 +- backends/nxp/tests/executorch_pipeline.py | 11 +- .../nxp/tests/test_per_channel_conversion.py | 153 ++++++++++++++++++ backends/nxp/tests/test_removing_dead_code.py | 4 +- .../nxp/tests/test_split_group_convolution.py | 6 +- 8 files changed, 211 insertions(+), 16 deletions(-) create mode 100644 backends/nxp/tests/test_per_channel_conversion.py diff --git a/backends/nxp/backend/edge_program_converter.py b/backends/nxp/backend/edge_program_converter.py index ddbbf5b2e3a..192798c151e 100644 --- a/backends/nxp/backend/edge_program_converter.py +++ b/backends/nxp/backend/edge_program_converter.py @@ -134,6 +134,7 @@ def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContex qdq_related_functions = [ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, ] @@ -203,7 +204,8 @@ def _convert_qdq_cluster_q_dq_nodes( :param conversion_context: ConversionContext instance. """ qdq_q_ops_converters = { - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, # noqa F405 + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQPerTensorDequantizeConverter, # noqa F405 + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: QDQPerChannelDequantizeConverter, # noqa F405 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, # noqa F405 } diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py index d1674e16a9f..472a3495e19 100755 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py @@ -41,7 +41,8 @@ PermuteCopyConverter, ) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_dequantize_converter import ( - QDQDequantizeConverter, + QDQPerChannelDequantizeConverter, + QDQPerTensorDequantizeConverter, ) from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.qdq_quantize_converter import ( QDQQuantizeConverter, @@ -70,7 +71,8 @@ "PermuteCopyConverter", "SoftmaxConverter", "ViewCopyConverter", - "QDQDequantizeConverter", + "QDQPerTensorDequantizeConverter", + "QDQPerChannelDequantizeConverter", "QDQQuantizeConverter", "ConstantPadNDConverter", "ReLUConverter", diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py index c6ea7f90042..1d7c6b44627 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/qdq_dequantize_converter.py @@ -2,6 +2,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod import numpy as np @@ -19,7 +20,15 @@ from torch.nn import Parameter -class QDQDequantizeConverter(NodeConverter): +class QDQDequantizeConverterBase(NodeConverter, ABC): + + @abstractmethod + def get_zero_point(self, node: Node) -> np.ndarray: + pass + + @abstractmethod + def get_scale(self, node: Node) -> np.ndarray: + pass @staticmethod def _is_supported_in_IR( @@ -27,7 +36,7 @@ def _is_supported_in_IR( parameters_mapping: dict[str, Parameter], custom_delegation_options: CustomDelegationOptions, ) -> bool: - zero_point_type = torch_type_to_numpy_type(node.args[5]) + zero_point_type = torch_type_to_numpy_type(node.args[-1]) if "cluster" not in node.meta or zero_point_type not in [np.int8, np.int32]: return False @@ -39,10 +48,8 @@ def convert(self, node: Node): from_tensor = self.builder.tensor_for_name(node.name) to_tensor = self.builder.tensor_for_name(node.args[0].name) - zero_point_type = torch_type_to_numpy_type(node.args[5]) - - scale = np.array(node.args[1], dtype=np.float32) - zero_point = np.array(node.args[2], dtype=zero_point_type) + scale = self.get_scale(node) + zero_point = self.get_zero_point(node) if self.context.parameters_mapping.get(node.args[0].name, None) is None: # Convert dequantize as identity op (Transpose that will be removed) because @@ -63,3 +70,22 @@ def convert(self, node: Node): # Change type so we pass check tensor similarity check when redirecting from_tensor.type = to_tensor.type self.builder.redirect_tensor(from_tensor, to_tensor) + + +class QDQPerTensorDequantizeConverter(QDQDequantizeConverterBase): + + def get_zero_point(self, node: Node) -> np.ndarray: + zero_point_type = torch_type_to_numpy_type(node.args[5]) + return np.array(node.args[2], dtype=zero_point_type) + + def get_scale(self, node: Node) -> np.ndarray: + return np.array(node.args[1], dtype=np.float32) + + +class QDQPerChannelDequantizeConverter(QDQDequantizeConverterBase): + + def get_zero_point(self, node: Node) -> np.ndarray: + return self.context.parameters_mapping[node.args[2].name].numpy() + + def get_scale(self, node: Node) -> np.ndarray: + return self.context.parameters_mapping[node.args[1].name].numpy() diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 1608c75c412..9b23a617d60 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -16,6 +16,7 @@ from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, + QuantizationSpec, SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY @@ -54,7 +55,9 @@ class PartitionAnchors: tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec], ] = field(default_factory=list) - weights: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + weights: list[ + tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, QuantizationSpec], + ] = field(default_factory=list) biases: list[ tuple[fx.Node, NodeArgsIdx] | tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec], diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index f2f625ad0c8..c675586a057 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -38,9 +38,9 @@ class ModelInputSpec: dtype: torch.dtype = torch.float32 -def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]): - quantizer = NeutronQuantizer() - +def _quantize_model( + model, quantizer, calibration_inputs: list[tuple[torch.Tensor, ...]] +): m = prepare_pt2e(model, quantizer) for data in calibration_inputs: m(*data) @@ -91,6 +91,7 @@ def to_quantized_edge_program( neutron_converter_flavor="SDK_25_06", remove_quant_io_ops=False, custom_delegation_options=CustomDelegationOptions(), # noqa B008 + get_quantizer_fn=lambda: NeutronQuantizer(), ) -> EdgeProgramManager: calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec)) @@ -102,7 +103,9 @@ def to_quantized_edge_program( exir_program_aten = torch.export.export(model, example_input, strict=True) exir_program_aten__module_quant = _quantize_model( - exir_program_aten.module(), calibration_inputs + exir_program_aten.module(), + get_quantizer_fn(), + calibration_inputs, ) edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) diff --git a/backends/nxp/tests/test_per_channel_conversion.py b/backends/nxp/tests/test_per_channel_conversion.py new file mode 100644 index 00000000000..6fc724b9254 --- /dev/null +++ b/backends/nxp/tests/test_per_channel_conversion.py @@ -0,0 +1,153 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import kgb +import numpy as np +import torch + +from executorch.backends.nxp.backend.edge_program_converter import ( + EdgeProgramToIRConverter, +) +from executorch.backends.nxp.quantizer.neutron_quantizer import ( + act_qspec, + NeutronAtenQuantizer, + wgt_qspec, +) +from executorch.backends.nxp.quantizer.patterns import ( + NodeArgsIdx, + PartitionAnchors, + QuantizationPattern, +) +from executorch.backends.nxp.quantizer.utils import get_bias_qparams +from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program +from executorch.backends.nxp.tests.executors import ( + convert_run_compare, + ToChannelFirstPreprocess, + ToChannelLastPreprocess, +) +from executorch.backends.nxp.tests.models import Conv2dModule +from executorch.backends.nxp.tests.test_quantizer import _get_target_name + +from torch import fx +from torch._ops import OpOverload +from torch.export import ExportedProgram +from torchao.quantization.pt2e import MinMaxObserver, PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationConfig, + QuantizationSpec, +) + + +class Conv2dPatternPerChannel(QuantizationPattern): + + def __init__(self, is_per_channel: bool): + super().__init__() + self.is_per_channel = is_per_channel + + def partition_types(self) -> list[OpOverload]: + return [torch.ops.aten.conv2d.default] + + def get_anchors( + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors: + conv2d_node = fused_partition[0].nodes[-1] + + bias_qscheme = ( + torch.per_channel_symmetric + if self.is_per_channel + else torch.per_tensor_symmetric + ) + bias_quantization_qspec = DerivedQuantizationSpec( + derived_from=[ + (conv2d_node.args[0], conv2d_node), + (conv2d_node.args[1], conv2d_node), + ], + derive_qparams_fn=get_bias_qparams, + dtype=torch.int32, + quant_min=-(2**31) + 1, + quant_max=2**31 - 1, + qscheme=bias_qscheme, + ch_axis=0, + ) + + weight_qscheme = ( + torch.per_channel_symmetric + if self.is_per_channel + else torch.per_tensor_symmetric + ) + weight_observer_or_fake_quant_ctr = ( + PerChannelMinMaxObserver if self.is_per_channel else MinMaxObserver + ) + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=weight_qscheme, + ch_axis=0, + ) + + return PartitionAnchors( + inputs=[(conv2d_node, NodeArgsIdx(0))], + weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)], + biases=[(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)], + output=[(conv2d_node,)], + ) + + +class TestPerChannelConversion(unittest.TestCase): + __test__ = False # Prevent interfering with PyTest tests + + def test_per_channel_convolution(self): + with kgb.spy_on( + EdgeProgramToIRConverter.convert_program, call_original=True + ) as converter_spy: + model = Conv2dModule( + in_channels=8, out_channels=32, kernel_size=5, padding=3 + ) + input_shape = (1, 8, 32, 32) + + static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) + _ = to_quantized_edge_program( + model, + input_shape, + get_quantizer_fn=lambda: NeutronAtenQuantizer( + Conv2dPatternPerChannel(is_per_channel=True), static_qconfig + ), + ) + + tflite_flatbuffers_model, io_formats = converter_spy.calls[-1].return_value + exported_program: ExportedProgram = converter_spy.calls[-1].args[0] + + input_data = (np.random.random(input_shape).astype(np.float32) * 50).astype( + np.int8 + ) + + convert_run_compare( + exported_program, + tflite_input_preprocess=ToChannelLastPreprocess(), + tfl_model=tflite_flatbuffers_model, + tflite_output_preprocess=ToChannelFirstPreprocess(), + input_data=input_data, + atol=1.0, + ) + + nodes = list(exported_program.graph.nodes) + + assert _get_target_name(nodes[8]).endswith( + "quantized_decomposed.dequantize_per_channel.default" + ) + assert _get_target_name(nodes[9]).endswith( + "quantized_decomposed.dequantize_per_channel.default" + ) + assert nodes[10].name == "aten_convolution_default" + + @classmethod + def setUpClass(cls): + torch.manual_seed(25) + np.random.seed(25) diff --git a/backends/nxp/tests/test_removing_dead_code.py b/backends/nxp/tests/test_removing_dead_code.py index 7b8641fb247..f5ce1211eb4 100644 --- a/backends/nxp/tests/test_removing_dead_code.py +++ b/backends/nxp/tests/test_removing_dead_code.py @@ -9,6 +9,7 @@ import pytest import torch +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer from executorch.backends.nxp.tests.executorch_pipeline import _quantize_model from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops @@ -45,8 +46,9 @@ def test_removing_dead_code(self): ) # The `NeutronQuantizer` should remove the dead code in the `transform_for_annotation()` method. + quantizer = NeutronQuantizer() exir_program_aten_quant = _quantize_model( - exir_program_aten.module(), [example_inputs] + exir_program_aten.module(), quantizer, [example_inputs] ) # Make sure the is no `add` operation in the graph anymore. diff --git a/backends/nxp/tests/test_split_group_convolution.py b/backends/nxp/tests/test_split_group_convolution.py index 1da53af794d..b908c850f53 100644 --- a/backends/nxp/tests/test_split_group_convolution.py +++ b/backends/nxp/tests/test_split_group_convolution.py @@ -17,6 +17,7 @@ ) from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec +from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer from executorch.backends.nxp.tests.executorch_pipeline import ( _quantize_model, get_random_calibration_inputs, @@ -39,8 +40,11 @@ def _quantize_and_lower_module( module: GraphModule, input_shape: tuple[int, ...], target="imxrt700" ) -> EdgeProgramManager: calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_shape)) + quantizer = NeutronQuantizer() - exir_program_aten__module_quant = _quantize_model(module, calibration_inputs) + exir_program_aten__module_quant = _quantize_model( + module, quantizer, calibration_inputs + ) edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) edge_program_manager = export_to_edge( From b340ad1eef4fc2c8477e191f029a8f57e873a83e Mon Sep 17 00:00:00 2001 From: Lukas Sztefek Date: Fri, 20 Jun 2025 11:37:05 +0200 Subject: [PATCH 3/5] NXP backend: Use per-channel quantization for Conv in NeutronQuantizer --- .../ops_converters/convolution_converter.py | 4 ++ backends/nxp/quantizer/patterns.py | 24 +++++-- backends/nxp/quantizer/utils.py | 2 +- .../node_converter/test_hardtanh_converter.py | 2 +- .../node_converter/test_mean_dim_converter.py | 1 + .../node_converter/test_tanh_converter.py | 2 +- backends/nxp/tests/test_batch_norm_fusion.py | 2 +- .../nxp/tests/test_qdq_clustering_conv.py | 12 ++-- backends/nxp/tests/test_quantizer.py | 67 ++++++++++--------- 9 files changed, 66 insertions(+), 50 deletions(-) diff --git a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py index 0f3a4b9bb5a..8955b4c8fd4 100644 --- a/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py +++ b/backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py @@ -321,6 +321,10 @@ def _convert_2d_conv( t_op.tmp_inputs[1] = self.builder.create_transposed_tensor( weight_tensor, perm ) + + if t_op.tmp_inputs[1].quantization is not None: + # Model is quantized + t_op.tmp_inputs[1].quantization.quantized_dimension = 3 else: raise NotImplementedError("Dynamic Depthwise Conv weights.") diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 9b23a617d60..e2d6f6dc9ea 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -13,6 +13,7 @@ from executorch.backends.nxp.quantizer.utils import get_bias_qparams from torch import fx from torch._ops import OpOverload +from torchao.quantization.pt2e import PerChannelMinMaxObserver from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, @@ -318,30 +319,39 @@ def partition_types(self) -> list[OpOverload]: def get_anchors( self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv2d_node = fused_partition[0].nodes[-1] - bias_qspec = DerivedQuantizationSpec( + bias_quantization_qspec = DerivedQuantizationSpec( derived_from=[ (conv2d_node.args[0], conv2d_node), (conv2d_node.args[1], conv2d_node), ], derive_qparams_fn=get_bias_qparams, dtype=torch.int32, - quant_min=-(2**31), + quant_min=-(2**31) + 1, quant_max=2**31 - 1, - qscheme=torch.per_tensor_affine, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + ) + + weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, ) # Keep bias empty if not supplied bias = [] if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: - bias = [(conv2d_node, NodeArgsIdx(2), bias_qspec)] + bias = [(conv2d_node, NodeArgsIdx(2), bias_quantization_qspec)] return PartitionAnchors( inputs=[(conv2d_node, NodeArgsIdx(0))], - weights=[(conv2d_node, NodeArgsIdx(1))], - # pyre-fixme[6]: Incompatible parameter type + weights=[(conv2d_node, NodeArgsIdx(1), weight_quantization_spec)], biases=bias, output=[(conv2d_node,)], ) diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index ed94183c2db..12c722a8ab3 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -49,7 +49,7 @@ def get_bias_qparams( act_scale, _ = obs_or_fqs[0].calculate_qparams() weight_scale, _ = obs_or_fqs[1].calculate_qparams() bias_scale = act_scale * weight_scale - bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) + bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int64) return bias_scale, bias_zero_point diff --git a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py index e17868d16e2..c4bc559817b 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_hardtanh_converter.py @@ -57,7 +57,7 @@ def test_relu6_quant(mocker, input_shape: tuple[int], inplace: bool): tflite_input_preprocess=ToNHWCPreprocess(), tflite_output_preprocess=ToNCHWPreprocess(), input_data=input_data, - atol=1.0, + atol=2.0, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py index 0032eae5c1a..a634416f8a7 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_mean_dim_converter.py @@ -49,6 +49,7 @@ def test_mean_dim_conv_quant_conversion(mocker, input_shape, dim, keeepdim=True) input_data=input_data, tflite_output_preprocess=ToChannelFirstPreprocess(), tfl_model=tflite_flatbuffers_model, + atol=1.0, ) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index 40857d18eb8..86db5685604 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -76,7 +76,7 @@ def test_conv_tanh( tflite_input_preprocess=ToChannelLastPreprocess(), tflite_output_preprocess=ToChannelFirstPreprocess(), input_data=input_data, - atol=1.0, + atol=2.0, ) @classmethod diff --git a/backends/nxp/tests/test_batch_norm_fusion.py b/backends/nxp/tests/test_batch_norm_fusion.py index 3f1106c6d24..788d04c6dad 100644 --- a/backends/nxp/tests/test_batch_norm_fusion.py +++ b/backends/nxp/tests/test_batch_norm_fusion.py @@ -168,7 +168,7 @@ def test_batch_norm_conv_fusing__full_pipeline__1d(bias: bool): nodes = list(edge_program.graph.nodes) assert ( - len(nodes) == 13 + len(nodes) == 17 ) # 1D Conv currently isn't delegated, because it doesn't get quantized. assert not any( node.op == "call_function" and "batch_norm" in node.target.__name__ diff --git a/backends/nxp/tests/test_qdq_clustering_conv.py b/backends/nxp/tests/test_qdq_clustering_conv.py index 1713aace1fe..ffae931dbb4 100644 --- a/backends/nxp/tests/test_qdq_clustering_conv.py +++ b/backends/nxp/tests/test_qdq_clustering_conv.py @@ -16,13 +16,13 @@ def test_conv2d_partitioner(): lowered_module = edge_program.exported_program().graph_module.lowered_module_0 nodes = list(lowered_module.original_module.graph.nodes) - assert len(nodes) == 7 + assert len(nodes) == 9 - q_x_node = nodes[1] - dq_w_node = nodes[2] - dq_x_node = nodes[3] - conv_node = nodes[4] - q_y_node = nodes[5] + q_x_node = nodes[3] + dq_w_node = nodes[4] + dq_x_node = nodes[5] + conv_node = nodes[6] + q_y_node = nodes[7] assert "cluster" not in q_x_node.meta assert dq_w_node.meta["cluster"] == "aten_convolution_default_cluster" diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index ef5fbb0cbca..624e350ed21 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -1,4 +1,4 @@ -# Copyright 2024 NXP +# Copyright 2024-2025 NXP # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -34,26 +34,26 @@ def test_quantizer_conv2d(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 11 - assert nodes[7].name == "conv2d" + assert len(nodes) == 15 + assert nodes[11].name == "conv2d" # [0]: Input, [1] : weights, [2]: bias assert ( - _get_target_name(nodes[7].args[0]) + _get_target_name(nodes[11].args[0]) == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" ) assert ( - _get_target_name(nodes[7].args[1]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + _get_target_name(nodes[11].args[1]) + == "torch.ops.quantized_decomposed.dequantize_per_channel.default" ) assert ( - _get_target_name(nodes[7].args[2]) - == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" + _get_target_name(nodes[11].args[2]) + == "torch.ops.quantized_decomposed.dequantize_per_channel.default" ) assert ( - _get_target_name(nodes[8]) + _get_target_name(nodes[12]) == "torch.ops.quantized_decomposed.quantize_per_tensor.default" ) - assert nodes[8].args[0].name == "conv2d" + assert nodes[12].args[0].name == "conv2d" def test_quantizer_linear(): @@ -112,22 +112,22 @@ def test_quantizer_maxpool2d(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 14 + assert len(nodes) == 18 # Check if QDQ pattern: - assert nodes[10].name == "max_pool2d" + assert nodes[14].name == "max_pool2d" assert ( - _get_target_name(nodes[10].args[0]) + _get_target_name(nodes[14].args[0]) == "torch.ops.quantized_decomposed.dequantize_per_tensor.default" ) assert ( - _get_target_name(nodes[11]) + _get_target_name(nodes[15]) == "torch.ops.quantized_decomposed.quantize_per_tensor.default" ) - assert nodes[11].args[0].name == "max_pool2d" + assert nodes[15].args[0].name == "max_pool2d" # Check if input and output quantization is same - input_quant = nodes[10].args[0].args[1:] - output_quant = nodes[11].args[1:] + input_quant = nodes[14].args[0].args[1:] + output_quant = nodes[15].args[1:] assert input_quant == output_quant @@ -207,10 +207,10 @@ def test_quantizer_conv2d_relu(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 12 - assert nodes[7].name == "dequantize_per_tensor_default_2" - assert nodes[8].name == "relu" - assert nodes[9].name == "quantize_per_tensor_default_3" + assert len(nodes) == 14 + assert nodes[9].name == "dequantize_per_tensor_default_1" + assert nodes[10].name == "relu" + assert nodes[11].name == "quantize_per_tensor_default_2" def test_quantizer_conv2d_avg_pool2d(): @@ -230,10 +230,10 @@ def test_quantizer_conv2d_avg_pool2d(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 14 - assert nodes[9].name == "dequantize_per_tensor_default_3" - assert nodes[10].name == "avg_pool2d" - assert nodes[11].name == "quantize_per_tensor_default_4" + assert len(nodes) == 18 + assert nodes[13].name == "dequantize_per_tensor_default_1" + assert nodes[14].name == "avg_pool2d" + assert nodes[15].name == "quantize_per_tensor_default_2" def test_quantizer_conv2d_permute(): @@ -253,10 +253,11 @@ def test_quantizer_conv2d_permute(): m(*example_input) nodes = list(m.graph.nodes) - assert len(nodes) == 12 - assert nodes[7].name == "dequantize_per_tensor_default_2" - assert nodes[8].name == "permute" - assert nodes[9].name == "quantize_per_tensor_default_3" + + assert len(nodes) == 14 + assert nodes[9].name == "dequantize_per_tensor_default_1" + assert nodes[10].name == "permute" + assert nodes[11].name == "quantize_per_tensor_default_2" def test_multiple_shared_spec_ops_in_row(): @@ -281,15 +282,15 @@ def test_multiple_shared_spec_ops_in_row(): nodes = list(m.graph.nodes) - assert len(nodes) == 15 - assert nodes[-5].name == "dequantize_per_tensor_default_3" + assert len(nodes) == 17 + assert nodes[-5].name.startswith("dequantize_per_tensor_default") assert nodes[-4].name == "max_pool2d" - assert nodes[-3].name == "quantize_per_tensor_default_4" + assert nodes[-3].name.startswith("quantize_per_tensor_default") # Assert that post-ReLU quantize and pre-MaxPool dequantize has same specs assert nodes[-6].args[1:] == nodes[-5].args[1:] # Assert that post-Conv quantize and pre-ReLU dequantize has same specs - assert nodes[6].args[1:] == nodes[7].args[1:] + assert nodes[5].args[1:] == nodes[6].args[1:] def test_quantizers_order_invariance(): From 82e6032c45955353b9cb7387a63a382e49e88f8f Mon Sep 17 00:00:00 2001 From: Lukas Sztefek Date: Thu, 3 Jul 2025 09:13:27 +0200 Subject: [PATCH 4/5] NXP backend: Print information about max error during the output tensor comparison --- backends/nxp/tests/executors.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backends/nxp/tests/executors.py b/backends/nxp/tests/executors.py index afdb15af106..592717c0b3b 100644 --- a/backends/nxp/tests/executors.py +++ b/backends/nxp/tests/executors.py @@ -196,6 +196,11 @@ def compare_output_arrays( assert tfl_output.shape == edge_output.shape, "Output shapes don't match!" + if (max_diff := np.abs(np.max(tfl_output - edge_output))) > 0.0: + logger.w( + f"Maximum absolute difference of the tensor '{output_name}': '{max_diff}'" + ) + assert np.allclose( tfl_output, edge_output, rtol=rtol, atol=atol, equal_nan=True ), f"Output values of the `{output_name}` tensor don't match!" From 1a3d9e5d4cbfc912b845c3a6a181f9c3133a6d1e Mon Sep 17 00:00:00 2001 From: Simon Strycek Date: Thu, 11 Sep 2025 08:36:20 +0200 Subject: [PATCH 5/5] NXP backend: Make setUpClass placement consistent across unit tests --- .../ir/converter/node_converter/test_tanh_converter.py | 10 +++++----- backends/nxp/tests/test_per_channel_conversion.py | 10 +++++----- backends/nxp/tests/test_removing_dead_code.py | 10 +++++----- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py index 86db5685604..bb4500bc1e2 100644 --- a/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py +++ b/backends/nxp/tests/ir/converter/node_converter/test_tanh_converter.py @@ -27,6 +27,11 @@ class TestTanhConverter(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(23) + @parameterized.expand( input=[ ( @@ -78,8 +83,3 @@ def test_conv_tanh( input_data=input_data, atol=2.0, ) - - @classmethod - def setUpClass(cls): - torch.manual_seed(23) - np.random.seed(23) diff --git a/backends/nxp/tests/test_per_channel_conversion.py b/backends/nxp/tests/test_per_channel_conversion.py index 6fc724b9254..043ba8fc001 100644 --- a/backends/nxp/tests/test_per_channel_conversion.py +++ b/backends/nxp/tests/test_per_channel_conversion.py @@ -103,6 +103,11 @@ def get_anchors( class TestPerChannelConversion(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests + @classmethod + def setUpClass(cls): + torch.manual_seed(25) + np.random.seed(25) + def test_per_channel_convolution(self): with kgb.spy_on( EdgeProgramToIRConverter.convert_program, call_original=True @@ -146,8 +151,3 @@ def test_per_channel_convolution(self): "quantized_decomposed.dequantize_per_channel.default" ) assert nodes[10].name == "aten_convolution_default" - - @classmethod - def setUpClass(cls): - torch.manual_seed(25) - np.random.seed(25) diff --git a/backends/nxp/tests/test_removing_dead_code.py b/backends/nxp/tests/test_removing_dead_code.py index f5ce1211eb4..cc51746c81c 100644 --- a/backends/nxp/tests/test_removing_dead_code.py +++ b/backends/nxp/tests/test_removing_dead_code.py @@ -33,6 +33,11 @@ def forward(self, x): class TestRemovingDeadCode(unittest.TestCase): __test__ = False # Prevent interfering with PyTest tests + @classmethod + def setUpClass(cls): + torch.manual_seed(23) + np.random.seed(23) + def test_removing_dead_code(self): input_shape = (42,) example_inputs = (torch.ones(input_shape),) @@ -55,8 +60,3 @@ def test_removing_dead_code(self): assert not any( "add" in str(node.target) for node in exir_program_aten_quant.graph.nodes ) - - @classmethod - def setUpClass(cls): - torch.manual_seed(23) - np.random.seed(23)