Skip to content

Commit

Permalink
[Quant Tool] Prevent int32 quantized bias from clipping by adjusting …
Browse files Browse the repository at this point in the history
…the weight's scale (microsoft#22020)

### Description
Fixes scenario in which a bias input quantized to int32 has a scale that
is too small. A bias with a scale that is smaller than a certain
threshold will overflow the range of an `int32` when quantized, which
significantly decreases accuracy.

Credit to @yihonglyu for finding out about this issue and the fix.

### Motivation and Context
Consider the following Convolution with very small weights and a
constant bias input of `[5, -4.5]`.

![image](https://github.com/user-attachments/assets/4bde2bd9-892f-4ae9-887b-61a6668779a1)

The QDQ quantizer first computes the following quantization scale for
`input_0` and `weight`:
- `input_0`: scale=0.5
- `weight`: scale=7.843e-10 **[really small]**

The QDQ quantizer then computes the bias input's scale as follows:
```
bias_scale = input_0_scale * weight_0_scale = 0.5 * 7.843e-10 = 3.9215686274509805e-11
```

This `bias_scale` is too small. Before this PR, the QDQ quantizer would
quantize the f32 bias with this `bias_scale`:
```
bias_quant = round(bias_f32 / bias_scale) =  round([5.0/bias_scale, -4.5/bias_scale]) = [127500000000, -114750000000]
```
These quantized bias values exceed the range of int32, and so are
clipped to [int32.min(), int32.max()], which is very inaccurate.

#### New approach
This PR increases the `weight_0_scale` by the necessary amount to ensure
that `bias_scale` (which equals `weight_0_scale * input_0_scale`) is
appropriate for the int32 quantization type.

The smallest valid bias scale is given by the normal scale formula: 
`bias_smallest_valid_scale = (bias_f32_max - bias_f32_min) / (int32_max
- int32_min)`

Then, we compute the candidate bias scale:
`bias_scale_candidate = input_0_scale * weight_0_scale`

If the candidate scale is smaller than the smallest valid scale, we
increase the `weight_0_scale` by the necessary ratio:
```python
if bias_scale_candidate < bias_smallest_valid_scale:
    ratio = bias_smallest_valid_scale / bias_scale_candidate
    weight_0_scale = ratio * weight_0_scale
```

Then, we recompute the final bias scale:
```python
bias_scale = input_0_scale * weight_0_scale
```

#### Impact on accuracy
Here's the above model's quantized output compared to the f32
(ground-truth) output.
- Before PR: 
  - f32 model output[0]: **5.0f**
  - qdq model output[0]: **0.075**
  - SNR: 0.1369 (higher is better)
- After PR:
  - f32 model output[0]: **5.0f**
  - qdq model output[0]: **4.992**
  - SNR: 55.656 (higher is better)
  • Loading branch information
adrianlizarraga authored and Ishwar Raut committed Nov 19, 2024
1 parent e06d459 commit 3a33382
Show file tree
Hide file tree
Showing 9 changed files with 822 additions and 186 deletions.
61 changes: 43 additions & 18 deletions onnxruntime/python/tools/quantization/base_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from .quant_utils import (
ONNX_TYPE_TO_NP_TYPE,
TENSOR_NAME_QUANT_SUFFIX,
QuantType,
find_by_name,
model_has_infer_metadata,
normalize_axis,
Expand All @@ -40,18 +39,26 @@ def __init__(self, **data: Dict[str, Any]):
for k, v in data.items():
if not isinstance(k, str):
raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.")
if not isinstance(v, (int, str, np.ndarray)):
if k != "axis" and not isinstance(v, (int, str, np.ndarray)):
raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.")
if k == "axis" and not isinstance(v, int) and v is not None:
raise TypeError(f"Axis value must be an int or None, not {type(v)}.")
if k == "scale" and v.dtype not in (np.float32, np.float16):
raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}")
self.data[k] = v

def get(self, key, default_value=None):
return self.data.get(key, default_value)

def __iter__(self):
yield from self.data

def __getitem__(self, key):
return self.data[key]

def __setitem__(self, key, value):
self.data[key] = value

def __len__(self):
return len(self.data)

Expand Down Expand Up @@ -88,9 +95,10 @@ def __init__(
self.force_quantize_no_input_check = (
"ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"]
)
self.is_weight_symmetric = self.extra_options.get(
"WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN)
)

# If user does not explicitly set "WeightSymmetric", then the weight's quantization type determines
# the symmetry (i.e., signed integer types will use symmetric quantization). See `def is_weight_symmetric()`
self._is_weight_symmetric: bool | None = self.extra_options.get("WeightSymmetric", None)
self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False)
self.min_real_range = self.extra_options.get("MinimumRealRange")

