From 2c1b17ce98a3505a62adf8533abaf2d186c47d8e Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 6 Nov 2024 10:27:02 -0800 Subject: [PATCH] [Quant Tool] Introduce get_qdq_config() helper to get QDQ configurations (#22677) ### Description Introduces the `get_qdq_config()` function to get a quantization configuration for a full integer QDQ model. This function provides an easier way of specifying commonly used options and sets convenient defaults. Specifically: - Instead of requiring the user to pass a dictionary of `extra_options`, the new interface adds function parameters for common settings: - All calibrator settings - Whether activations/weights are symmetric - Whether to keep or fuse relu/clip into Q - Minimum real range for quantization - Dictionary of tensor quantization overrides. - Automatically scans the input floating-point model and fills out the operator types to quantize. Otherwise, only a limited number of operator types would be quantized by default. - Detects if the input model uses external data. If so, ensures that the generated QDQ model also uses external data. - Detects if the model will use newly introduced quantization types (int4/int16) with an older opset. If so, forces the use of the `com.microsoft` domain for Q/DQ ops, which support all types. - Automatically enables the "extra option" called `ForceQuantizeNoInputCheck` to ensure data movement operators (e.g., Transpose) are always quantized. - User can pass a function to indicate which nodes to exclude from quantization. - The user can still pass their own `extra_options` to override any of the above if necessary. ```python from onnxruntime.quantization import get_int_qdq_config, quantize # , ... # Get QDQ configuration qdq_config = get_int_qdq_config( float_model, data_reader, calibrate_method=CalibrationMethod.Percentile, calibrate_args={"percentile": 99.98}, # Converted to extra_options activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8, per_channel=True, nodes_to_exclude=["Mul"], # Could also be a function. Ex: `lambda model, node: node.op_type == "Softmax"` # Other options converted to extra_options: min_real_range=0.0001, keep_removable_activations=True, activation_symmetric=True, weight_symmetric=True, ) # Quantize model quantize(float_model_path, qdq_model_path, qdq_config) ``` ### Motivation and Context Need a version of `get_qnn_qdq_config()` that is not EP-specific. --- .../python/tools/quantization/__init__.py | 1 + .../python/tools/quantization/quant_utils.py | 1 + .../python/tools/quantization/quantize.py | 177 ++++++++++- .../quantization/test_get_qdq_config.py | 276 ++++++++++++++++++ 4 files changed, 448 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_get_qdq_config.py diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 9d397499d45a4..712e15a6a1ca9 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -10,6 +10,7 @@ from .quantize import DynamicQuantConfig # noqa: F401 from .quantize import QuantizationMode # noqa: F401 from .quantize import StaticQuantConfig # noqa: F401 +from .quantize import get_qdq_config # noqa: F401 from .quantize import quantize # noqa: F401 from .quantize import quantize_dynamic # noqa: F401 from .quantize import quantize_static # noqa: F401 diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 9228ad33130f2..54c791d0a3c58 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -43,6 +43,7 @@ DEQUANT_OP_NAME = "DequantizeLinear" DEQUANT_OUTPUT_SUFFIX = "_DequantizeLinear_Output" TENSOR_NAME_QUANT_SUFFIX = "_quantized" +MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB FLOAT8_DISTRIBUTIONS = {} diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 745344dc01fcb..038cbdce92e94 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -3,10 +3,13 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + +import copy import logging import tempfile from pathlib import Path -from typing import Union +from typing import Any, Callable import onnx @@ -14,6 +17,7 @@ from .onnx_quantizer import ONNXQuantizer from .qdq_quantizer import QDQQuantizer from .quant_utils import ( + MODEL_SIZE_THRESHOLD, QuantFormat, QuantizationMode, QuantType, @@ -22,6 +26,7 @@ save_and_reload_model_with_shape_infer, ) from .registry import IntegerOpsRegistry, QDQRegistry, QLinearOpsRegistry +from .tensor_quant_overrides import TensorQuantOverridesHelper class QuantConfig: @@ -213,6 +218,163 @@ def __init__( self.extra_options = extra_options or {} +def get_qdq_config( + model_input: str | Path | onnx.ModelProto, + calibration_data_reader: CalibrationDataReader, + calibrate_method=CalibrationMethod.MinMax, + calibrate_args: dict[str, Any] | None = None, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + activation_symmetric: bool = False, + weight_symmetric: bool | None = None, + per_channel: bool = False, + keep_removable_activations: bool = False, + min_real_range: float | None = None, + tensor_quant_overrides: dict[str, list[dict[str, Any]]] | None = None, + nodes_to_exclude: list[str] | Callable[[onnx.ModelProto, onnx.NodeProto], bool] | None = None, + extra_options: dict | None = None, +) -> StaticQuantConfig: + """ + Returns a configuration suitable that quantizes the entire model to integer precision. + + Params: + model_input: Path to the input model file or ModelProto. + calibration_data_reader: Calibration data reader. + calibrate_methode: The calibration method. Defaults to MinMax. + activation_type: The default activation quantization type. Defaults to QUInt8. + weight_type: The default weight quantization type. Defaults to QUInt8. + activation_symmetric: True if activations should be quantized symmetrically (i.e, rmax == -rmin) by default. + Defaults to false. For int8 and int16, this results in zero-point values of 0. For uint8 and uint16, + the zero-point values are 127 and 32,767, respectively. + weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default. + Defaults to None. If set to None, weight_symmetric is assumed true if a weight's quant type is a signed int. + per_channel: Global option that determines if a fixed set of operator types should be quantized per-channel. + Defaults to false. Alternatively, use the tensor-level `tensor_quant_overrides` to select individual operators + and their quantization axes. + keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not + be removed, and will be explicitly represented in the QDQ model. If false, these activations + are automatically removed if activations are asymmetrically quantized. Keeping these activations + is necessary if optimizations or EP transformations will later remove + QuantizeLinear/DequantizeLinear operators from the model. + min_real_range: Default is None. If set to a floating-point value, the calculation of the quantization parameters + (i.e., scale and zero point) will enforce a minimum range between rmin and rmax. If (rmax - rmin) + is less than the specified minimum range, rmax will be set to rmin + min_real_range. + tensor_quant_overrides: tensor-level quantization overrides. Defaults to None. + The key is a tensor name and the value is a list of dictionaries. For per-tensor quantization, the list + contains a single dictionary. For per-channel quantization, the list contains either a dictionary for + each channel in the tensor or a single dictionary that is assumed to apply to all channels. An 'axis' + key must be present in the first dictionary for per-channel quantization. + + Each dictionary contains optional overrides with the following keys and values. + 'quant_type' = QuantType : The tensor's quantization data type. + 'axis' = Int : The per-channel axis. Must be present for per-channel weights. + 'scale' = Float : The scale value to use. Must also specify `zero_point` if set. + 'zero_point' = Int : The zero-point value to use. Must also specify `scale` is set. + 'symmetric' = Bool : If the tensor should use symmetric quantization. Invalid if also + set `scale` or `zero_point`. + 'reduce_range' = Bool : If the quantization range should be reduced. Invalid if also + set `scale` or `zero_point`. Only valid for initializers. + 'rmax' = Float : Override the maximum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'rmin' = Float : Override the minimum real tensor value in calibration data. + Invalid if also set `scale` or `zero_point`. + 'convert' = Dict : A nested dictionary with the same keys for an activation + tensor that should be converted to another quantization type. + 'convert["recv_nodes"] = Set : Set of node names that consume the converted activation, + other nodes get the original type. If not specified, + assume all consumer nodes get the converted type. + nodes_to_exclude: List of nodes names to exclude from quantization. Alternatively, can provide a function that + accepts an onnx.ModelProto and onnx.NodeProto as arguments and returns true if the give onnx.NodeProto + should be excluded from quantization. + extra_options: Additional options specified as string key/value pairs. Refer to the documentation for + `quantize_static` for valid keys and values. + + Returns: + A StaticQuantConfig object + """ + q16_types = {QuantType.QInt16, QuantType.QUInt16} + q4_types = {QuantType.QInt4, QuantType.QUInt4} + op_types_to_exclude = {"Cast", "DequantizeLinear", "QuantizeLinear"} + + model = ( + model_input + if isinstance(model_input, onnx.ModelProto) + else onnx.load_model(model_input, load_external_data=False) + ) + + op_types = set() + model_has_external_data = False + overrides_helper = TensorQuantOverridesHelper( + copy.deepcopy(tensor_quant_overrides) if tensor_quant_overrides else {} + ) + + # check if the model has external data. + for initializer in model.graph.initializer: + if onnx.external_data_helper.uses_external_data(initializer): + model_has_external_data = True + + final_nodes_to_exclude = [] + if nodes_to_exclude is not None and isinstance(nodes_to_exclude, list): + final_nodes_to_exclude.extend(nodes_to_exclude) + + # Iterate through nodes to get all operator types in the model and + # call user's function to filter out nodes from quantization. + for node in model.graph.node: + op_types.add(node.op_type) + if nodes_to_exclude is not None and callable(nodes_to_exclude): + if nodes_to_exclude(model, node): + final_nodes_to_exclude.append(node.name) + + final_extra_options = { + "MinimumRealRange": min_real_range, + "QDQKeepRemovableActivations": keep_removable_activations, + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, + "ForceQuantizeNoInputCheck": True, + "TensorQuantOverrides": overrides_helper.get_dict(), + } + + # Pass along known calibration options + if calibrate_args: + calib_extra_options_keys = [ + ("symmetric", "CalibTensorRangeSymmetric"), + ("moving_average", "CalibMovingAverage"), + ("averaging_constant", "CalibMovingAverageConstant"), + ("max_intermediate_outputs", "CalibMaxIntermediateOutputs"), + ("percentile", "CalibPercentile"), + ] + calib_extra_options = { + key: calibrate_args.get(name) for (name, key) in calib_extra_options_keys if name in calibrate_args + } + final_extra_options.update(calib_extra_options) + + # ONNX opset < 21 does not support 16-bit quantization, so must use 'com.microsoft' domain + # on Q/DQ operators if using 16-bit or 4-bit quantization. + onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx") + if onnx_opset.version < 21: + opset21_types = q16_types.union(q4_types) + overrides_have_opset21_types = any(t in opset21_types for t in overrides_helper.get_quant_types()) + if activation_type in opset21_types or weight_type in opset21_types or overrides_have_opset21_types: + final_extra_options["UseQDQContribOps"] = True + + # Allow user's extra_options to override our final_extra_options. + if extra_options: + final_extra_options.update(extra_options) + + return StaticQuantConfig( + calibration_data_reader, + calibrate_method=calibrate_method, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + op_types_to_quantize=list(op_types.difference(op_types_to_exclude)), + nodes_to_exclude=final_nodes_to_exclude, + per_channel=per_channel, + use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), + extra_options=final_extra_options, + ) + + class DynamicQuantConfig(QuantConfig): def __init__( self, @@ -290,8 +452,8 @@ def check_static_quant_arguments(quant_format: QuantFormat, activation_type: Qua def quantize_static( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, calibration_data_reader: CalibrationDataReader, quant_format=QuantFormat.QDQ, op_types_to_quantize=None, @@ -473,6 +635,7 @@ def quantize_static( ("CalibMovingAverage", "moving_average"), ("CalibMovingAverageConstant", "averaging_constant"), ("CalibMaxIntermediateOutputs", "max_intermediate_outputs"), + ("CalibPercentile", "percentile"), ] calib_extra_options = { key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options @@ -590,8 +753,8 @@ def inc_dataloader(): def quantize_dynamic( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, op_types_to_quantize=None, per_channel=False, reduce_range=False, @@ -690,8 +853,8 @@ def quantize_dynamic( def quantize( - model_input: Union[str, Path, onnx.ModelProto], - model_output: Union[str, Path], + model_input: str | Path | onnx.ModelProto, + model_output: str | Path, quant_config: QuantConfig, ): """Quantize a model with QuantConfig. diff --git a/onnxruntime/test/python/quantization/test_get_qdq_config.py b/onnxruntime/test/python/quantization/test_get_qdq_config.py new file mode 100644 index 0000000000000..d7055764f745a --- /dev/null +++ b/onnxruntime/test/python/quantization/test_get_qdq_config.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import os +import tempfile +import unittest + +import numpy as np +import onnx +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import CalibrationMethod, QuantFormat, QuantType, get_qdq_config, quantize + + +class TestGetQDQConfig(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.int_qdq_config_") + + # Note: swap with the commented line if you want to see the models in local test dir. + cls._tmp_dir_path = cls._tmp_model_dir.name + # cls._tmp_dir_path = "." + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def build_add_model( + self, + shape: list[int], + tensor_type: onnx.TensorProto.DataType, + weight: onnx.TensorProto | None = None, + opset: int = 21, + ) -> onnx.ModelProto: + """ + Returns an onnx.ModelProto with a single Add operator. The second input can be optionally made + a static weight. + """ + graph_inputs = [onnx.helper.make_tensor_value_info("input_0", tensor_type, shape)] + graph_outputs = [onnx.helper.make_tensor_value_info("output_0", tensor_type, shape)] + initializers = [] + add_input_names = ["input_0"] + + if weight is not None: + initializers.append(weight) + add_input_names.append(weight.name) + else: + graph_inputs.append(onnx.helper.make_tensor_value_info("input_1", tensor_type, shape)) + add_input_names.append("input_1") + + add_node = onnx.helper.make_node("Add", add_input_names, ["output_0"], name="Add0") + + graph = onnx.helper.make_graph( + [add_node], + "AddGraph", + graph_inputs, + graph_outputs, + initializer=initializers, + ) + opset_imports = [onnx.helper.make_opsetid("", opset)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + onnx.checker.check_model(model, True) + return model + + def test_basic_args(self): + """ + Test that get_qdq_config() returns a config that sets the basic args. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=21) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + calibrate_method=CalibrationMethod.Percentile, + calibrate_args={"percentile": 99.98}, # Converted to extra_options + activation_type=QuantType.QUInt16, + weight_type=QuantType.QInt16, + per_channel=True, + nodes_to_exclude=["Mul"], + # Other options converted to extra_options: + min_real_range=0.0001, + keep_removable_activations=True, + activation_symmetric=True, + weight_symmetric=True, + ) + self.assertEqual(qdq_config.calibrate_method, CalibrationMethod.Percentile) + self.assertEqual(qdq_config.activation_type, QuantType.QUInt16) + self.assertEqual(qdq_config.weight_type, QuantType.QInt16) + self.assertTrue(qdq_config.per_channel) + self.assertEqual(set(qdq_config.nodes_to_exclude), {"Mul"}) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + + # Check that calibration args are translated to extra_options. + self.assertEqual(qdq_config.extra_options["CalibPercentile"], 99.98) + + # Check that other args are also translated to extra_options. + self.assertEqual(qdq_config.extra_options["MinimumRealRange"], 0.0001) + self.assertTrue(qdq_config.extra_options["QDQKeepRemovableActivations"]) + self.assertTrue(qdq_config.extra_options["ActivationSymmetric"]) + self.assertTrue(qdq_config.extra_options["WeightSymmetric"]) + + # The following options should always be set to specific values. + self.assertTrue(qdq_config.extra_options["ForceQuantizeNoInputCheck"]) + self.assertEqual(qdq_config.quant_format, QuantFormat.QDQ) + + # Should use onnx domain Q/DQ ops because onnx opset >= 21. + self.assertFalse(qdq_config.extra_options.get("UseQDQContribOps", False)) + + def test_exclude_nodes_callable(self): + """ + Test passing a function/callable to exclude nodes from quantization. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=21) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Local function that excludes all "Add" nodes. + def should_exclude_node_(model: onnx.ModelProto, node: onnx.NodeProto) -> bool: + return node.op_type == "Add" + + qdq_config = get_qdq_config( + float_model, + data_reader, + nodes_to_exclude=should_exclude_node_, + ) + + expected_excluded_nodes = set([node.name for node in float_model.graph.node if node.op_type == "Add"]) + self.assertTrue(bool(expected_excluded_nodes)) + self.assertEqual(set(qdq_config.nodes_to_exclude), expected_excluded_nodes) + + def test_external_data(self): + """ + Test that get_qdq_config() returns a config that enables external data + if the input model has external data. + """ + + # Create model with a weight large enough (> 1024 bytes) to be stored externally. + shape = [1, 32, 32] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + large_weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, large_weight) + float_model_path = os.path.join(self._tmp_dir_path, "add_ext_data_int_qdq_config.onnx") + + onnx.save_model( + float_model, + float_model_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="add_ext_data_int_qdq_config.bin", + ) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(0, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Create a quantization config and check that it sets boolean to use external data + qdq_config = get_qdq_config( + float_model_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8 + ) + self.assertEqual(set(qdq_config.op_types_to_quantize), {"Add"}) + self.assertTrue(qdq_config.use_external_data_format) + + # Quantize the model and check computational correctness against float model. + qdq_model_path = os.path.join(self._tmp_dir_path, "add_ext_data_int_qdq_config.qdq.onnx") + quantize(float_model_path, qdq_model_path, qdq_config) + + expected_op_counts = {"DequantizeLinear": 3, "QuantizeLinear": 2, "Add": 1} + check_op_type_count(self, qdq_model_path, **expected_op_counts) + + data_reader.rewind() + check_model_correctness(self, float_model_path, qdq_model_path, data_reader.get_next()) + + # The quantized weight should still be stored in an external file. + qdq_model = onnx.load_model(qdq_model_path, load_external_data=False) + weight_quantized = next( + ( + initializer + for initializer in qdq_model.graph.initializer + if initializer.name == f"{large_weight.name}_quantized" + ), + None, + ) + self.assertIsNotNone(weight_quantized) + self.assertEqual(weight_quantized.data_location, onnx.TensorProto.EXTERNAL) + + def test_use_qdq_contrib_ops_for_int16_opset19(self): + """ + Test that get_qdq_config() returns a config that forces 'com.microsoft' Q/DQ ops for + use of int16 in opset < 21. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=19) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + qdq_config = get_qdq_config( + float_model, + data_reader, + activation_type=QuantType.QUInt16, + weight_type=QuantType.QInt8, + ) + + self.assertEqual(qdq_config.activation_type, QuantType.QUInt16) + self.assertTrue(qdq_config.extra_options["UseQDQContribOps"]) + + def test_use_qdq_contrib_ops_for_int4_opset19(self): + """ + Test that get_qdq_config() returns a config that forces 'com.microsoft' Q/DQ ops for + use of int4 in opset < 21. + """ + + shape = [1, 8, 8] + tensor_type = onnx.TensorProto.FLOAT + np_dtype = onnx.helper.tensor_dtype_to_np_dtype(tensor_type) + weight = onnx.numpy_helper.from_array(np.ones(shape, dtype=np_dtype), "weight") + float_model = self.build_add_model(shape, tensor_type, weight, opset=19) + + input_data_list = [ + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(-2, dtype=np_dtype)}, + {"input_0": np.ones(shape, dtype=np_dtype) * np.array(2, dtype=np_dtype)}, + ] + data_reader = TestDataFeeds(input_data_list) + + # Use int4 in tensor quantization overrides. This should still force use of 'com.microsoft' Q/DQ ops. + qdq_config = get_qdq_config( + float_model, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QInt8, + tensor_quant_overrides={"weight": [{"quant_type": QuantType.QInt4}]}, + ) + + self.assertEqual(qdq_config.extra_options["TensorQuantOverrides"]["weight"][0]["quant_type"], QuantType.QInt4) + self.assertTrue(qdq_config.extra_options["UseQDQContribOps"]) + + +if __name__ == "__main__": + unittest.main()