Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 116 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
COMPLEX128,
DOUBLE,
FLOAT,
FLOAT16,
INT8,
INT16,
INT32,
Expand Down Expand Up @@ -3317,17 +3318,58 @@ def aten_eye(n: int) -> TensorType:
raise NotImplementedError()


@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True)
def aten_fake_quantize_per_channel_affine(
self: TensorType,
scale: TensorType,
zero_point: TensorType,
self: TFloat,
scale: FLOAT, # float32 specifically!
zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only!
axis: int,
quant_min: int,
quant_max: int,
) -> TensorType:
"""fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor"""

raise NotImplementedError()
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
raise NotImplementedError(
"For (quant_min, quant_max), ONNX allows only "
"(0, 127), (0, 255) and (-128, 127). "
f"Got ({quant_min}, {quant_max})",
)

if quant_min == 0:
int_dtype = ir.DataType.UINT8
else:
int_dtype = ir.DataType.INT8

# TODO: When opset >= 19, remove this cast
orig_dtype = self.type.dtype
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
self = op.Cast(self, to=ir.DataType.FLOAT)

if zero_point.type.dtype == ir.DataType.INT32:
zero_point = op.Cast(zero_point, to=int_dtype)
else:
raise NotImplementedError(
"ONNX only supports integer values for the zero_point parameter. "
f"Got {zero_point.type.dtype}",
)

quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis)

# See comment about, PyTorch-specific (0, 127) handling
if (quant_min, quant_max) == (0, 127):
const_127 = op.Cast(127, to=int_dtype)
quantized = op.Clip(quantized, max=const_127)

output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis)

# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
if orig_dtype != ir.DataType.FLOAT:
output = op.Cast(output, to=orig_dtype)

return output


def aten_fake_quantize_per_channel_affine_cachemask(
Expand All @@ -3351,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward(
raise NotImplementedError()


@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True)
def aten_fake_quantize_per_tensor_affine(
self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int
) -> TensorType:
self: TFloat,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
) -> TFloat:
"""fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor"""

raise NotImplementedError()
return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)


@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True)
def aten_fake_quantize_per_tensor_affine_tensor_qparams(
self: TFloat,
scale: TReal,
zero_point: TReal,
quant_min: int,
quant_max: int,
) -> TFloat:
"""fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor"""

return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max)


def _aten_fake_quantize_per_tensor_affine(
self: TFloat,
scale: Union[float, TReal],
zero_point: Union[int, TReal],
quant_min: int,
quant_max: int,
) -> TFloat:
# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
raise NotImplementedError(
"For (quant_min, quant_max), ONNX allows only "
"(0, 127), (0, 255) and (-128, 127). "
f"Got ({quant_min}, {quant_max})",
)

if quant_min == 0:
int_dtype = ir.DataType.UINT8
else:
int_dtype = ir.DataType.INT8

# TODO: When opset >= 19, remove this cast
orig_dtype = self.type.dtype
if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}:
self = op.Cast(self, to=ir.DataType.FLOAT)

# TODO: When opset >= 19, relex the condition for this cast
if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT:
scale = op.Cast(scale, to=ir.DataType.FLOAT)

if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype:
zero_point = op.Cast(zero_point, to=int_dtype)

quantized = op.QuantizeLinear(self, scale, zero_point)

# See comment about, PyTorch-specific (0, 127) handling
if (quant_min, quant_max) == (0, 127):
const_127 = op.Cast(127, to=int_dtype)
quantized = op.Clip(quantized, max=const_127)

output = op.DequantizeLinear(quantized, scale, zero_point)

# TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear
if orig_dtype != ir.DataType.FLOAT:
output = op.Cast(output, to=orig_dtype)

return output


def aten_fake_quantize_per_tensor_affine_cachemask(
Expand Down
119 changes: 119 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,109 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_):
)


def sample_inputs_fake_quantize_per_tensor_affine(
op_info, device, dtype, requires_grad, **kwargs
):
del op_info, kwargs # Unused
make_arg = functools.partial(
opinfo_core.make_tensor,
device=device,
requires_grad=requires_grad,
)

