Skip to content

Commit

Permalink
[Quant Tool] Introduce get_qdq_config() helper to get QDQ configurati…
Browse files Browse the repository at this point in the history
…ons (#22677)

### Description
Introduces the `get_qdq_config()` function to get a quantization
configuration for a full integer QDQ model. This function provides an
easier way of specifying commonly used options and sets convenient
defaults. Specifically:

- Instead of requiring the user to pass a dictionary of `extra_options`,
the new interface adds function parameters for common settings:
  - All calibrator settings
  - Whether activations/weights are symmetric
  - Whether to keep or fuse relu/clip into Q
  - Minimum real range for quantization
  - Dictionary of tensor quantization overrides.
- Automatically scans the input floating-point model and fills out the
operator types to quantize. Otherwise, only a limited number of operator
types would be quantized by default.
- Detects if the input model uses external data. If so, ensures that the
generated QDQ model also uses external data.
- Detects if the model will use newly introduced quantization types
(int4/int16) with an older opset. If so, forces the use of the
`com.microsoft` domain for Q/DQ ops, which support all types.
- Automatically enables the "extra option" called
`ForceQuantizeNoInputCheck` to ensure data movement operators (e.g.,
Transpose) are always quantized.
- User can pass a function to indicate which nodes to exclude from
quantization.
- The user can still pass their own `extra_options` to override any of
the above if necessary.
 
```python
from onnxruntime.quantization import get_int_qdq_config, quantize # , ...

# Get QDQ configuration
qdq_config = get_int_qdq_config(
    float_model,
    data_reader,
    calibrate_method=CalibrationMethod.Percentile,
    calibrate_args={"percentile": 99.98},  # Converted to extra_options
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QInt8,
    per_channel=True,
    nodes_to_exclude=["Mul"], # Could also be a function. Ex: `lambda model, node: node.op_type == "Softmax"`

    # Other options converted to extra_options:
    min_real_range=0.0001,
    keep_removable_activations=True,
    activation_symmetric=True,
    weight_symmetric=True,
)

# Quantize model
quantize(float_model_path, qdq_model_path, qdq_config)
```
### Motivation and Context
Need a version of `get_qnn_qdq_config()` that is not EP-specific.
  • Loading branch information
adrianlizarraga authored Nov 6, 2024
1 parent 72186bb commit 2c1b17c
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 7 deletions.
1 change: 1 addition & 0 deletions onnxruntime/python/tools/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .quantize import DynamicQuantConfig # noqa: F401
from .quantize import QuantizationMode # noqa: F401
from .quantize import StaticQuantConfig # noqa: F401
from .quantize import get_qdq_config # noqa: F401
from .quantize import quantize # noqa: F401
from .quantize import quantize_dynamic # noqa: F401
from .quantize import quantize_static # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/quantization/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
DEQUANT_OP_NAME = "DequantizeLinear"
DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output"
TENSOR_NAME_QUANT_SUFFIX = "_quantized"
MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB

FLOAT8_DISTRIBUTIONS = {}

Expand Down
177 changes: 170 additions & 7 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import copy
import logging
import tempfile
from pathlib import Path
from typing import Union
from typing import Any, Callable

import onnx

from .calibrate import CalibrationDataReader, CalibrationMethod, TensorsData, create_calibrator
from .onnx_quantizer import ONNXQuantizer
from .qdq_quantizer import QDQQuantizer
from .quant_utils import (
MODEL_SIZE_THRESHOLD,
QuantFormat,
QuantizationMode,
QuantType,
Expand All @@ -22,6 +26,7 @@
save_and_reload_model_with_shape_infer,
)
from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry
from .tensor_quant_overrides import TensorQuantOverridesHelper


class QuantConfig:
Expand Down Expand Up @@ -213,6 +218,163 @@ def __init__(
self.extra_options = extra_options or {}


def get_qdq_config(
model_input: str | Path | onnx.ModelProto,
calibration_data_reader: CalibrationDataReader,
calibrate_method=CalibrationMethod.MinMax,
calibrate_args: dict[str, Any] | None = None,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8,
activation_symmetric: bool = False,
weight_symmetric: bool | None = None,
per_channel: bool = False,
keep_removable_activations: bool = False,
min_real_range: float | None = None,
tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None,
nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None,
extra_options: dict | None = None,
) -> StaticQuantConfig:
"""
Returns a configuration suitable that quantizes the entire model to integer precision.
Params:
model_input: Path to the input model file or ModelProto.
calibration_data_reader: Calibration data reader.
calibrate_methode: The calibration method. Defaults to MinMax.
activation_type: The default activation quantization type. Defaults to QUInt8.
weight_type: The default weight quantization type. Defaults to QUInt8.
activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default.
Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uint16,
the zero-point values are 127 and 32,767, respectively.
weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default.
Defaults to None. If set to None, weight_symmetric is assumed true if a weight's quant type is a signed int.
per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel.
Defaults to false. Alternatively, use the tensor-level `tensor_quant_overrides` to select individual operators
and their quantization axes.
keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not
be removed, and will be explicitly represented in the QDQ model. If false, these activations
are automatically removed if activations are asymmetrically quantized. Keeping these activations
is necessary if optimizations or EP transformations will later remove
QuantizeLinear/DequantizeLinear operators from the model.
min_real_range: Default is None. If set to a floating-point value, the calculation of the quantization parameters
(i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin)
is less than the specified minimum range, rmax will be set to rmin + min_real_range.
tensor_quant_overrides: tensor-level quantization overrides. Defaults to None.
The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list
contains a single dictionary. For per-channel quantization, the list contains either a dictionary for
each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis'
key must be present in the first dictionary for per-channel quantization.
Each dictionary contains optional overrides with the following keys and values.
'quant_type' = QuantType : The tensor's quantization data type.
'axis' = Int : The per-channel axis. Must be present for per-channel weights.
'scale' = Float : The scale value to use. Must also specify `zero_point` if set.
'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set.
'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also
set `scale` or `zero_point`.
'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also
set `scale` or `zero_point`. Only valid for initializers.
'rmax' = Float : Override the maximum real tensor value in calibration data.
Invalid if also set `scale` or `zero_point`.
'rmin' = Float : Override the minimum real tensor value in calibration data.
Invalid if also set `scale` or `zero_point`.
'convert' = Dict : A nested dictionary with the same keys for an activation
tensor that should be converted to another quantization type.
'convert["recv_nodes"] = Set : Set of node names that consume the converted activation,
other nodes get the original type. If not specified,
assume all consumer nodes get the converted type.
nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that
accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto
should be excluded from quantization.
extra_options: Additional options specified as string key/value pairs. Refer to the documentation for
`quantize_static` for valid keys and values.
Returns:
A StaticQuantConfig object
"""
q16_types = {QuantType.QInt16, QuantType.QUInt16}
q4_types = {QuantType.QInt4, QuantType.QUInt4}
op_types_to_exclude = {"Cast", "DequantizeLinear", "QuantizeLinear"}

model = (
model_input
if isinstance(model_input, onnx.ModelProto)
else onnx.load_model(model_input, load_external_data=False)
)

op_types = set()
model_has_external_data = False
overrides_helper = TensorQuantOverridesHelper(
copy.deepcopy(tensor_quant_overrides) if tensor_quant_overrides else {}
)

# check if the model has external data.
for initializer in model.graph.initializer:
if onnx.external_data_helper.uses_external_data(initializer):
model_has_external_data = True

final_nodes_to_exclude = []
if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list):
final_nodes_to_exclude.extend(nodes_to_exclude)

# Iterate through nodes to get all operator types in the model and
# call user's function to filter out nodes from quantization.
for node in model.graph.node:
op_types.add(node.op_type)
if nodes_to_exclude is not None and callable(nodes_to_exclude):
if nodes_to_exclude(model, node):
final_nodes_to_exclude.append(node.name)

final_extra_options = {
"MinimumRealRange": min_real_range,
"QDQKeepRemovableActivations": keep_removable_activations,
"ActivationSymmetric": activation_symmetric,
"WeightSymmetric": weight_symmetric,
"ForceQuantizeNoInputCheck": True,
"TensorQuantOverrides": overrides_helper.get_dict(),
}

# Pass along known calibration options
if calibrate_args:
calib_extra_options_keys = [
("symmetric", "CalibTensorRangeSymmetric"),
("moving_average", "CalibMovingAverage"),
("averaging_constant", "CalibMovingAverageConstant"),
("max_intermediate_outputs", "CalibMaxIntermediateOutputs"),
("percentile", "CalibPercentile"),
]
calib_extra_options = {
key: calibrate_args.get(name) for (name, key) in calib_extra_options_keys if name in calibrate_args
}
final_extra_options.update(calib_extra_options)

# ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain
# on Q/DQ operators if using 16-bit or 4-bit quantization.
onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
if onnx_opset.version < 21:
opset21_types = q16_types.union(q4_types)
overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types())
if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types:
final_extra_options["UseQDQContribOps"] = True

