Skip to content

Commit

Permalink
Infer symmetry from offset in sim.onnx.export (#3833)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Feb 25, 2025
1 parent 934265c commit 8fb89a3
Showing 1 changed file with 14 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
"""Utility APIs for onnx export"""

from contextlib import contextmanager, ExitStack
import functools
from collections import defaultdict
import functools
import math
from typing import Sequence, Iterable

import onnx
Expand Down Expand Up @@ -287,6 +288,7 @@ def _get_encoding_from_onnx_node(onnx_model: onnx.ModelProto, quant_node: onnx.N
from aimet_torch.v2.quantization.affine.encoding import AffineEncoding
assert quant_node.op_type in ONNX_QUANTIZER_OP_TYPES

scale, offset = None, None
qmin, qmax, block_size = None, None, None
scale_name, offset_name = quant_node.input[1], quant_node.input[2]

Expand All @@ -303,7 +305,17 @@ def _get_encoding_from_onnx_node(onnx_model: onnx.ModelProto, quant_node: onnx.N
scale = torch.tensor(_get_tensor_from_constant_name(onnx_model, scale_name))
offset = torch.tensor(_get_tensor_from_constant_name(onnx_model, offset_name))

return AffineEncoding(scale, offset, qmin, qmax, block_size=block_size)
assert scale is not None
assert offset is not None

if scale.numel() == 1 and offset.numel() == 1:
scale = scale.squeeze()
offset = offset.squeeze()

centroid = math.ceil((qmin + qmax) / 2)
symmetry = bool(torch.all(offset == -centroid))

return AffineEncoding(scale, offset, qmin, qmax, symmetry=symmetry, block_size=block_size)


def _remove_constants(onnx_model: onnx.ModelProto, constant_names: Iterable[str]):
Expand Down

0 comments on commit 8fb89a3

Please sign in to comment.