# Test 1D, empty and scalar tensors (like sample_inputs_elementwise_unary)
shapes = [
(S,),
(1, 0, 3),
(),
]

scale_zero_point_dtypes = [
# default (float, int)
(None, None)
] + [
# tensor_qparams (tensor, tensor)
(t1, t2)
for t1 in common_dtype.all_types_and()
for t2 in common_dtype.all_types_and()
]

# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
quant_vals = [(0, 255), (-128, 127), (0, 127)]

cases = itertools.product(shapes, scale_zero_point_dtypes, quant_vals)
for shape, (scale_dtype, zero_point_dtype), (quant_min, quant_max) in cases:
scale = make_arg(
(),
dtype=scale_dtype or torch.float64,
)
if scale_dtype is None:
scale = scale.item()

zero_point = make_arg(
(),
dtype=zero_point_dtype or torch.int64,
# zero_point must be between quant_min and quant_max
low=quant_min,
high=quant_max,
)
if zero_point_dtype is None:
zero_point = zero_point.item()

args = (scale, zero_point, quant_min, quant_max)
yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args)


def sample_inputs_fake_quantize_per_channel_affine(
op_info, device, dtype, requires_grad, **kwargs
):
del op_info, kwargs # Unused
make_arg = functools.partial(
opinfo_core.make_tensor,
device=device,
requires_grad=requires_grad,
)

# Test 1D, 2D, 4D and empty tensors (scalar tensors not supported)
axes_and_shapes = [
# 1D, 2D, 4D
(axis, (S,) * dims)
for dims in (1, 2, 4)
for axis in range(dims)
] + [
# empty
(0, (1, 0, 3)),
(2, (1, 0, 3)),
# empty channel axis causes an error due to
# an internal zero_point.min() calculation
# (1, (1, 0, 3)),
]

# tensor_qparams
scale_dtype = torch.float
zero_point_dtypes = [torch.int32, torch.float, torch.half]

# NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
# https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
quant_vals = [(0, 255), (-128, 127), (0, 127)]

cases = itertools.product(axes_and_shapes, zero_point_dtypes, quant_vals)
for (axis, shape), zero_point_dtype, (quant_min, quant_max) in cases:
scale = make_arg((shape[axis],), dtype=scale_dtype)

zero_point = make_arg(
(shape[axis],),
dtype=zero_point_dtype or torch.int64,
# zero_point must be between quant_min and quant_max
low=quant_min,
high=quant_max,
)

args = (scale, zero_point, axis, quant_min, quant_max)
yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args)


def _index_variable_bool(shape, max_indices, device):
if not isinstance(shape, tuple):
shape = (shape,)
Expand Down Expand Up @@ -2408,6 +2511,22 @@ def __init__(self):
sample_inputs_func=sample_inputs__fft_r2c,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.fake_quantize_per_tensor_affine",
aten_name="fake_quantize_per_tensor_affine",
op=torch.fake_quantize_per_tensor_affine,
dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_fake_quantize_per_tensor_affine,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten.fake_quantize_per_channel_affine",
aten_name="fake_quantize_per_channel_affine",
op=torch.fake_quantize_per_channel_affine,
dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_fake_quantize_per_channel_affine,
supports_out=False,
),
opinfo_core.BinaryUfuncInfo(
"ops.aten.floor_divide",
aten_name="floor_divide",
Expand Down
11 changes: 11 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,17 @@ def _where_input_wrangler(
TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail(
reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223"
),
TorchLibOpInfo(
"ops.aten.fake_quantize_per_channel_affine",
core_ops.aten_fake_quantize_per_channel_affine,
).xfail(
reason="fixme: ONNX (De)QuantizeLinear only supports integer zero_point values",
matcher=lambda sample: sample.args[1].dtype != torch.int32,
),
TorchLibOpInfo(
"ops.aten.fake_quantize_per_tensor_affine",
core_ops.aten_fake_quantize_per_tensor_affine,
),
TorchLibOpInfo("fill", core_ops.aten_fill),
TorchLibOpInfo("flip", core_ops.aten_flip).skip(
reason="fixme: size 0 inputs are not handled yet",
Expand Down