Skip to content

Commit

Permalink
Implement LPBQ encoding schema 2.0.0 (beta) in aimet-onnx (#3847)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyunggeun Lee <[email protected]>
  • Loading branch information
quic-kyunggeu authored Feb 28, 2025
1 parent 22cba12 commit d909c00
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 28 deletions.
41 changes: 38 additions & 3 deletions TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,13 +881,48 @@ def _export_1_0_0_encodings(self) -> Optional[Dict]:
encodings["compressed_bw"] = self.bitwidth
encodings["bw"] = self.decompressed_bw
scale, _ = lpbq_utils.encodings_to_scale_offset_arrays(self.get_encodings(), self._encoding_shape())
int_scale, per_block_scale = lpbq_utils.grouped_dynamic_quantize(scale, self._block_grouping(), self.decompressed_bw - self.bitwidth)
encodings['per_block_int_scale'] = int_scale.flatten().tolist()
encodings['scale'] = per_block_scale.flatten().tolist()
compressed_bw = self.bitwidth
decompressed_bw = self.decompressed_bw
per_block_int_scale, per_channel_scale = lpbq_utils.grouped_dynamic_quantize(scale,
self._block_grouping(),
decompressed_bw - compressed_bw)
encodings['per_block_int_scale'] = per_block_int_scale.flatten().tolist()
encodings['scale'] = per_channel_scale.flatten().tolist()
encodings["offset"] = [-2 ** (self.decompressed_bw - 1) for _ in encodings['scale']]

return encodings

def _export_2_0_0_encodings(self) -> Optional[Dict]:
encodings = super()._export_2_0_0_encodings()

if encodings is None:
return None

output_dtype = encodings.pop("output_dtype")
y_zero_point = encodings.pop("y_zero_point")

if y_zero_point is not None and np.any(np.array(y_zero_point) != 0):
raise RuntimeError(
f"LPBQ only supports symmetric quantization; got non-zero y_zero_point {y_zero_point}"
)

compressed_bw = self.bitwidth
decompressed_bw = self.decompressed_bw
y_scale = np.array(encodings.pop("y_scale"))
per_block_int_scale, per_channel_scale = lpbq_utils.grouped_dynamic_quantize(y_scale,
self._block_grouping(),
decompressed_bw - compressed_bw)
per_channel_scale = per_channel_scale.squeeze(tuple(range(1, per_channel_scale.ndim, 2)))
assert per_block_int_scale.ndim == per_channel_scale.ndim

return {
"per_block_int_scale": per_block_int_scale.tolist(),
"per_channel_float_scale": per_channel_scale.tolist(),
"y_zero_point": None,
**encodings,
"output_dtype": f"int{decompressed_bw}" if output_dtype.startswith("int") else f"uint{decompressed_bw}"
}

def _fill_mismatching_encoding_settings_info(self, encoding_dict: Optional[dict], encoding_mismatch_info: _EncodingMismatchInfo):
super()._fill_mismatching_encoding_settings_info(encoding_dict, encoding_mismatch_info)
encoding_mismatch_info.bitwidth_mismatch = None
Expand Down
181 changes: 174 additions & 7 deletions TrainingExtensions/onnx/test/python/test_qc_quantize_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,13 +1100,13 @@ def _onnx_QuantizeDequantizeLinear(input_shape, y_scale, y_zero_point, axis, blo
# axis := |
# +- block axis (otherwise)
"input_shape, channel_axis, block_axis, block_size", [
((10, 10, 3, 3), None, None, None), # per-tensor
((10, 10, 3, 3), 0, None, None), # per-channel with axis=0 (Convolution)
((10, 10, 3, 3), 1, None, None), # per-channel with axis=1 (Convolution)
((10, 10, 1, 1), None, None, None), # per-tensor
((10, 10, 1, 1), 0, None, None), # per-channel with axis=0 (Convolution)
((10, 10, 1, 1), 1, None, None), # per-channel with axis=1 (Convolution)
((10, 10), 0, None, None), # per-channel with axis=0 (Linear/Gemm)
((10, 10), 1, None, None), # per-channel with axis=1 (Linear/Gemm)
((10, 10, 3, 3), 0, 1, 5), # per-block with block_axis=1 (Convolution)
((10, 10, 3, 3), 1, 0, 5), # per-block with block_axis=0 (Convolution)
((10, 10, 1, 1), 0, 1, 5), # per-block with block_axis=1 (Convolution)
((10, 10, 1, 1), 1, 0, 5), # per-block with block_axis=0 (Convolution)
((10, 10), 0, 1, 5), # per-block with block_axis=1 (Linear/Gemm)
((10, 10), 1, 0, 5), # per-block with block_axis=0 (Linear/Gemm)
])
Expand Down Expand Up @@ -1218,7 +1218,174 @@ def test_affine_encoding_schema_2_0_0(input_shape, channel_axis, block_axis, blo
sess = ort.InferenceSession(full_path, providers=['CPUExecutionProvider'])
ort_out, = sess.run(None, {'x': input})

ort_out = ort_out
aimet_out = session.run(None, {'input': input})
atol = (abs(input).max() * 2) / (2 ** bitwidth - 1) # Allow off-by-one error
atol = y_scale # Allow off-by-one error
if block_axis is not None:
atol = atol.max(axis=block_axis, keepdims=True)
elif channel_axis is not None:
atol = atol.reshape(*(1 if axis != channel_axis else -1 for axis in range(input.ndim)))
assert np.allclose(ort_out, aimet_out, atol=atol)


def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
y_zero_point, axis, block_size, output_dtype):
op = OperatorSetIdProto()
op.version = 21

assert output_dtype in ("int8", "int16", "uint8", "uint16")
assert y_zero_point is None

x_int_dtype = TensorProto.INT16 if output_dtype == "int16" else \
TensorProto.INT8 if output_dtype == "int8" else \
TensorProto.INT4 if output_dtype == "int4" else \
TensorProto.UINT16 if output_dtype == "uint16" else \
TensorProto.UINT8 if output_dtype == "uint8" else \
TensorProto.UINT4 if output_dtype == "uint4" else \
None
assert x_int_dtype is not None

x = helper.make_tensor_value_info(name='x',
elem_type=TensorProto.FLOAT,
shape=input_shape)

per_block_int_scale = numpy_helper.from_array(np.array(per_block_int_scale).astype('float32'),
name='per_block_int_scale')
per_channel_float_scale = numpy_helper.from_array(np.array(per_channel_float_scale).astype('float32'),
name='per_channel_float_scale')

y = helper.make_tensor_value_info(name='y',
elem_type=TensorProto.FLOAT,
shape=input_shape)

mul_node = helper.make_node('Mul',
inputs=['per_block_int_scale', 'per_channel_float_scale'],
outputs=['y_scale'])

quantize_node = helper.make_node('QuantizeLinear',
inputs=['x', 'y_scale'],
outputs=['x_int'],
axis=axis,
block_size=block_size,
output_dtype=x_int_dtype)

dequantize_node = helper.make_node('DequantizeLinear',
inputs=['x_int', 'y_scale'],
outputs=['y'],
axis=axis,
block_size=block_size)

onnx_graph = helper.make_graph([mul_node, quantize_node, dequantize_node],
name='lpbq',
inputs=[x],
outputs=[y],
initializer=[per_block_int_scale, per_channel_float_scale])

model = helper.make_model(onnx_graph, opset_imports=[op])
onnx.checker.check_model(model, True)

return model


@pytest.mark.parametrize(
"input_shape, block_axis, block_size", [
((10, 50, 1, 1), 1, 5), # per-block with block_axis=1 (Convolution)
((50, 10, 1, 1), 0, 5), # per-block with block_axis=0 (Convolution)
((10, 50), 1, 5), # per-block with block_axis=1 (Linear/Gemm)
((50, 10), 0, 5), # per-block with block_axis=0 (Linear/Gemm)
])
@pytest.mark.parametrize(
"compressed_bw, decompressed_bw", [
(4, 8),
(8, 16),
])
def test_lpbq_encoding_schema_2_0_0(input_shape, block_axis, block_size, compressed_bw, decompressed_bw):
"""
Given: QcQuantizeOp
"""
input = np.random.randn(*input_shape).astype(np.float32)
channel_axis = 0 if block_axis == 1 else 1
quant_params = TensorQuantizerParams(input_shape, channel_axis, block_axis)

quant_info = libquant_info.QcQuantizeInfo()
quant_info.isIntDataType = True
quant_info.channelAxis = channel_axis
quant_info.blockAxis = block_axis

quant_node = helper.make_node(op_name, inputs=['input'], outputs=['output'],
domain=op_domain, quant_info=libpymo.PtrToInt64(quant_info))
model = create_model_from_node(quant_node, input.shape)
session = build_session(model, available_providers)
qtzr = GroupedBlockQuantizeDequantize(quant_info,
compressed_bw,
decompressed_bw,
block_size=block_size,
quant_scheme=QuantScheme.post_training_tf,
op_mode=OpMode.oneShotQuantizeDequantize,
tensor_quantizer_params=quant_params)

_, = session.run(None, {'input': input})
qtzr.compute_encodings()

"""
When: Export encoding in 2.0.0.beta schema
"""
encoding = qtzr.export_encodings("2.0.0.beta")

"""
Then: Exported qnn encoding should contain:
* "per_block_int_scale"
* "per_channel_float_scale"
* "y_zero_point"
* "axis"
* "block_size"
* "output_dtype"
all of which are defined as onnx::QuantizeLinear except
per_block_int_scale * per_channel_float_scale == y_scale
"""

per_block_int_scale = np.array(encoding["per_block_int_scale"])
per_channel_float_scale = np.array(encoding["per_channel_float_scale"])

assert per_block_int_scale.ndim == per_channel_float_scale.ndim == input.ndim
assert per_block_int_scale.shape[channel_axis] == input.shape[channel_axis]
assert per_block_int_scale.shape[block_axis] == input.shape[block_axis] // block_size
assert all(dim == 1 for axis, dim in enumerate(per_block_int_scale.shape)
if axis not in (channel_axis, block_axis))
assert per_channel_float_scale.shape[channel_axis] == input.shape[channel_axis]
assert all(dim == 1 for axis, dim in enumerate(per_channel_float_scale.shape) if axis != channel_axis)

assert encoding["y_zero_point"] is None
assert encoding["axis"] == block_axis
assert encoding["block_size"] == block_size
assert encoding["output_dtype"] == f"int{decompressed_bw}"


"""
Then: The output of onnx::QuantizeLinear followed by DequantizeLinear with the exported qnn encoding
should be all-close to AIMET qdq output with off-by-one tolerance threshold
"""
if version.parse(ort.__version__) < version.parse("1.20.0"):
pytest.skip(reason="Remaining tests require onnxruntime>=1.20 for blockwise QuantizeLinear")

onnx_LPBQ = _onnx_LPBQ(input_shape=input.shape,
per_block_int_scale=encoding["per_block_int_scale"],
per_channel_float_scale=encoding["per_channel_float_scale"],
y_zero_point=encoding["y_zero_point"],
axis=encoding["axis"],
block_size=encoding["block_size"],
output_dtype=encoding["output_dtype"])

with tempfile.TemporaryDirectory() as tmp_dir:
full_path = os.path.join(tmp_dir, "model.onnx")

with open(full_path, "wb") as f:
f.write(onnx_LPBQ.SerializeToString())

sess = ort.InferenceSession(full_path, providers=['CPUExecutionProvider'])
ort_out, = sess.run(None, {'x': input})

aimet_out = session.run(None, {'input': input})
y_scale = per_block_int_scale * per_channel_float_scale
atol = y_scale.max(axis=block_axis, keepdims=True) # Allow off-by-one error
assert np.allclose(ort_out, aimet_out, atol=atol)
Original file line number Diff line number Diff line change
Expand Up @@ -1786,7 +1786,7 @@ def test_affine_encoding_schema_2_0_0(shape, block_size, axis,


def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
y_zero_point, axis, block_size, block_grouping, output_dtype):
y_zero_point, axis, block_size, output_dtype):
op = OperatorSetIdProto()
op.version = 21

