diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index 2279c177f59..b849061434b 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Optional, Tuple, Union - import torch from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( NeutronAtenPassManager, @@ -25,6 +23,7 @@ LinearPattern, MaxPoolPattern, MeanDimPattern, + NodeArgsIdx, PadPattern, PermutePattern, QuantizationPattern, @@ -102,13 +101,13 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ) def annotate_inputs( - inputs: Union[ - List[Tuple[fx.Node, int]], - List[Tuple[fx.Node, int, DerivedQuantizationSpec],], - ], - spec: Optional[QuantizationSpec], + inputs: ( + list[tuple[fx.Node, NodeArgsIdx]] + | list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]] + ), + spec: QuantizationSpec | None, ) -> None: - for node, idx, *custom_spec in inputs: + for node, args_idx, *custom_spec in inputs: # pyre-ignore[16]: no attribute annotation = node.meta.get( Q_ANNOTATION_KEY, @@ -116,10 +115,10 @@ def annotate_inputs( ) arg = ( # pyre-ignore[16]: no attribute - node.args[idx] - if isinstance(idx, int) + node.args[args_idx.idx] + if args_idx.inner_idx is None # pyre-ignore[16]: no attribute - else node.args[idx[0]][idx[1]] + else node.args[args_idx.idx][args_idx.inner_idx] ) annotation.input_qspec_map[arg] = ( custom_spec[0] if custom_spec else spec @@ -127,32 +126,18 @@ def annotate_inputs( # pyre-ignore[16]: no attribute node.meta[Q_ANNOTATION_KEY] = annotation - def annotate_weights_or_biases( - weights_or_biases: List[Tuple[fx.Node, int]], - spec: Optional[QuantizationSpec], - ) -> None: - for node, idx, *custom_spec in weights_or_biases: - annotation = node.meta.get( - Q_ANNOTATION_KEY, - QuantizationAnnotation(_annotated=True), - ) - annotation.input_qspec_map[node.args[idx]] = ( - custom_spec[0] if custom_spec else spec - ) - node.meta[Q_ANNOTATION_KEY] = annotation - # pyre-ignore[6]: incompatible parameter type annotate_inputs(anchors.inputs, input_act_qspec) - annotate_weights_or_biases(anchors.weights, weight_qspec) + annotate_inputs(anchors.weights, weight_qspec) # pyre-ignore[6]: incompatible parameter type - annotate_weights_or_biases(anchors.biases, bias_qspec) + annotate_inputs(anchors.biases, bias_qspec) return model def validate(self, model: fx.GraphModule) -> None: pass @classmethod - def get_supported_operators(cls) -> List[OperatorConfig]: + def get_supported_operators(cls) -> list[OperatorConfig]: return [] @@ -191,12 +176,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]: class NeutronQuantizer(ComposableQuantizer): def __init__(self): - static_qconfig = QuantizationConfig( - act_qspec, - act_qspec, - wgt_qspec, - None, - ) + static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None) static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None) super().__init__( [ diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index cf79b539060..c45a67d7809 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Type, Union import torch @@ -22,11 +21,27 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +@dataclass +class NodeArgsIdx: + """ + Specifies indexes to args paramater of Node in node input annotation. + + + Attributes: + idx (int): Index to Node's args paramater (list). Selects an input Node or a list of Nodes at the index. + inner_idx (int): If specified, index to a list pointed by 'idx' attribute. Selects an input Node at the index. + Default: None. + """ + + idx: int + inner_idx: int = None + + @dataclass class PartitionAnchors: """ - All fields except output are lists of (node, args_index) pair, where node is from - the given partition and node.args[args_index] is an input to the partition. Assumes + All fields except output are lists of (node, node_args_idx) or (node, node_args_idx, quantization_spec) tuples, + where node is from the given partition and node.args[node_args_idx] is an input to the partition. Assumes a single output. Quantizer uses inputs, weights and biases for quantization annotation. The others @@ -35,25 +50,21 @@ class PartitionAnchors: """ # Inputs can share quantization parameters - inputs: List[ - Union[ - Tuple[fx.Node, Union[int, Tuple[int, int]]], - Tuple[ - fx.Node, - Union[int, Tuple[int, int]], - SharedQuantizationSpec, - ], - ] + inputs: list[ + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, SharedQuantizationSpec], ] = field(default_factory=list) - weights: List[Tuple[fx.Node, int]] = field(default_factory=list) - biases: List[ - Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]] + weights: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + biases: list[ + tuple[fx.Node, NodeArgsIdx] + | tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec], + ] = field(default_factory=list) + others: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + literals: list[tuple[fx.Node, NodeArgsIdx]] = field(default_factory=list) + output: list[ + tuple[fx.Node] + | tuple[fx.Node, FixedQParamsQuantizationSpec | SharedQuantizationSpec], ] = field(default_factory=list) - others: List[Tuple[fx.Node, int]] = field(default_factory=list) - literals: List[Tuple[fx.Node, int]] = field(default_factory=list) - output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field( - default_factory=list - ) empty: bool = False @@ -67,8 +78,8 @@ def partition_types(self) -> list[OpOverload]: @abstractmethod def get_anchors( - self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule] - ) -> Optional[PartitionAnchors]: + self, gm: torch.fx.GraphModule, fused_partition: list[fx.GraphModule] + ) -> PartitionAnchors | None: pass @@ -80,11 +91,11 @@ class SharedSpecPattern(QuantizationPattern): quantization parameters (scale and zero-point). """ - def partition_types(self) -> List[Type[torch.nn.Module]]: + def partition_types(self) -> list[torch.nn.Module]: pass def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] assert len(fused_partition[0].input_nodes) == 1 @@ -97,7 +108,7 @@ def get_anchors( qspec = SharedQuantizationSpec(prev_node) return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[ @@ -125,11 +136,11 @@ def partition_types(self): class AddmmPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.addmm.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... addmm_node = fused_partition[0].nodes[-1] @@ -147,9 +158,9 @@ def get_anchors( ) return PartitionAnchors( - inputs=[(addmm_node, 1)], - weights=[(addmm_node, 2)], - biases=[(addmm_node, 0, bias_qspec)], + inputs=[(addmm_node, NodeArgsIdx(1))], + weights=[(addmm_node, NodeArgsIdx(2))], + biases=[(addmm_node, NodeArgsIdx(0), bias_qspec)], output=[(addmm_node,)], ) @@ -161,16 +172,16 @@ class AddTensorPattern(QuantizationPattern): Basic quantization for all inputs and output. """ - def partition_types(self) -> List[Type[torch.nn.Module]]: + def partition_types(self) -> list[torch.nn.Module]: return [torch.ops.aten.add.Tensor] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] - inputs = [(node, 0)] + inputs = [(node, NodeArgsIdx(0))] if len(fused_partition[0].input_nodes) == 2: - inputs = [(node, 0), (node, 1)] + inputs = [(node, NodeArgsIdx(0)), (node, NodeArgsIdx(1))] return PartitionAnchors( inputs=inputs, @@ -190,11 +201,11 @@ def partition_types(self): class Conv1dPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.conv1d.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv1d_node = fused_partition[0].nodes[-1] @@ -214,11 +225,11 @@ def get_anchors( # Keep bias empty if not supplied bias = [] if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None: - bias = [(conv1d_node, 2, bias_qspec)] + bias = [(conv1d_node, NodeArgsIdx(2), bias_qspec)] return PartitionAnchors( - inputs=[(conv1d_node, 0)], - weights=[(conv1d_node, 1)], + inputs=[(conv1d_node, NodeArgsIdx(0))], + weights=[(conv1d_node, NodeArgsIdx(1))], # pyre-fixme[6]: Incompatible parameter type biases=bias, output=[(conv1d_node,)], @@ -226,11 +237,11 @@ def get_anchors( class Conv2dPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.conv2d.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... conv2d_node = fused_partition[0].nodes[-1] @@ -250,11 +261,11 @@ def get_anchors( # Keep bias empty if not supplied bias = [] if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None: - bias = [(conv2d_node, 2, bias_qspec)] + bias = [(conv2d_node, NodeArgsIdx(2), bias_qspec)] return PartitionAnchors( - inputs=[(conv2d_node, 0)], - weights=[(conv2d_node, 1)], + inputs=[(conv2d_node, NodeArgsIdx(0))], + weights=[(conv2d_node, NodeArgsIdx(1))], # pyre-fixme[6]: Incompatible parameter type biases=bias, output=[(conv2d_node,)], @@ -289,12 +300,12 @@ def partition_types(self): return [torch.ops.aten.hardtanh.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[(node,)], @@ -314,12 +325,12 @@ def partition_types(self): return [torch.ops.aten.hardtanh_.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors | None: node = fused_partition[0].nodes[-1] return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[(node,)], @@ -330,11 +341,11 @@ def replacement_op(self): class LinearPattern(QuantizationPattern): - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.linear.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge... linear_node = fused_partition[0].nodes[-1] @@ -354,11 +365,11 @@ def get_anchors( # Keep bias empty if not supplied bias = [] if len(linear_node.args) > 2: - bias = [(linear_node, 2, bias_qspec)] + bias = [(linear_node, NodeArgsIdx(2), bias_qspec)] return PartitionAnchors( - inputs=[(linear_node, 0)], - weights=[(linear_node, 1)], + inputs=[(linear_node, NodeArgsIdx(0))], + weights=[(linear_node, NodeArgsIdx(1))], # pyre-fixme[6]: Incompatible parameter type biases=bias, output=[(linear_node,)], @@ -439,7 +450,7 @@ def partition_types(self): def get_anchors_for_softmax_like_operators( - fused_partition: List[fx.GraphModule], + fused_partition: list[fx.GraphModule], ) -> PartitionAnchors: node = fused_partition[0].nodes[-1] assert len(fused_partition[0].input_nodes) == 1 @@ -454,7 +465,7 @@ def get_anchors_for_softmax_like_operators( ) return PartitionAnchors( - inputs=[(node, 0)], + inputs=[(node, NodeArgsIdx(0))], weights=[], biases=[], output=[ @@ -470,11 +481,11 @@ class SoftMaxPattern(QuantizationPattern): The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8. """ - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.softmax.int] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_softmax_like_operators(fused_partition) @@ -486,10 +497,10 @@ class SigmoidPattern(QuantizationPattern): The quantization of Sigmoid output is fixed to scale 1/256, zero point -128, dtype int8. """ - def partition_types(self) -> List[OpOverload]: + def partition_types(self) -> list[OpOverload]: return [torch.ops.aten.sigmoid.default] def get_anchors( - self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule] + self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule] ) -> PartitionAnchors: return get_anchors_for_softmax_like_operators(fused_partition)