Skip to content

Commit e6ff711

Browse files
committed
implement onnx conversion for aten::fake_quantize_per_tensor_affine
1 parent 1707e14 commit e6ff711

File tree

1 file changed

+54
-4
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+54
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3392,12 +3392,62 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
33923392
raise NotImplementedError()
33933393

33943394

3395+
@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True)
33953396
def aten_fake_quantize_per_tensor_affine(
3396-
self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int
3397-
) -> TensorType:
3398-
"""fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""
3397+
self: TFloat,
3398+
scale: TReal | float,
3399+
zero_point: TReal | int,
3400+
quant_min: int,
3401+
quant_max: int,
3402+
) -> TFloat:
3403+
"""fake_quantize_per_tensor_affine(Tensor self, Tensor | float scale, Tensor | int zero_point, int quant_min, int quant_max) -> Tensor"""
33993404

3400-
raise NotImplementedError()
3405+
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
3406+
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
3407+
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
3408+
raise NotImplementedError(
3409+
"For (quant_min, quant_max), ONNX allows only "
3410+
"(0, 127), (0, 255) and (-128, 127). "
3411+
f"Got ({quant_min}, {quant_max})",
3412+
)
3413+
3414+
if quant_min == 0:
3415+
int_dtype = ir.DataType.UINT8
3416+
else:
3417+
int_dtype = ir.DataType.INT8
3418+
3419+
# TODO: When opset >= 19, remove this cast
3420+
orig_dtype = self.type.dtype
3421+
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
3422+
self = op.Cast(self, to=ir.DataType.FLOAT)
3423+
3424+
# TODO: When opset >= 19, relex the condition for this cast
3425+
if (
3426+
isinstance(scale, float) or
3427+
scale.type.dtype != ir.DataType.FLOAT
3428+
):
3429+
scale = op.Cast(scale, to=ir.DataType.FLOAT)
3430+
3431+
if (
3432+
isinstance(zero_point, int) or
3433+
zero_point.type.dtype != int_dtype
3434+
):
3435+
zero_point = op.Cast(zero_point, to=int_dtype)
3436+
3437+
quantized = op.QuantizeLinear(self, scale, zero_point)
3438+
3439+
# See comment about, PyTorch-specific (0, 127) handling
3440+
if (quant_min, quant_max) == (0, 127):
3441+
const_127 = op.Cast(127, to=int_dtype)
3442+
quantized = op.Clip(quantized, max=const_127)
3443+
3444+
output = op.DequantizeLinear(quantized, scale, zero_point)
3445+
3446+
# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
3447+
if orig_dtype != ir.DataType.FLOAT:
3448+
output = op.Cast(output, to=orig_dtype)
3449+
3450+
return output
34013451

34023452

34033453
def aten_fake_quantize_per_tensor_affine_cachemask(

0 commit comments

Comments
 (0)