Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 14 additions & 34 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional, Tuple, Union

import torch
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
NeutronAtenPassManager,
Expand All @@ -25,6 +23,7 @@
LinearPattern,
MaxPoolPattern,
MeanDimPattern,
NodeArgsIdx,
PadPattern,
PermutePattern,
QuantizationPattern,
Expand Down Expand Up @@ -102,57 +101,43 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
)

def annotate_inputs(
inputs: Union[
List[Tuple[fx.Node, int]],
List[Tuple[fx.Node, int, DerivedQuantizationSpec],],
],
spec: Optional[QuantizationSpec],
inputs: (
list[tuple[fx.Node, NodeArgsIdx]]
| list[tuple[fx.Node, NodeArgsIdx, DerivedQuantizationSpec]]
),
spec: QuantizationSpec | None,
) -> None:
for node, idx, *custom_spec in inputs:
for node, args_idx, *custom_spec in inputs:
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
arg = (
# pyre-ignore[16]: no attribute
node.args[idx]
if isinstance(idx, int)
node.args[args_idx.idx]
if args_idx.inner_idx is None
# pyre-ignore[16]: no attribute
else node.args[idx[0]][idx[1]]
else node.args[args_idx.idx][args_idx.inner_idx]
)
annotation.input_qspec_map[arg] = (
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta[Q_ANNOTATION_KEY] = annotation

def annotate_weights_or_biases(
weights_or_biases: List[Tuple[fx.Node, int]],
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in weights_or_biases:
annotation = node.meta.get(
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
node.meta[Q_ANNOTATION_KEY] = annotation

# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.inputs, input_act_qspec)
annotate_weights_or_biases(anchors.weights, weight_qspec)
annotate_inputs(anchors.weights, weight_qspec)
# pyre-ignore[6]: incompatible parameter type
annotate_weights_or_biases(anchors.biases, bias_qspec)
annotate_inputs(anchors.biases, bias_qspec)
return model

def validate(self, model: fx.GraphModule) -> None:
pass

@classmethod
def get_supported_operators(cls) -> List[OperatorConfig]:
def get_supported_operators(cls) -> list[OperatorConfig]:
return []


Expand Down Expand Up @@ -191,12 +176,7 @@ def get_supported_operators(cls) -> List[OperatorConfig]:

class NeutronQuantizer(ComposableQuantizer):
def __init__(self):
static_qconfig = QuantizationConfig(
act_qspec,
act_qspec,
wgt_qspec,
None,
)
static_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_qspec, None)
static_fc_qconfig = QuantizationConfig(act_qspec, act_qspec, wgt_fc_qspec, None)
super().__init__(
[
Expand Down
Loading
Loading