From 3a333823e8f2e8891a61301823358b833181445c Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 6 Nov 2024 10:44:54 -0800 Subject: [PATCH] [Quant Tool] Prevent int32 quantized bias from clipping by adjusting the weight's scale (#22020) ### Description Fixes scenario in which a bias input quantized to int32 has a scale that is too small. A bias with a scale that is smaller than a certain threshold will overflow the range of an `int32` when quantized, which significantly decreases accuracy. Credit to @yihonglyu for finding out about this issue and the fix. ### Motivation and Context Consider the following Convolution with very small weights and a constant bias input of `[5, -4.5]`. ![image](https://github.com/user-attachments/assets/4bde2bd9-892f-4ae9-887b-61a6668779a1) The QDQ quantizer first computes the following quantization scale for `input_0` and `weight`: - `input_0`: scale=0.5 - `weight`: scale=7.843e-10 **[really small]** The QDQ quantizer then computes the bias input's scale as follows: ``` bias_scale = input_0_scale * weight_0_scale = 0.5 * 7.843e-10 = 3.9215686274509805e-11 ``` This `bias_scale` is too small. Before this PR, the QDQ quantizer would quantize the f32 bias with this `bias_scale`: ``` bias_quant = round(bias_f32 / bias_scale) = round([5.0/bias_scale, -4.5/bias_scale]) = [127500000000, -114750000000] ``` These quantized bias values exceed the range of int32, and so are clipped to [int32.min(), int32.max()], which is very inaccurate. #### New approach This PR increases the `weight_0_scale` by the necessary amount to ensure that `bias_scale` (which equals `weight_0_scale * input_0_scale`) is appropriate for the int32 quantization type. The smallest valid bias scale is given by the normal scale formula: `bias_smallest_valid_scale = (bias_f32_max - bias_f32_min) / (int32_max - int32_min)` Then, we compute the candidate bias scale: `bias_scale_candidate = input_0_scale * weight_0_scale` If the candidate scale is smaller than the smallest valid scale, we increase the `weight_0_scale` by the necessary ratio: ```python if bias_scale_candidate < bias_smallest_valid_scale: ratio = bias_smallest_valid_scale / bias_scale_candidate weight_0_scale = ratio * weight_0_scale ``` Then, we recompute the final bias scale: ```python bias_scale = input_0_scale * weight_0_scale ``` #### Impact on accuracy Here's the above model's quantized output compared to the f32 (ground-truth) output. - Before PR: - f32 model output[0]: **5.0f** - qdq model output[0]: **0.075** - SNR: 0.1369 (higher is better) - After PR: - f32 model output[0]: **5.0f** - qdq model output[0]: **4.992** - SNR: 55.656 (higher is better) --- .../tools/quantization/base_quantizer.py | 61 ++- .../python/tools/quantization/onnx_model.py | 20 + .../tools/quantization/qdq_quantizer.py | 508 +++++++++++++----- .../python/tools/quantization/quant_utils.py | 196 +++++-- .../python/tools/quantization/quantize.py | 6 + .../quantization/tensor_quant_overrides.py | 4 + .../test/python/quantization/test_qdq.py | 199 +++++++ .../python/quantization/test_quant_util.py | 2 +- .../test_tensor_quant_overrides_option.py | 12 +- 9 files changed, 822 insertions(+), 186 deletions(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index b20af5137d206..b12465ffa7926 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -21,7 +21,6 @@ from .quant_utils import ( ONNX_TYPE_TO_NP_TYPE, TENSOR_NAME_QUANT_SUFFIX, - QuantType, find_by_name, model_has_infer_metadata, normalize_axis, @@ -40,18 +39,26 @@ def __init__(self, **data: Dict[str, Any]): for k, v in data.items(): if not isinstance(k, str): raise TypeError(f"Keys must be strings not {type(k)} for k={k!r}.") - if not isinstance(v, (int, str, np.ndarray)): + if k != "axis" and not isinstance(v, (int, str, np.ndarray)): raise TypeError(f"Values must be numpy arrays, int, float, str not {type(v)} for k={k!r}.") + if k == "axis" and not isinstance(v, int) and v is not None: + raise TypeError(f"Axis value must be an int or None, not {type(v)}.") if k == "scale" and v.dtype not in (np.float32, np.float16): raise ValueError(f"scale must a float32 or float16 numpy element but is {v.dtype} for k={k!r}") self.data[k] = v + def get(self, key, default_value=None): + return self.data.get(key, default_value) + def __iter__(self): yield from self.data def __getitem__(self, key): return self.data[key] + def __setitem__(self, key, value): + self.data[key] = value + def __len__(self): return len(self.data) @@ -88,9 +95,10 @@ def __init__( self.force_quantize_no_input_check = ( "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"] ) - self.is_weight_symmetric = self.extra_options.get( - "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) - ) + + # If user does not explicitly set "WeightSymmetric", then the weight's quantization type determines + # the symmetry (i.e., signed integer types will use symmetric quantization). See `def is_weight_symmetric()` + self._is_weight_symmetric: bool | None = self.extra_options.get("WeightSymmetric", None) self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False) self.min_real_range = self.extra_options.get("MinimumRealRange") @@ -131,6 +139,16 @@ def __init__( self.tensor_quant_override_qtypes = self.tensor_quant_overrides.get_quant_types() + def is_weight_symmetric(self, weight_quant_type: onnx.TensorProto.DataType) -> bool: + if self._is_weight_symmetric is not None: + return self._is_weight_symmetric # Return value explicitly set by user. + return weight_quant_type in ( + onnx.TensorProto.INT4, + onnx.TensorProto.INT8, + onnx.TensorProto.INT16, + onnx.TensorProto.FLOAT8E4M3FN, + ) + def quantize_model(self): raise NotImplementedError @@ -230,9 +248,19 @@ def quantize_bias_static_impl(self, bias_name, input_scale, weight_scale, beta=1 # TODO: This formula should be explained including why the scale is not estimated for the bias as well. bias_scale = input_scale * weight_scale * beta - quantized_data = (np.asarray(bias_data) / bias_scale).round() - quantized_data = np.clip(quantized_data, np.iinfo(np.int32).min, np.iinfo(np.int32).max) - quantized_data = quantized_data.astype(np.int32) + # Quantize by dividing by bias_scale + quantized_data = np.asarray(bias_data, dtype=np.float64) / np.asarray(bias_scale, dtype=np.float64) + quantized_data = quantized_data.round() + + # Clip quantized data to the range of a int32 + int32_min = np.float64(np.iinfo(np.int32).min) + int32_max = np.float64(np.iinfo(np.int32).max) + if np.any(quantized_data < int32_min) or np.any(quantized_data > int32_max): + logging.warning( + f"Quantized bias `{bias_name}` exceeds the range of a int32. The bias scale is too small." + ) + + quantized_data = np.clip(quantized_data, int32_min, int32_max).astype(np.int32) # update bias initializer bias_np_data = np.asarray(quantized_data, dtype=np.int32).reshape(bias_initializer.dims) @@ -282,6 +310,7 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa If keep_float_weight is False, quantize the weight, or don't quantize the weight. :return: quantized weight name, zero point name, scale name """ + # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there. q_weight_name = weight.name + TENSOR_NAME_QUANT_SUFFIX zp_name = weight.name + "_zero_point" scale_name = weight.name + "_scale" @@ -303,10 +332,11 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa assert isinstance(scale, np.ndarray), f"Unexpected type {type(scale)}" else: - _, _, zero_point, scale, q_weight_data = quantize_data( + symmetric = self.is_weight_symmetric(qType) if qType == self.weight_qType else self.is_activation_symmetric + zero_point, scale, q_weight_data = quantize_data( weight_data.flatten(), qType, - quant_overrides.get("symmetric", self.is_weight_symmetric), + quant_overrides.get("symmetric", symmetric), reduce_range=quant_overrides.get("reduce_range", self.reduce_range and reduce_range), min_real_range=self.min_real_range, rmin_override=quant_overrides.get("rmin"), @@ -371,6 +401,7 @@ def quantize_weight_per_channel_impl( reduce_range=True, keep_float_weight=False, ): + # TODO(adrianlizarraga): This function is now only used by onnx_quantizer.py, so move it there. initializer = find_by_name(weight_name, self.model.initializer()) if initializer is None: raise ValueError("{} is not an initializer", weight_name) @@ -409,13 +440,7 @@ def quantize_weight_per_channel_impl( if "quant_type" in quant_overrides_for_channels[0]: weight_qType = quant_overrides_for_channels[0]["quant_type"].tensor_type # noqa: N806 - symmetric = quant_overrides_for_channels[0].get( - "symmetric", - ( - self.is_weight_symmetric - or weight_qType in (onnx.TensorProto.INT8, onnx.TensorProto.FLOAT8E4M3FN, onnx.TensorProto.INT4) - ), - ) + symmetric = quant_overrides_for_channels[0].get("symmetric", self.is_weight_symmetric(weight_qType)) reduce_range = quant_overrides_for_channels[0].get("reduce_range", self.reduce_range and reduce_range) zero_point_list = [] scale_list = [] @@ -444,7 +469,7 @@ def quantize_weight_per_channel_impl( ), f"Unexpected type {type(quantized_per_channel_data)}" else: - _, _, zero_point, scale, quantized_per_channel_data = quantize_data( + zero_point, scale, quantized_per_channel_data = quantize_data( per_channel_data.flatten(), weight_qType, symmetric, diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 174bf5fd1509c..43105550139de 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -296,6 +296,26 @@ def get_largest_node_name_suffix(self, node_name_prefix): return suffix + def get_largest_initializer_name_suffix(self, initializer_name_prefix): + """ + Gets the largest initializer name integer suffix for all initializer names that begin + with `initializer_name_prefix`. This can be used to create unique initializer names. + + Example: for initializer names 'my_weight_0' and 'my_weight_3', this method returns 3 if + `initializer_name_prefix` is 'my_weight_'. + """ + suffix = -1 + + for initializer in self.model.graph.initializer: + if initializer.name.startswith(initializer_name_prefix): + try: + index = int(initializer.name[len(initializer_name_prefix) :]) + suffix = max(index, suffix) + except ValueError: + continue + + return suffix + def find_nodes_by_initializer(self, graph, initializer): """ Find all nodes with given initializer as an input. diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index b71f332252850..048c7f3296503 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -20,6 +20,7 @@ from .calibrate import TensorData from .quant_utils import ( DEQUANT_OP_NAME, + ONNX_TYPE_TO_NP_TYPE, QUANT_OP_NAME, QuantizedValue, QuantizedValueType, @@ -30,12 +31,14 @@ add_quant_input_suffix, add_quant_output_suffix, add_quant_suffix, + compute_data_quant_params, compute_scale_zp, compute_scale_zp_float8, find_by_name, get_qmin_qmax_for_qType, ms_domain, normalize_axis, + quantize_onnx_initializer, tensor_proto_to_array, ) from .registry import CreateQDQQuantizer @@ -86,6 +89,18 @@ class QDQTensorQuantParams: converted: QuantizationParams | None # Converted type consumed by some (or all/none) consumer nodes. converted_recv_nodes: set[str] | None # The name of nodes that consume the converted type. + def get_for_consumer(self, consumer_node_name) -> QuantizationParams: + if self.converted is None: # Quantized value is not converted, return original + return self.original + + if self.converted_recv_nodes is None: # All consumers receive the converted value + return self.converted + + # Check if consumer node name is in the list of nodes that + # receive the converted quantization value. If not, return the original value generated + # by the tensor's producer. + return self.converted if (consumer_node_name in self.converted_recv_nodes) else self.original + # Holds scale and zero_point initializer TensorProtos. @dataclass @@ -153,8 +168,8 @@ def __init__( op_types_to_quantize, extra_options, ) - self.tensors_to_quantize = {} - self.bias_to_quantize = {} + self.tensors_to_quantize: dict[str, QDQTensorQuantInfo] = {} + self.bias_to_quantize: dict[str, QDQBiasQuantInfo] = {} self.nodes_to_remove = [] @@ -191,6 +206,9 @@ def __init__( # Used in the QDQRemovableActivation class. self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False) + # Let user disable adjustment of weight scales for bias inputs that are quantized to int32. + self.qdq_disable_weight_adjust_for_int32_bias = extra_options.get("QDQDisableWeightAdjustForInt32Bias", False) + # The ONNX spec did not support 16-bit Q/DQ ops before opset 21. # So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types # are 16-bit or 4-bit integers. @@ -213,6 +231,7 @@ def __init__( self.qdq_op_domain = ms_domain self.quantization_params = self.calc_graph_quant_params() + self.initializer_quant_params: dict[str, QuantizationParams] = {} # Map of all original value names to quantized value names self.quantized_value_map = {} @@ -328,6 +347,18 @@ def quantize_weight_tensor_per_channel(self, tensor_name, axis): else: logging.warning(f"only support per-channel quantization on weight. Tensor: {tensor_name} is not quantized.") + def _dup_initializer(self, initializer: onnx.TensorProto) -> onnx.TensorProto: + """ + Duplicates an existing initializer and adds it to the model. Returns the new initializer. + """ + name_suffix: int = self.model.get_largest_initializer_name_suffix(initializer.name) + 1 + new_initializer_name = f"{initializer.name}{name_suffix}" + new_initializer = onnx.TensorProto() + new_initializer.CopyFrom(initializer) + new_initializer.name = new_initializer_name + self.model.add_initializer(new_initializer) + return new_initializer + def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, beta=1.0): """ Adds a bias tensor to the list of bias tensors to quantize. Called by op quantizers that @@ -353,15 +384,160 @@ def quantize_bias_tensor(self, node_name, bias_name, input_name, weight_name, be self.quantize_weight_tensor(bias_name) return - weight = find_by_name(bias_name, self.model.initializer()) - if weight is not None: - if weight.data_type in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): - if bias_name not in self.bias_to_quantize: - self.bias_to_quantize[bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) - else: - logging.warning(f"Bias {bias_name} has already been marked for quantization") - else: - logging.warning(f"Expected {bias_name} to be a weight") + bias_initializer = find_by_name(bias_name, self.model.initializer()) + if bias_initializer is None: + logging.warning(f"Expected bias '{bias_name}' to be an initializer") + return + + if bias_initializer.data_type not in (onnx_proto.TensorProto.FLOAT, onnx_proto.TensorProto.FLOAT16): + logging.info(f"Expected bias '{bias_name}' to be an floating-point initializer") + return + + actual_bias_name = bias_name + if bias_name in self.bias_to_quantize: + # This bias input is consumed by two different nodes. We need to duplicate the bias so that + # each node has its own bias input. This is necessary because the bias's scale is computed + # from the node's other input scales. + new_bias_initializer = self._dup_initializer(bias_initializer) + actual_bias_name = new_bias_initializer.name + + # Replace this node's bias input + self.model.replace_input_of_nodes(bias_name, actual_bias_name, {node_name}) + logging.info(f"Created a copy of bias input '{bias_name}' called '{actual_bias_name}'") + + # Add this to our list of biases to quantize. + self.bias_to_quantize[actual_bias_name] = QDQBiasQuantInfo(node_name, input_name, weight_name, beta) + + def _adjust_weight_scale_for_int32_bias( + self, + input_scale: np.ndarray, + weight_scale: np.ndarray, + weight_name: str, + bias_tp: onnx.TensorProto, + is_per_channel: bool, + ) -> tuple[bool, np.ndarray | None]: + """ + Checks if the bias scale (input_scale * weight_scale) that we intend to use is too small. + A bias scale that is too small leads to quantized bias values that fall outside the range of a int32 and have to + be clipped, which decreases accuracy. If this function detects such a scenario, the weight_scale value will be + increased to prevent this from happening. + + Although the adjustment method and amount differs, the idea to adjust the weight's scale came from the following + reference: + https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/tools/optimize/quantization_utils.cc#L252 + + :param input_scale: The input's scale. + :param weight_scale: The weight scale to potentially adjust. + :param weight_name: The weight initializer's name. Used for logging. + :param bias_tp: The bias ONNX initializer. + :param is_per_channel: True if the bias and weight are quantized per-channel. + :return: A tuple with a bool indicating if the weight's scale was adjusted and the new weight scale. + """ + if not weight_scale.size: + return False, None + + bias_float_data = tensor_proto_to_array(bias_tp) + + int32_info = np.iinfo(np.int32) + multiplicative_epsilon = 1.0001 + qrange = np.array(int32_info.max, dtype=np.float64) - np.array(int32_info.min + 1, dtype=np.float64) + weight_scale_dtype = weight_scale.dtype + updated_an_elem = False + + if not is_per_channel: + rmin = np.minimum(bias_float_data.min(), np.array(0, dtype=np.float64)) + rmax = np.maximum(bias_float_data.max(), np.array(0, dtype=np.float64)) + absmax = np.maximum(np.abs(rmin), np.abs(rmax)) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * absmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale.item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio. + ratio = bias_smallest_valid_scale / bias_candidate_scale + logging.info( + f"Increasing scale for weight `{weight_name}` by the ratio {ratio} to " + f"ensure bias input `{bias_tp.name}` has a valid scale." + ) + new_scale = weight_scale_fp64 * ratio + weight_scale = new_scale.astype(weight_scale_dtype) + updated_an_elem = True + elif weight_scale.shape and len(weight_scale.shape) == 1: + # per-channel case + num_elems = weight_scale.shape[0] + + for i in range(num_elems): + bias_rmax = np.abs(bias_float_data[i]) + bias_smallest_valid_scale = multiplicative_epsilon * (2.0 * bias_rmax) / qrange + + input_scale_fp64 = np.array(input_scale.item(), dtype=np.float64) + weight_scale_fp64 = np.array(weight_scale[i].item(), dtype=np.float64) + bias_candidate_scale = input_scale_fp64 * weight_scale_fp64 + if (bias_candidate_scale < bias_smallest_valid_scale) and (bias_candidate_scale > 0.0): + # The candidate bias scale would be too small, so increase the weight_scale by the necessary ratio. + ratio = bias_smallest_valid_scale / bias_candidate_scale + logging.info( + f"Increased scale[{i}] for weight `{weight_name}` by ratio {ratio} " + f"to ensure bias input `{bias_tp.name}` has a valid scale." + ) + new_scale = weight_scale_fp64 * ratio + weight_scale[i] = new_scale.astype(weight_scale_dtype) + updated_an_elem = True + + return updated_an_elem, weight_scale + + def _adjust_weight_quant_params_for_bias_tensors(self): + """ + Iterates through all bias inputs that should be quantized to int32. If the intended + bias scale (equal to input_scale * weight_scale) is too small, this function will increase + the associated weight's scale to ensure the bias does not overflow the int32 range when quantized. + """ + + if self.qdq_disable_weight_adjust_for_int32_bias: + # User passed an extra_option to disable this adjustment. + return + + for bias_name, bias_info in self.bias_to_quantize.items(): + if ( + bias_info.input_name not in self.quantization_params + or bias_info.input_name not in self.tensors_to_quantize + or bias_info.weight_name not in self.initializer_quant_params + ): + continue + + # Get the associated input's scale. + input_qparams = self.quantization_params[bias_info.input_name].get_for_consumer(bias_info.node_name) + input_info = self.tensors_to_quantize[bias_info.input_name] + input_scale = np.asarray( + input_qparams["scale"], dtype=onnx.helper.tensor_dtype_to_np_dtype(input_info.data_type) + ) + + weight_quant_params = self.initializer_quant_params[bias_info.weight_name] + weight_quant_type = weight_quant_params["quant_type"] + if weight_quant_type not in (onnx.TensorProto.INT8, onnx.TensorProto.INT16): + continue + + weight_zero_point: np.ndarray = weight_quant_params["zero_point"] + if weight_zero_point.any(): + # Skip if zero_point(s) are not all zero (i.e., symmetric quant) + continue + + weight_scale: np.ndarray = weight_quant_params["scale"] + is_per_channel = weight_quant_params.get("axis", None) is not None + + # Get adjusted weight scales. + did_update_weight_scale, new_weight_scale = self._adjust_weight_scale_for_int32_bias( + input_scale, + weight_scale, + bias_info.weight_name, + find_by_name(bias_name, self.model.initializer()), + is_per_channel, + ) + + if did_update_weight_scale: + weight_quant_params["scale"] = new_weight_scale def remove_node(self, node): self.nodes_to_remove.append(node) @@ -380,6 +556,8 @@ def quantize_model(self): self.tensor_to_its_receiving_nodes[tensor_name] = [] self.tensor_to_its_receiving_nodes[tensor_name].append(node) + self.initializer_quant_params = self._calc_initializer_quant_params() + self._adjust_weight_quant_params_for_bias_tensors() self._quantize_normal_tensors() self._quantize_sharing_param_tensors() if self.quantize_bias: @@ -475,38 +653,26 @@ def _create_qdq_nodes( ) self.model.add_nodes([qlinear_node, dequant_node]) - def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): + def _add_qdq_nodes_for_initializer(self, weight_proto: onnx.TensorProto): + """ + Adds Q/DQ nodes for an initializer. If `self.add_qdq_pair_to_weight` is true, creates + the sequence (weight_f32 -> Q -> DQ -> ). Otherwise, this function quantizes the initializer + and adds the sequence (weight_quant -> DQ ->). + """ weight_name = weight_proto.name - if axis is not None: - if self.opset_version < 13: - raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") - - qtype = self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType - if qtype == onnx.onnx_pb.TensorProto.UINT8: - qtype = onnx_proto.TensorProto.INT8 - - q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( - weight_name, - # Quantization type is forced to be TensorProto.INT8. - # when the expected value would be (see below) - # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. - # QLinearConv expects to have a unique value for all channels. - # This code does not enforce that but it is necessarily the case when the - # quantization is symmetric (as for INT8). - qtype, - axis, - keep_float_weight=self.add_qdq_pair_to_weight, - ) - else: - q_weight_name, zp_name, scale_name = self.quantize_initializer( - weight_proto, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, - keep_float_weight=self.add_qdq_pair_to_weight, - ) + if weight_name in self.quantized_value_map: + return + quant_params: QuantizationParams = self.initializer_quant_params[weight_name] + axis: int = quant_params.get("axis") + scale_zp_initializers = self._make_scale_zp_initializers(weight_name, quant_params) + q_weight_name: str | None = None weight_dequant_output = add_dequant_output_suffix(weight_name) self.model.replace_input_of_all_nodes(weight_name, weight_dequant_output) + if self.add_qdq_pair_to_weight: + # Don't actually quantize the weight. Instead, keep floating-point weight and create the node + # sequence (weight_f32 -> Q -> DQ -> weight_dequant) weight_quant_output = add_quant_output_suffix(weight_name) self._create_qdq_nodes( @@ -516,14 +682,26 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): weight_quant_output, weight_dequant_output, add_dequant_suffix(weight_name), - scale_name, - zp_name, + scale_zp_initializers.scale.name, + scale_zp_initializers.zero_point.name, axis, ) else: + # Quantize the weight and create the node sequence: + # (weight_quantized -> DQ -> weight_dequant) + quant_weight = quantize_onnx_initializer( + weight_proto, + quant_params["quant_type"], + quant_params["zero_point"], + quant_params["scale"], + axis, + ) + self.model.add_initializer(quant_weight) + + q_weight_name = quant_weight.name dequant_node = onnx.helper.make_node( DEQUANT_OP_NAME, - [q_weight_name, scale_name, zp_name], + [quant_weight.name, scale_zp_initializers.scale.name, scale_zp_initializers.zero_point.name], [weight_dequant_output], add_dequant_suffix(weight_name), axis=axis, @@ -531,6 +709,17 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): ) self.model.add_node(dequant_node) + # Log entry for this quantized weight + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_zp_initializers.scale.name, + scale_zp_initializers.zero_point.name, + QuantizedValueType.Initializer, + axis=axis, + ) + self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) + def _add_qdq_pair_for_activation(self, tensor_name, scale_name, zp_name, data_type=None): if ( self.dedicated_qdq_pair @@ -767,7 +956,7 @@ def _quantize_normal_tensors(self): # Quantize the input initializer = find_by_name(tensor_name, self.model.initializer()) if initializer: - self._add_qdq_pair_for_initializer(initializer, tensor_info.tensor_type, tensor_info.axis) + self._add_qdq_nodes_for_initializer(initializer) else: tensor_qparam_initializers = self._make_tensor_scale_zp_initializers(tensor_name) if not tensor_qparam_initializers: @@ -909,45 +1098,6 @@ def _quantize_bias_tensors(self): def is_tensor_quantized(self, tensor_name: str): return tensor_name in self.tensors_to_quantize or tensor_name in self.bias_to_quantize - def quantize_initializer( - self, - weight: onnx.TensorProto, - qType: onnx.TensorProto.DataType, - reduce_range: bool = False, - keep_float_weight: bool = False, - ) -> tuple[str, str, str]: - """ - :param weight: TensorProto initializer - :param qType: type to quantize to - :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. - If keep_float_weight is False, quantize the weight, or don't quantize the weight. - :return: quantized weight name, zero point name, scale name - """ - # Find if this input is already quantized - if weight.name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight.name].original - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - - q_weight_name, zp_name, scale_name = self.quantize_initializer_impl( - weight, qType, reduce_range, keep_float_weight - ) - - # Log entry for this quantized weight - quantized_value = QuantizedValue( - weight.name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight.name] = QDQTensorQuantizedValue(quantized_value, None, None) - return q_weight_name, zp_name, scale_name - def is_tensor_per_channel( self, tensor_name: str, @@ -997,38 +1147,6 @@ def is_tensor_per_channel( return True, axis - def quantize_weight_per_channel( - self, - weight_name: str, - weight_qType: onnx.TensorProto.DataType, - channel_axis: int, - reduce_range: bool = True, - keep_float_weight: bool = False, - ) -> tuple[str, str, str]: - # Find if this input is already quantized - if weight_name in self.quantized_value_map: - quantized_value = self.quantized_value_map[weight_name].original - return ( - quantized_value.q_name, - quantized_value.zp_name, - quantized_value.scale_name, - ) - - q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel_impl( - weight_name, weight_qType, channel_axis, reduce_range, keep_float_weight - ) - quantized_value = QuantizedValue( - weight_name, - q_weight_name, - scale_name, - zp_name, - QuantizedValueType.Initializer, - None, - ) - self.quantized_value_map[weight_name] = QDQTensorQuantizedValue(quantized_value, None, None) - - return q_weight_name, zp_name, scale_name - def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> str: """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale @@ -1040,15 +1158,15 @@ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> s # get scale for weight weight_scale_name = self.quantized_value_map[bias_info.weight_name].original.scale_name - weight_initializer = find_by_name(weight_scale_name, self.model.initializer()) - weight_scale = tensor_proto_to_array(weight_initializer) + weight_scale_initializer = find_by_name(weight_scale_name, self.model.initializer()) + weight_scale = tensor_proto_to_array(weight_scale_initializer) # get scale for input input_scale_name = ( self.quantized_value_map[bias_info.input_name].get_for_consumer(bias_info.node_name).scale_name ) - inputscale_initializer = find_by_name(input_scale_name, self.model.initializer()) - input_scale = tensor_proto_to_array(inputscale_initializer) + input_scale_initializer = find_by_name(input_scale_name, self.model.initializer()) + input_scale = tensor_proto_to_array(input_scale_initializer) ( quantized_bias_name, @@ -1074,7 +1192,7 @@ def quantize_bias_static(self, bias_name: str, bias_info: QDQBiasQuantInfo) -> s return quantized_bias_name def _make_scale_zp_initializers( - self, param_name: str, params: QuantizationParams, init_name_suffix: str = "" + self, param_name: str, quant_params: QuantizationParams, init_name_suffix: str = "" ) -> QDQScaleZpInitializers: """ Creates and returns scale and zero-point initializers for the given quantization params. The initializers are @@ -1082,31 +1200,31 @@ def _make_scale_zp_initializers( - {param_name}_zero_point{init_name_suffix} - {param_name}_scale{init_name_suffix} """ - zero_point_values = np.array([params["zero_point"]]) - if not hasattr(params["scale"], "dtype") or params["scale"].dtype not in (np.float32, np.float16): - raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") - scale_values = np.array([params["scale"]]) - assert scale_values.dtype != np.float64 - zero_point_type = params.data.get("quant_type", self.activation_qType) - - zero_point_shape = [] + zero_point = quant_params["zero_point"] + scale = quant_params["scale"] + zero_point_type = quant_params["quant_type"] + axis: int | None = quant_params.get("axis") + assert (axis is not None and len(scale.shape) == 1) or ( + axis is None and len(scale.shape) == 0 + ), "Wrong scale/zp shapes" + assert len(scale.shape) == len(zero_point.shape), "Scale and zero-point must have the same rank" + zero_point_name = param_name + "_zero_point" + init_name_suffix - scale_shape = [] scale_name = param_name + "_scale" + init_name_suffix # Add initializers to model init_zp = onnx.helper.make_tensor( - zero_point_name, zero_point_type, zero_point_shape, zero_point_values.ravel().tolist() + zero_point_name, zero_point_type, zero_point.shape, zero_point.ravel().tolist() ) self.model.add_initializer(init_zp) - if scale_values.dtype == np.float32: + if scale.dtype == np.float32: scale_type = onnx_proto.TensorProto.FLOAT - elif scale_values.dtype == np.float16: + elif scale.dtype == np.float16: scale_type = onnx_proto.TensorProto.FLOAT16 else: - raise ValueError(f"Unexpected dtype={scale_values.dtype} for param_name={param_name!r}") - init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale_shape, scale_values.reshape((-1,)).tolist()) + raise ValueError(f"Unexpected dtype={scale.dtype} for param_name={param_name!r}") + init_scale = onnx.helper.make_tensor(scale_name, scale_type, scale.shape, scale.ravel().tolist()) self.model.add_initializer(init_scale) return QDQScaleZpInitializers(init_scale, init_zp) @@ -1155,7 +1273,7 @@ def calc_quant_params(self, tensor_data: TensorData, quant_overrides: dict[str, qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range=reduce_range, symmetric=symmetric) zero, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, self.min_real_range) - return QuantizationParams(zero_point=zero, scale=scale, quant_type=quant_type) + return QuantizationParams(zero_point=zero.squeeze(), scale=scale.squeeze(), quant_type=quant_type) def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: """ @@ -1185,3 +1303,127 @@ def calc_graph_quant_params(self) -> dict[str, QDQTensorQuantParams]: quantization_params[tensor_name] = QDQTensorQuantParams(original, converted, converted_recv_nodes) return quantization_params + + def _calc_initializer_quant_params(self) -> dict[str, QuantizationParams]: + """ + Returns quantization parameters (scale/zero_point/quant_type) for all initializers. + """ + + quantization_params: dict[str, QuantizationParams] = {} + for tensor_name, tensor_info in self.tensors_to_quantize.items(): + initializer = find_by_name(tensor_name, self.model.initializer()) + if not initializer: + continue + + initializer_data = tensor_proto_to_array(initializer) + initializer_rank = len(initializer_data.shape) + + # initializers for elementwise ops use the quant_type for activations. + is_weight = tensor_info.tensor_type is QDQQuantTensorType.WEIGHT + quant_type = self.weight_qType if is_weight else self.activation_qType + + # Try to get scale/zp directly from user's overrides and avoid computation. + if self.tensor_quant_overrides.overrides_scale_zp(tensor_name): + overrides = self.tensor_quant_overrides[tensor_name] + if "quant_type" in overrides[0]: + quant_type = overrides[0]["quant_type"].tensor_type + + zp_dtype = ONNX_TYPE_TO_NP_TYPE[quant_type] + is_per_channel = "axis" in overrides[0] + if not is_per_channel: + quantization_params[tensor_name] = QuantizationParams( + zero_point=np.array(overrides[0]["zero_point"], dtype=zp_dtype), + scale=np.array(overrides[0]["scale"], initializer_data.dtype), + quant_type=quant_type, + ) + else: + zero_points_list = [] + scales_list = [] + for chan_overrides in overrides: + zero_points_list.append(np.array(chan_overrides["zero_point"], zp_dtype)) + scales_list.append(np.array(chan_overrides["scale"], dtype=initializer_data.dtype)) + + channel_axis = overrides[0]["axis"] + is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank) + if not is_axis_valid: + raise ValueError( + f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is " + f"out-of-bounds for rank {initializer_rank}" + ) + + quantization_params[tensor_name] = QuantizationParams( + zero_point=np.array(zero_points_list), + scale=np.array(scales_list), + quant_type=quant_type, + axis=norm_channel_axis, + ) + + continue + + # Compute scale/zp normally. User's overrides may still override parameters + # used to compute the scale/zp (e.g., rmin, rmax, symmetric, etc.) + overrides = self.tensor_quant_overrides.get(tensor_name, [{}]) + if "quant_type" in overrides[0]: + quant_type = overrides[0]["quant_type"].tensor_type + + channel_axis = overrides[0].get("axis", tensor_info.axis) + is_per_channel = channel_axis is not None + + # Note: always quantize per-channel initializers as symmetric because QLinear* ops require the + # same zero-point in every channel, which is necessarily the case for symmetric quantization. + is_symmetric_default = is_per_channel or ( + self.is_weight_symmetric(quant_type) if is_weight else self.is_activation_symmetric + ) + is_symmetric = overrides[0].get("symmetric", is_symmetric_default) + reduce_range = overrides[0].get("reduce_range", self.reduce_range) + zero_point: np.ndarray | None = None + scale: np.ndarray | None = None + + if not is_per_channel: + zero_point, scale = compute_data_quant_params( + initializer_data.flatten(), + quant_type, + is_symmetric, + reduce_range=reduce_range, + min_real_range=self.min_real_range, + rmin_override=overrides[0].get("rmin"), + rmax_override=overrides[0].get("rmax"), + ) + else: + is_axis_valid, norm_channel_axis = normalize_axis(channel_axis, initializer_rank) + if not is_axis_valid: + raise ValueError( + f"Weight {initializer.name} has a per-channel axis with value {channel_axis} that is " + f"out-of-bounds for rank {initializer_rank}" + ) + + channel_axis = norm_channel_axis + channel_count = initializer_data.shape[channel_axis] + zero_points_list = [] + scales_list = [] + for i in range(channel_count): + per_channel_data = initializer_data.take(i, channel_axis) + channel_overrides = overrides[i] if overrides and i < len(overrides) else {} + channel_zero_point, channel_scale = compute_data_quant_params( + per_channel_data.ravel(), + quant_type, + is_symmetric, + reduce_range=reduce_range, + min_real_range=self.min_real_range, + rmin_override=channel_overrides.get("rmin"), + rmax_override=channel_overrides.get("rmax"), + ) + zero_points_list.append(channel_zero_point) + scales_list.append(channel_scale) + + zero_point = np.asarray(zero_points_list) + scale = np.asarray(scales_list) + + quantization_params[tensor_name] = QuantizationParams( + zero_point=zero_point, + scale=scale, + quant_type=quant_type, + axis=channel_axis, + ) + + return quantization_params diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 54c791d0a3c58..2bf675745d093 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -33,6 +33,12 @@ int4 = None uint4 = None +try: + from onnx.reference.op_run import to_array_extended +except ImportError: + # old version of onnx. + to_array_extended = None + __producer__ = "onnx.quantize" __version__ = "0.1.0" @@ -157,7 +163,9 @@ def from_string(format): } ONNX_INT_TYPE_SYMMETRIC_RANGE = { + onnx_proto.TensorProto.UINT8: (numpy.array(0, dtype=numpy.uint8), numpy.array(254, dtype=numpy.uint8)), onnx_proto.TensorProto.INT8: (numpy.array(-127, dtype=numpy.int8), numpy.array(127, dtype=numpy.int8)), + onnx_proto.TensorProto.UINT16: (numpy.array(0, dtype=numpy.uint16), numpy.array(65534, dtype=numpy.uint16)), onnx_proto.TensorProto.INT16: (numpy.array(-32767, dtype=numpy.int16), numpy.array(32767, dtype=numpy.int16)), } @@ -230,7 +238,7 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): # which matches the python reference ONNX implementation of QuantizeLinear. # This data can be packed into 4-bit elements by using pack_bytes_to_4bit(). dtype = ONNX_TYPE_TO_NP_TYPE[qType] - (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) + qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False) cliplow = max(qmin, low) if low is not None else qmin cliphigh = min(qmax, high) if high is not None else qmax @@ -270,7 +278,7 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False, min_real_range=Non # Ensure a minimum float-point range if specified. if min_real_range is not None: - rmax = max(rmax, rmin + min_real_range) + rmax = max(rmax, rmin + numpy.asarray(min_real_range, dtype=rmin.dtype)) if symmetric: absmax = numpy.maximum(numpy.abs(rmin), numpy.abs(rmax)) @@ -339,13 +347,75 @@ def compute_scale_zp_float8(element_type, std): return [zero, scale] +def compute_data_quant_params( + data: numpy.ndarray, + quant_type: onnx.TensorProto.DataType, + symmetric: bool, + reduce_range: bool = False, + min_real_range: float | None = None, + rmin_override: float | None = None, + rmax_override: float | None = None, +) -> tuple[numpy.ndarray, numpy.ndarray]: + """ + Returns the zero_point and scale for the given data. + + :param data: The data for which to compute quantization parameters. + :param quant_type: The quantization data type. + :param symmetric: whether symmetric quantization is used or not. + :parameter reduce_range: True if the quantization range should be reduced. Defaults to False. + :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None. + :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data). + :parameter rmax_override: The value of rmax to use if not None. Otherwise, uses max(data). + :return: zero point and scale + """ + if not isinstance(data, numpy.ndarray): + raise TypeError(f"Weight must be given as an array not {type(data)}.") + if rmin_override is not None: + rmin = rmin_override + else: + rmin = data.min() if len(data) else 0.0 + + if rmax_override is not None: + rmax = rmax_override + else: + rmax = data.max() if len(data) else 0.0 + + rmin = numpy.array(rmin, dtype=data.dtype) + rmax = numpy.array(rmax, dtype=data.dtype) + scale = numpy.array(1.0, dtype=data.dtype) + + if quant_type == TensorProto.FLOAT8E4M3FN: + if reduce_range: + raise RuntimeError("Unsupported option reduce_range=True for float 8.") + std = numpy.std(data) + zero_point, scale = compute_scale_zp_float8(quant_type, std) + return _check_type(zero_point, scale, zero_point_index=0) + + if quant_type in ( + TensorProto.INT8, + TensorProto.UINT8, + TensorProto.INT16, + TensorProto.UINT16, + TensorProto.INT4, + TensorProto.UINT4, + ): + qmin, qmax = get_qmin_qmax_for_qType(quant_type, reduce_range, symmetric=symmetric) + if len(data): + zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) + else: + zero_point = numpy.array(0, dtype=qmin.dtype) + return _check_type(zero_point, scale, zero_point_index=0) + + raise ValueError(f"Unexpected value for quant_type={quant_type}.") + + def quantize_data( data, qType, symmetric, reduce_range=False, min_real_range=None, rmin_override=None, rmax_override=None -): +) -> tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]: """ :param data: data to quantize - :param qType: data type to quantize to. Supported types UINT8 and INT8 - :param symmetric: whether symmetric quantization is used or not. This is applied to INT8. + :param qType: data type to quantize to. + :param symmetric: whether symmetric quantization is used or not. :parameter reduce_range: True if the quantization range should be reduced. Defaults to False. :parameter min_real_range: Minimum floating-point range (i.e., rmax - rmin) to enforce. Defaults to None. :parameter rmin_override: The value of rmin to use if not None. Otherwise, uses min(data). @@ -367,28 +437,16 @@ def quantize_data( - *S*: scale - *z*: zero point """ - if not isinstance(data, numpy.ndarray): - raise TypeError(f"Weight must be given as an array not {type(data)}.") - if rmin_override is not None: - rmin = rmin_override - else: - rmin = data.min() if len(data) else 0.0 - - if rmax_override is not None: - rmax = rmax_override - else: - rmax = data.max() if len(data) else 0.0 - - rmin = numpy.array(rmin, dtype=data.dtype) - rmax = numpy.array(rmax, dtype=data.dtype) - zero_point = 0 - scale = numpy.array(1.0, dtype=data.dtype) - + zero_point, scale = compute_data_quant_params( + data, + qType, + symmetric, + reduce_range, + min_real_range, + rmin_override, + rmax_override, + ) if qType == TensorProto.FLOAT8E4M3FN: - if reduce_range: - raise RuntimeError("Unsupported option reduce_range=True for float 8.") - std = numpy.std(data) - zero_point, scale = compute_scale_zp_float8(qType, std) quantized_data = quantize_nparray(qType, data, scale, zero_point) if any((quantized_data.astype(numpy.uint8).ravel() & 127) == 127): np_data = numpy.asarray(data) @@ -396,7 +454,7 @@ def quantize_data( f"One of the quantized value is NaN data in [{np_data.min()}, {np_data.max()}], " f"quantized_data in [{quantized_data.min()}, {quantized_data.max()}]." ) - return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) + return zero_point, scale, quantized_data if qType in ( TensorProto.INT8, @@ -406,15 +464,91 @@ def quantize_data( TensorProto.INT4, TensorProto.UINT4, ): - if len(data): - qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) - zero_point, scale = compute_scale_zp(rmin, rmax, qmin, qmax, symmetric, min_real_range) quantized_data = quantize_nparray(qType, data, scale, zero_point) - return _check_type(rmin, rmax, zero_point, scale, quantized_data, zero_point_index=2) + return zero_point, scale, quantized_data raise ValueError(f"Unexpected value for qType={qType}.") +def quantize_onnx_initializer( + weight: onnx.TensorProto, + quant_type: onnx.TensorProto.DataType, + zero_point: numpy.ndarray, + scale: numpy.ndarray, + axis: int | None = None, + quant_weight_name: str | None = None, +) -> onnx.TensorProto: + """ + Returns a quantized version of the given ONNX initializer. + + :param weight: The ONNX initializer to quantize. + :param quant_type: The final quantized data type. + :param zero_point: The zero-point value to use for quantization. + :param scale: The scale value to use for quantization. + :param axis: The quantization axis if quantizing per-channel. Defaults to None. + :param quant_weight_name: The name of the quantized initializer. + If not specified, the quantized name is generated. + :return: The quantized ONNX initializer. + """ + weight_data = tensor_proto_to_array(weight) + q_weight_data: numpy.ndarray | None = None + + if axis is None: # Per-tensor quantization + q_weight_data = quantize_nparray(quant_type, weight_data.ravel(), scale, zero_point) + else: # Per-channel quantization + channel_count = weight_data.shape[axis] + channel_dims = list(weight_data.shape) # deep copy + channel_dims[axis] = 1 # only one per channel for reshape + quantized_channel_data_list = [] + + for i in range(channel_count): + channel_data = weight_data.take(i, axis) + channel_scale = scale[i] + channel_zero_point = zero_point[i] + quantized_channel_data = quantize_nparray( + quant_type, channel_data.ravel(), channel_scale, channel_zero_point + ) + quantized_channel_data_list.append(numpy.asarray(quantized_channel_data).reshape(channel_dims)) + + q_weight_data = numpy.concatenate(quantized_channel_data_list, axis) + + q_weight_name = quant_weight_name if quant_weight_name else f"{weight.name}{TENSOR_NAME_QUANT_SUFFIX}" + + if quant_type == onnx.TensorProto.FLOAT8E4M3FN: + q_weight_initializer = onnx.TensorProto() + q_weight_initializer.data_type = quant_type + q_weight_initializer.dims.extend(weight.dims) + q_weight_initializer.name = q_weight_name + # Do not remove .flatten().copy() numpy is not clear about data persistence. + q_weight_initializer.raw_data = q_weight_data.flatten().copy().tobytes() + if to_array_extended is not None: + # This test should not be needed but it helped catch some issues + # with data persistence and tobytes. + check = to_array_extended(q_weight_initializer) + if check.shape != weight_data.shape or check.tobytes() != q_weight_data.tobytes(): + raise RuntimeError( + f"The initializer of shape {weight_data.shape} could not be created, expecting " + f"{q_weight_data.tobytes()[:10]}, got {check.tobytes()[:10]} and shape={weight.shape}" + f"\nraw={str(q_weight_initializer)[:200]}." + ) + elif quant_type in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): + if q_weight_data.dtype not in (numpy.int8, numpy.uint8): + raise RuntimeError(f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values.") + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 + q_weight_initializer = onnx.helper.make_tensor(q_weight_name, quant_type, weight.dims, packed_data, raw=True) + else: + quant_np_dtype = onnx.helper.tensor_dtype_to_np_dtype(quant_type) + q_weight_data = numpy.asarray(q_weight_data, dtype=quant_np_dtype).reshape(weight.dims) + q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) + + return q_weight_initializer + + def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): # noqa: N802 """ Return qmin and qmax, the minimum and maximum value representable by the given qType diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 038cbdce92e94..f368f35955955 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -197,6 +197,9 @@ def __init__( removed if activations are asymmetrically quantized. Keeping these activations is necessary if optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear operators from the model. + QDQDisableWeightAdjustForInt32Bias = True/False: + Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias + has a scale (input_scale * weight_scale) that is too small. execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc. Raises: ValueError: Raise ValueError if execution provider is unknown @@ -600,6 +603,9 @@ def quantize_static( removed if activations are asymmetrically quantized. Keeping these activations is necessary if optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear operators from the model. + QDQDisableWeightAdjustForInt32Bias = True/False: + Default is False. If true, QDQ quantizer will not adjust the weight's scale when the bias + has a scale (input_scale * weight_scale) that is too small. """ if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN: if calibrate_method != CalibrationMethod.Distribution: diff --git a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py index 219d929d22fce..fbd0cc17f5d81 100644 --- a/onnxruntime/python/tools/quantization/tensor_quant_overrides.py +++ b/onnxruntime/python/tools/quantization/tensor_quant_overrides.py @@ -78,6 +78,10 @@ def has_per_channel_overrides(self, tensor_name: str) -> bool: overrides_list = self.overrides.get(tensor_name) return overrides_list and "axis" in overrides_list[0] + def overrides_scale_zp(self, tensor_name: str) -> bool: + overrides_list = self.overrides.get(tensor_name) + return overrides_list and ("scale" in overrides_list[0]) and ("zero_point" in overrides_list[0]) + def get_per_tensor_overrides( self, tensor_name: str, diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index b99c11abf6d2c..24039fe7398a8 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -1726,5 +1726,204 @@ def test_json_serialization(self): write_calibration_table(new_calibrate_tensors_range) +class TestAdjustWeightScaleForInt32Bias(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.adj_int32_bias_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_conv_test_model( + self, + input0_shape: list[int], + weight_shape: list[int], + onnx_float_type: onnx.TensorProto.DataType, + ): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input0_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + + tiny_value = 1e-7 if np_float_type == np.float32 else 0.007782 + # weight_scale = 2*tiny_value / 255.0 = 7.84313725490196e-10 + + weight_data = np.full(weight_shape, tiny_value, dtype=np_float_type) + with np.nditer(weight_data, op_flags=["readwrite"]) as it: + for i, x in enumerate(it): + if i % 2 == 0: + x[...] = -x + + weight = onnx.numpy_helper.from_array(weight_data, "weight") + + # if we set input_scale to 0.05, then normally bias_scale would be + # (input_scale * weight_scale) => (0.05 * 7.84314e-10) => 3.9215686274509805e-11 + # + # If we quantize the f32 bias with this bias_scale, we get + # [5.0/bias_scale, 4.0/bias_scale] = [127500000000, 102000000000]. These quantized bias values exceed the + # range of int32. + # + # The ORT quantization tool will clamp these out-of-bounds values to int32::max(), + # which can be very inaccurate. + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + with np.nditer(bias_data, op_flags=["readwrite"]) as it: + for i, x in enumerate(it): + if i % 2 == 0: + x[...] = 5.0 if np_float_type == np.float32 else 1400 + else: + x[...] = -4.5 if np_float_type == np.float32 else -1200 + + bias = onnx.numpy_helper.from_array(bias_data, "bias") + + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convfloat", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_adjust_weight_scale_for_int32_bias(self): + """ + Test adjustment of weight input's scale to ensure int32 bias's scale is not too small. + """ + test_configs = [ + (onnx.TensorProto.FLOAT, True), + (onnx.TensorProto.FLOAT, False), + (onnx.TensorProto.FLOAT16, True), + (onnx.TensorProto.FLOAT16, False), + ] + + for float_type, per_channel in test_configs: + with self.subTest(float_type=float_type, per_channel=per_channel): + label = f"_f{float_type}_perchannel{per_channel}" + float_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, f"conv{label}.qdq.onnx") + + # Create float model with a Conv that has tiny weight values. + # This tiny weight scale would normally create a very small bias scale that will saturate + # bias's int32 range. But, the qdq_quantizer adjusts the weight's scale to ensure this doesn't happen. + input0_shape = [1, 2, 4, 4] + weight_shape = [2, 2, 2, 2] + float_model = self.build_conv_test_model(input0_shape, weight_shape, float_type) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(float_type) + input0_rmin = 0.0 + input0_scale = 0.05 if float_type == onnx.TensorProto.FLOAT else 0.01 + input0_rmax = (input0_scale * 255.0) + input0_rmin + input_data_list = [ + {"input_0": np.full(input0_shape, input0_rmin, dtype=np_float_type)}, + {"input_0": np.full(input0_shape, (input0_rmax - input0_rmin) / 2.0, dtype=np_float_type)}, + {"input_0": np.full(input0_shape, input0_rmax, dtype=np_float_type)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + per_channel=per_channel, + ) + + # Check correctness + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + def build_model_convs_share_bias( + self, + input0_shape: list[int], + weight_shape: list[int], + onnx_float_type: onnx.TensorProto.DataType, + ): + np_float_type = onnx.helper.tensor_dtype_to_np_dtype(onnx_float_type) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx_float_type, input0_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx_float_type, None) + output_1 = onnx.helper.make_tensor_value_info("output_1", onnx_float_type, None) + + weight_0_data = np.ones(weight_shape, dtype=np_float_type) + weight_0 = onnx.numpy_helper.from_array(weight_0_data, "weight_0") + + weight_1_data = np.full(weight_shape, 0.5, dtype=np_float_type) + weight_1 = onnx.numpy_helper.from_array(weight_1_data, "weight_1") + + bias_shape = [weight_shape[0]] + bias_data = np.ones(bias_shape, dtype=np_float_type) + bias_shared = onnx.numpy_helper.from_array(bias_data, "bias_shared") + + conv_0_node = onnx.helper.make_node("Conv", ["input_0", "weight_0", "bias_shared"], ["output_0"], name="Conv0") + conv_1_node = onnx.helper.make_node("Conv", ["input_0", "weight_1", "bias_shared"], ["output_1"], name="Conv1") + graph = onnx.helper.make_graph( + [conv_0_node, conv_1_node], + "ConvWithSharedBiasToDup", + [input_0], + [output_0, output_1], + initializer=[weight_0, weight_1, bias_shared], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_dup_shared_bias(self): + """ + Test duplicating a bias that is shared by two nodes that want to quantize their bias to int32. + """ + float_model_path = os.path.join(self._tmp_dir_path, "convs_share_bias.float.onnx") + qdq_model_path = os.path.join(self._tmp_dir_path, "convs_share_bias.qdq.onnx") + + # Create float model with a Convs that share a bias input. The QDQ quantizer should add a + # duplicate bias so that each node has its own. + input0_shape = [1, 2, 4, 4] + weight_shape = [2, 2, 2, 2] + float_model = self.build_model_convs_share_bias(input0_shape, weight_shape, onnx.TensorProto.FLOAT) + onnx.save_model(float_model, float_model_path) + + # Create a data reader + input0_rmin = 0.0 + input0_scale = 0.05 + input0_rmax = (input0_scale * 255.0) + input0_rmin + input_data_list = [ + {"input_0": np.full(input0_shape, input0_rmin, dtype=np.float32)}, + {"input_0": np.full(input0_shape, (input0_rmax - input0_rmin) / 2.0, dtype=np.float32)}, + {"input_0": np.full(input0_shape, input0_rmax, dtype=np.float32)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # quantize model to QDQ + quantize_static( + float_model_path, + qdq_model_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + ) + + qdq_model = onnx.load_model(qdq_model_path) + bias_names = set() + + for node in qdq_model.graph.node: + if node.op_type == "DequantizeLinear" and node.input[0].startswith("bias_shared"): + bias_names.add(node.input[0]) + + self.assertEqual(len(bias_names), 2) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quant_util.py b/onnxruntime/test/python/quantization/test_quant_util.py index 96d841654adbd..b23d53f2a04e8 100644 --- a/onnxruntime/test/python/quantization/test_quant_util.py +++ b/onnxruntime/test/python/quantization/test_quant_util.py @@ -145,7 +145,7 @@ def test_quantize_data_4bit(self): for onnx_type, symmetric in subtest_configs: with self.subTest(onnx_type=onnx_type, symmetric=symmetric): - _, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) + zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) is_signed = onnx_type == onnx.TensorProto.INT4 np_int_type = numpy.int8 if is_signed else numpy.uint8 qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type) diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 21a772c5f56c7..41dae04f1c6ff 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -36,7 +36,7 @@ def setUp(self): self.bias = np.array([0.0, 1.0], dtype=np.float32) self.default_act_qtype = onnx.TensorProto.UINT8 self.default_wgt_qtype = onnx.TensorProto.UINT8 - self.default_wgt_qtype_per_channel = onnx.TensorProto.INT8 + self.default_wgt_qtype_per_channel = onnx.TensorProto.UINT8 self.default_bias_qtype = onnx.TensorProto.INT32 self.default_zp_scales = { @@ -49,7 +49,8 @@ def setUp(self): self.default_zp_scales_per_channel = { "INP": (0, np.float32(0.0235294122248888)), "SIG_OUT": (0, np.float32(0.003911871928721666)), - "WGT": ([0, 0], [np.float32(0.015748031437397003), np.float32(0.011811023578047752)]), + # per-channel weights are always symmetric (ie. zp = (qmin + qmax) / 2) + "WGT": ([127, 127], [np.float32(0.015748031437397003), np.float32(0.011811023578047752)]), "BIAS": ([0, 0], [np.float32(0.00006160428165458143), np.float32(0.00004620321124093607)]), "OUT": (0, np.float32(0.005075461231172085)), } @@ -420,12 +421,17 @@ def test_qdq_overrides_per_channel2(self): self.assertEqual(wgt_zp.data_type, quant_type.tensor_type) for index, (zp, scale) in enumerate(zip(wgt_zp.int32_data, wgt_sc.float_data)): - wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType(wgt_zp.data_type, reduce_range=reduce_range) + wgt_qmin, wgt_qmax = get_qmin_qmax_for_qType( + wgt_zp.data_type, + symmetric=True, # per-channel is always symmetric + reduce_range=reduce_range, + ) expected_zp, expected_scale = compute_scale_zp( np.array(rmin_vals[index], dtype=np.float32), np.array(rmax_vals[index], dtype=np.float32), wgt_qmin, wgt_qmax, + symmetric=True, # per-channel is always symmetric ) self.assertEqual(zp, expected_zp) self.assertEqual(scale, np.float32(expected_scale))