diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py index 9c3dd47378c..d4944ed3ef5 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py @@ -920,7 +920,7 @@ def _export_2_0_0_encodings(self) -> Optional[Dict]: "per_channel_float_scale": per_channel_scale.tolist(), "y_zero_point": None, **encodings, - "output_dtype": f"int{decompressed_bw}" if output_dtype.startswith("int") else f"uint{decompressed_bw}" + "output_dtype": f"int{compressed_bw}" if output_dtype.startswith("int") else f"uint{compressed_bw}" } def _fill_mismatching_encoding_settings_info(self, encoding_dict: Optional[dict], encoding_mismatch_info: _EncodingMismatchInfo): diff --git a/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py b/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py index 1138e61f05c..ade7694f5ae 100644 --- a/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py +++ b/TrainingExtensions/onnx/test/python/test_qc_quantize_op.py @@ -1232,7 +1232,6 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale, op = OperatorSetIdProto() op.version = 21 - assert output_dtype in ("int8", "int16", "uint8", "uint16") assert y_zero_point is None x_int_dtype = TensorProto.INT16 if output_dtype == "int16" else \ @@ -1358,7 +1357,7 @@ def test_lpbq_encoding_schema_2_0_0(input_shape, block_axis, block_size, compres assert encoding["y_zero_point"] is None assert encoding["axis"] == block_axis assert encoding["block_size"] == block_size - assert encoding["output_dtype"] == f"int{decompressed_bw}" + assert encoding["output_dtype"] == f"int{compressed_bw}" """ diff --git a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py index 8a02e6d09ea..d61d4e81415 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/v2/quantization/affine/encoding.py @@ -513,7 +513,7 @@ def to_qnn_encoding_dict(self, encoding_version=None) -> Union[List, Dict]: del encoding_dict["y_scale"] del encoding_dict["output_dtype"] - decompressed_bw = self.decompressed_bw + compressed_bw = self.bitwidth y_zero_point = encoding_dict.pop("y_zero_point") if y_zero_point is not None and torch.any(torch.tensor(y_zero_point) != 0): @@ -526,7 +526,7 @@ def to_qnn_encoding_dict(self, encoding_version=None) -> Union[List, Dict]: "per_channel_float_scale": self.per_channel_scale.tolist(), "y_zero_point": None, **encoding_dict, - "output_dtype": f"int{decompressed_bw}" if self.signed else f"uint{decompressed_bw}" + "output_dtype": f"int{compressed_bw}" if self.signed else f"uint{compressed_bw}" } return encoding_dict diff --git a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py index 8bd72f3b02e..07a184135db 100644 --- a/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py +++ b/TrainingExtensions/torch/test/python/v2/quantization/affine/test_affine_quantizer.py @@ -1790,17 +1790,16 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale, op = OperatorSetIdProto() op.version = 21 - assert output_dtype in ("int8", "int16", "uint8", "uint16") assert y_zero_point is None - onnx_dtype = TensorProto.INT16 if output_dtype == "int16" else \ - TensorProto.INT8 if output_dtype == "int8" else \ - TensorProto.INT4 if output_dtype == "int4" else \ - TensorProto.UINT16 if output_dtype == "uint16" else \ - TensorProto.UINT8 if output_dtype == "uint8" else \ - TensorProto.UINT4 if output_dtype == "uint4" else \ - None - assert onnx_dtype is not None + x_int_dtype = TensorProto.INT16 if output_dtype == "int16" else \ + TensorProto.INT8 if output_dtype == "int8" else \ + TensorProto.INT4 if output_dtype == "int4" else \ + TensorProto.UINT16 if output_dtype == "uint16" else \ + TensorProto.UINT8 if output_dtype == "uint8" else \ + TensorProto.UINT4 if output_dtype == "uint4" else \ + None + assert x_int_dtype is not None x = helper.make_tensor_value_info(name='x', elem_type=TensorProto.FLOAT, @@ -1812,21 +1811,27 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale, name='per_channel_float_scale') y = helper.make_tensor_value_info(name='y', - elem_type=onnx_dtype, + elem_type=TensorProto.FLOAT, shape=input_shape) - dequantize_node = helper.make_node('Mul', - inputs=['per_block_int_scale', 'per_channel_float_scale'], - outputs=['y_scale']) + mul_node = helper.make_node('Mul', + inputs=['per_block_int_scale', 'per_channel_float_scale'], + outputs=['y_scale']) quantize_node = helper.make_node('QuantizeLinear', inputs=['x', 'y_scale'], - outputs=['y'], + outputs=['x_int'], axis=axis, block_size=block_size, - output_dtype=onnx_dtype) + output_dtype=x_int_dtype) + + dequantize_node = helper.make_node('DequantizeLinear', + inputs=['x_int', 'y_scale'], + outputs=['y'], + axis=axis, + block_size=block_size) - onnx_graph = helper.make_graph([dequantize_node, quantize_node], + onnx_graph = helper.make_graph([mul_node, quantize_node, dequantize_node], name='lpbq', inputs=[x], outputs=[y], @@ -1896,7 +1901,7 @@ def test_lpbq_encoding_schema_2_0_0(shape, block_size, block_grouping, axis, com assert encoding["y_zero_point"] is None assert encoding["axis"] == axis assert encoding["block_size"] == next(iter(blk for blk in block_size if blk != 1)) - assert encoding["output_dtype"] == f"int{decompressed_bw}" + assert encoding["output_dtype"] == f"int{compressed_bw}" """ @@ -1935,6 +1940,9 @@ def test_lpbq_encoding_schema_2_0_0(shape, block_size, block_grouping, axis, com ort_out, = sess.run(None, {'x': x.numpy()}) ort_out = torch.from_numpy(ort_out.astype("float64")) - torch_out = qtzr(x.to(torch.float64)).quantize() + torch_out = qtzr(x.to(torch.float64)) + atol = per_block_int_scale * per_channel_float_scale # Allow off-by-one error + atol = atol.amax(dim=[axis for axis, blk in enumerate(block_size) if blk != 1], + keepdim=True) - assert torch.allclose(ort_out, torch_out, atol=1) + assert torch.all((ort_out - torch_out).abs() < atol)