Expand Down Expand Up @@ -131,6 +139,16 @@ def __init__(

self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types()

def is_weight_symmetric(self, weight_quant_type: onnx.TensorProto.DataType) -> bool:
if self._is_weight_symmetric is not None:
return self._is_weight_symmetric # Return value explicitly set by user.
return weight_quant_type in (
onnx.TensorProto.INT4,
onnx.TensorProto.INT8,
onnx.TensorProto.INT16,
onnx.TensorProto.FLOAT8E4M3FN,
)

def quantize_model(self):
raise NotImplementedError

Expand Down Expand Up @@ -230,9 +248,19 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1
# TODO: This formula should be explained including why the scale is not estimated for the bias as well.
bias_scale = input_scale * weight_scale * beta

quantized_data = (np.asarray(bias_data) / bias_scale).round()
quantized_data = np.clip(quantized_data, np.iinfo(np.int32).min, np.iinfo(np.int32).max)
quantized_data = quantized_data.astype(np.int32)
# Quantize by dividing by bias_scale
quantized_data = np.asarray(bias_data, dtype=np.float64) / np.asarray(bias_scale, dtype=np.float64)
quantized_data = quantized_data.round()

# Clip quantized data to the range of a int32
int32_min = np.float64(np.iinfo(np.int32).min)
int32_max = np.float64(np.iinfo(np.int32).max)
if np.any(quantized_data < int32_min) or np.any(quantized_data > int32_max):
logging.warning(
f"Quantized bias `{bias_name}` exceeds the range of a int32. The bias scale is too small."
)

quantized_data = np.clip(quantized_data, int32_min, int32_max).astype(np.int32)

# update bias initializer
bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims)
Expand Down Expand Up @@ -282,6 +310,7 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa
If keep_float_weight is False, quantize the weight, or don't quantize the weight.
:return: quantized weight name, zero point name, scale name
"""
# TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there.
q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX
zp_name = weight.name + "_zero_point"
scale_name = weight.name + "_scale"
Expand All @@ -303,10 +332,11 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa
assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}"

else:
_, _, zero_point, scale, q_weight_data = quantize_data(
symmetric = self.is_weight_symmetric(qType) if qType == self.weight_qType else self.is_activation_symmetric
zero_point, scale, q_weight_data = quantize_data(
weight_data.flatten(),
qType,
quant_overrides.get("symmetric", self.is_weight_symmetric),
quant_overrides.get("symmetric", symmetric),
reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range),
min_real_range=self.min_real_range,
rmin_override=quant_overrides.get("rmin"),
Expand Down Expand Up @@ -371,6 +401,7 @@ def quantize_weight_per_channel_impl(
reduce_range=True,
keep_float_weight=False,
):
# TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there.
initializer = find_by_name(weight_name, self.model.initializer())
if initializer is None:
raise ValueError("{} is not an initializer", weight_name)
Expand Down Expand Up @@ -409,13 +440,7 @@ def quantize_weight_per_channel_impl(
if "quant_type" in quant_overrides_for_channels[0]:
weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806

symmetric = quant_overrides_for_channels[0].get(
"symmetric",
(
self.is_weight_symmetric
or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4)
),
)
symmetric = quant_overrides_for_channels[0].get("symmetric", self.is_weight_symmetric(weight_qType))
reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range)
zero_point_list = []
scale_list = []
Expand Down Expand Up @@ -444,7 +469,7 @@ def quantize_weight_per_channel_impl(
), f"Unexpected type {type(quantized_per_channel_data)}"

else:
_, _, zero_point, scale, quantized_per_channel_data = quantize_data(
zero_point, scale, quantized_per_channel_data = quantize_data(
per_channel_data.flatten(),
weight_qType,
symmetric,
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/python/tools/quantization/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,26 @@ def get_largest_node_name_suffix(self, node_name_prefix):

return suffix

def get_largest_initializer_name_suffix(self, initializer_name_prefix):
"""
Gets the largest initializer name integer suffix for all initializer names that begin
with `initializer_name_prefix`. This can be used to create unique initializer names.
Example: for initializer names 'my_weight_0' and 'my_weight_3', this method returns 3 if
`initializer_name_prefix` is 'my_weight_'.
"""
suffix = -1

for initializer in self.model.graph.initializer:
if initializer.name.startswith(initializer_name_prefix):
try:
index = int(initializer.name[len(initializer_name_prefix) :])
suffix = max(index, suffix)
except ValueError:
continue

return suffix

def find_nodes_by_initializer(self, graph, initializer):
"""
Find all nodes with given initializer as an input.
Expand Down
Loading

0 comments on commit 3a33382

Please sign in to comment.