Skip to content

Commit

Permalink
Redefine LPBQ output_dtype of 2.0.0 encoding schema (#3849)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Feb 28, 2025
1 parent 4a5ee65 commit 5d0817f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def _export_2_0_0_encodings(self) -> Optional[Dict]:
"per_channel_float_scale": per_channel_scale.tolist(),
"y_zero_point": None,
**encodings,
"output_dtype": f"int{decompressed_bw}" if output_dtype.startswith("int") else f"uint{decompressed_bw}"
"output_dtype": f"int{compressed_bw}" if output_dtype.startswith("int") else f"uint{compressed_bw}"
}

def _fill_mismatching_encoding_settings_info(self, encoding_dict: Optional[dict], encoding_mismatch_info: _EncodingMismatchInfo):
Expand Down
3 changes: 1 addition & 2 deletions TrainingExtensions/onnx/test/python/test_qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,6 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
op = OperatorSetIdProto()
op.version = 21

assert output_dtype in ("int8", "int16", "uint8", "uint16")
assert y_zero_point is None

x_int_dtype = TensorProto.INT16 if output_dtype == "int16" else \
Expand Down Expand Up @@ -1358,7 +1357,7 @@ def test_lpbq_encoding_schema_2_0_0(input_shape, block_axis, block_size, compres
assert encoding["y_zero_point"] is None
assert encoding["axis"] == block_axis
assert encoding["block_size"] == block_size
assert encoding["output_dtype"] == f"int{decompressed_bw}"
assert encoding["output_dtype"] == f"int{compressed_bw}"


"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def to_qnn_encoding_dict(self, encoding_version=None) -> Union[List, Dict]:
del encoding_dict["y_scale"]
del encoding_dict["output_dtype"]

decompressed_bw = self.decompressed_bw
compressed_bw = self.bitwidth
y_zero_point = encoding_dict.pop("y_zero_point")

if y_zero_point is not None and torch.any(torch.tensor(y_zero_point) != 0):
Expand All @@ -526,7 +526,7 @@ def to_qnn_encoding_dict(self, encoding_version=None) -> Union[List, Dict]:
"per_channel_float_scale": self.per_channel_scale.tolist(),
"y_zero_point": None,
**encoding_dict,
"output_dtype": f"int{decompressed_bw}" if self.signed else f"uint{decompressed_bw}"
"output_dtype": f"int{compressed_bw}" if self.signed else f"uint{compressed_bw}"
}

return encoding_dict
Original file line number Diff line number Diff line change
Expand Up @@ -1790,17 +1790,16 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
op = OperatorSetIdProto()
op.version = 21

assert output_dtype in ("int8", "int16", "uint8", "uint16")
assert y_zero_point is None

onnx_dtype = TensorProto.INT16 if output_dtype == "int16" else \
TensorProto.INT8 if output_dtype == "int8" else \
TensorProto.INT4 if output_dtype == "int4" else \
TensorProto.UINT16 if output_dtype == "uint16" else \
TensorProto.UINT8 if output_dtype == "uint8" else \
TensorProto.UINT4 if output_dtype == "uint4" else \
None
assert onnx_dtype is not None
x_int_dtype = TensorProto.INT16 if output_dtype == "int16" else \
TensorProto.INT8 if output_dtype == "int8" else \
TensorProto.INT4 if output_dtype == "int4" else \
TensorProto.UINT16 if output_dtype == "uint16" else \
TensorProto.UINT8 if output_dtype == "uint8" else \
TensorProto.UINT4 if output_dtype == "uint4" else \
None
assert x_int_dtype is not None

x = helper.make_tensor_value_info(name='x',
elem_type=TensorProto.FLOAT,
Expand All @@ -1812,21 +1811,27 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
name='per_channel_float_scale')

y = helper.make_tensor_value_info(name='y',
elem_type=onnx_dtype,
elem_type=TensorProto.FLOAT,
shape=input_shape)

dequantize_node = helper.make_node('Mul',
inputs=['per_block_int_scale', 'per_channel_float_scale'],
outputs=['y_scale'])
mul_node = helper.make_node('Mul',
inputs=['per_block_int_scale', 'per_channel_float_scale'],
outputs=['y_scale'])

quantize_node = helper.make_node('QuantizeLinear',
inputs=['x', 'y_scale'],
outputs=['y'],
outputs=['x_int'],
axis=axis,
block_size=block_size,
output_dtype=onnx_dtype)
output_dtype=x_int_dtype)

dequantize_node = helper.make_node('DequantizeLinear',
inputs=['x_int', 'y_scale'],
outputs=['y'],
axis=axis,
block_size=block_size)

onnx_graph = helper.make_graph([dequantize_node, quantize_node],
onnx_graph = helper.make_graph([mul_node, quantize_node, dequantize_node],
name='lpbq',
inputs=[x],
outputs=[y],
Expand Down Expand Up @@ -1896,7 +1901,7 @@ def test_lpbq_encoding_schema_2_0_0(shape, block_size, block_grouping, axis, com
assert encoding["y_zero_point"] is None
assert encoding["axis"] == axis
assert encoding["block_size"] == next(iter(blk for blk in block_size if blk != 1))
assert encoding["output_dtype"] == f"int{decompressed_bw}"
assert encoding["output_dtype"] == f"int{compressed_bw}"


"""
Expand Down Expand Up @@ -1935,6 +1940,9 @@ def test_lpbq_encoding_schema_2_0_0(shape, block_size, block_grouping, axis, com
ort_out, = sess.run(None, {'x': x.numpy()})

ort_out = torch.from_numpy(ort_out.astype("float64"))
torch_out = qtzr(x.to(torch.float64)).quantize()
torch_out = qtzr(x.to(torch.float64))
atol = per_block_int_scale * per_channel_float_scale # Allow off-by-one error
atol = atol.amax(dim=[axis for axis, blk in enumerate(block_size) if blk != 1],
keepdim=True)

assert torch.allclose(ort_out, torch_out, atol=1)
assert torch.all((ort_out - torch_out).abs() < atol)

0 comments on commit 5d0817f

Please sign in to comment.