Expand All @@ -1806,7 +1806,7 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
elem_type=TensorProto.FLOAT,
shape=input_shape)

per_block_int_scale = numpy_helper.from_array(np.array(per_block_int_scale).astype('int32'),
per_block_int_scale = numpy_helper.from_array(np.array(per_block_int_scale).astype('float32'),
name='per_block_int_scale')
per_channel_float_scale = numpy_helper.from_array(np.array(per_channel_float_scale).astype('float32'),
name='per_channel_float_scale')
Expand All @@ -1815,16 +1815,9 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
elem_type=onnx_dtype,
shape=input_shape)

group_axis, group_size = next(iter(
(group_axis, group_size)
for group_axis, group_size in enumerate(block_grouping)
if group_size != 1
))
dequantize_node = helper.make_node('DequantizeLinear',
dequantize_node = helper.make_node('Mul',
inputs=['per_block_int_scale', 'per_channel_float_scale'],
outputs=['y_scale'],
axis=group_axis,
block_size=group_size)
outputs=['y_scale'])

quantize_node = helper.make_node('QuantizeLinear',
inputs=['x', 'y_scale'],
Expand All @@ -1848,15 +1841,15 @@ def _onnx_LPBQ(input_shape, per_block_int_scale, per_channel_float_scale,
@torch.no_grad()
@pytest.mark.parametrize(
"shape, block_size, block_grouping, axis", [
((10, 64, 1, 1), (1, 4, 1, 1), (1, 32, 1, 1), 1), # per-block with block_axis=1 (Convolution)
((64, 10, 1, 1), (4, 1, 1, 1), (32, 1, 1, 1), 0), # per-block with block_axis=0 (Convolution)
((10, 64), (1, 4), (1, 32), 1), # per-block with block_axis=1 (Linear/Gemm)
((64, 10), (4, 1), (32, 1), 0), # per-block with block_axis=0 (Linear/Gemm)
((10, 64, 1, 1), (1, 8, 1, 1), (1, 64, 1, 1), 1), # per-block with block_axis=1 (Convolution)
((64, 10, 1, 1), (8, 1, 1, 1), (64, 1, 1, 1), 0), # per-block with block_axis=0 (Convolution)
((10, 64), (1, 8), (1, 64), 1), # per-block with block_axis=1 (Linear/Gemm)
((64, 10), (8, 1), (64, 1), 0), # per-block with block_axis=0 (Linear/Gemm)
])
@pytest.mark.parametrize(
"compressed_bw, decompressed_bw", [
(4, 8),
(8, 16),
(4, 8),
(8, 16),
])
def test_lpbq_encoding_schema_2_0_0(shape, block_size, block_grouping, axis, compressed_bw, decompressed_bw):
"""
Expand Down Expand Up @@ -1930,7 +1923,6 @@ def test_lpbq_encoding_schema_2_0_0(shape, block_size, block_grouping, axis, com
y_zero_point=encoding["y_zero_point"],
axis=encoding["axis"],
block_size=encoding["block_size"],
block_grouping=block_grouping,
output_dtype=encoding["output_dtype"])

with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down

0 comments on commit d909c00

Please sign in to comment.