# Allow user's extra_options to override our final_extra_options.
if extra_options:
final_extra_options.update(extra_options)

return StaticQuantConfig(
calibration_data_reader,
calibrate_method=calibrate_method,
quant_format=QuantFormat.QDQ,
activation_type=activation_type,
weight_type=weight_type,
op_types_to_quantize=list(op_types.difference(op_types_to_exclude)),
nodes_to_exclude=final_nodes_to_exclude,
per_channel=per_channel,
use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD),
extra_options=final_extra_options,
)


class DynamicQuantConfig(QuantConfig):
def __init__(
self,
Expand Down Expand Up @@ -290,8 +452,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua


def quantize_static(
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
model_input: str | Path | onnx.ModelProto,
model_output: str | Path,
calibration_data_reader: CalibrationDataReader,
quant_format=QuantFormat.QDQ,
op_types_to_quantize=None,
Expand Down Expand Up @@ -473,6 +635,7 @@ def quantize_static(
("CalibMovingAverage", "moving_average"),
("CalibMovingAverageConstant", "averaging_constant"),
("CalibMaxIntermediateOutputs", "max_intermediate_outputs"),
("CalibPercentile", "percentile"),
]
calib_extra_options = {
key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options
Expand Down Expand Up @@ -590,8 +753,8 @@ def inc_dataloader():


def quantize_dynamic(
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
model_input: str | Path | onnx.ModelProto,
model_output: str | Path,
op_types_to_quantize=None,
per_channel=False,
reduce_range=False,
Expand Down Expand Up @@ -690,8 +853,8 @@ def quantize_dynamic(


def quantize(
model_input: Union[str, Path, onnx.ModelProto],
model_output: Union[str, Path],
model_input: str | Path | onnx.ModelProto,
model_output: str | Path,
quant_config: QuantConfig,
):
"""Quantize a model with QuantConfig.
Expand Down
Loading

0 comments on commit 2c1b17c

Please sign in to comment.