Skip to content

Commit

Permalink
Add 4bit quantizer to onnx runtime doc (#21835)
Browse files Browse the repository at this point in the history
### Description
Introduce how to use matmul_4bits_quantizer to do weight only
quantization.

### Motivation and Context
Add 4bit quantizer to onnx runtime doc
  • Loading branch information
fajin-corp authored Aug 27, 2024
1 parent d491241 commit 6265c3a
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions docs/performance/model-optimizations/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,59 @@ ONNX Runtime leverages the TensorRT Execution Provider for quantization on GPU n

We provide two end-to end examples: [Yolo V3](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization/object_detection/trt/yolov3) and [resnet50](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/quantization/image_classification/trt/resnet50).

## Quantize to Int4/UInt4

ONNX Runtime can quantize certain operators in a model to 4 bit integer types. Block-wise weight-only quantizaiton is applied to the operators. The supported op types are:
- [MatMul](https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul):
- The node is quantized only if the input `B` is constant
- support QOperator or QDQ format.
- If QOperator is selected, the node is converted to a [MatMulNBits](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits) node. Weight `B` is blockwise quantized and saved in the new node. [HQQ](https://arxiv.org/pdf/2309.15531.pdf), [GPTQ](https://huggingface.co/docs/transformers/main/en/quantization/gptq) and RTN (default) algorithms are supported.
- If QDQ is selected, the MatMul node is replaced by a DequantizeLinear -> MatMul pair. Weight `B` is blockwise quantized and saved in the DequantizeLinear node as an initializer.
- [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather):
- The node is quantized only if the input `data` is constant.
- support QOperator
- Gather is quantized to a [GatherBlockQuantized](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftgatherblockquantized) node. Input `data` is blockwise quantized and saved in the new node. Only support RTN algorithm.

Since Int4/UInt4 types are introduced in [onnx opset 21](https://github.com/onnx/onnx/releases/tag/v1.16.0), if the model's onnx domain version is < 21, it is force upgraded to opset 21. Please make sure the operators in the model are compatible with onnx opset 21.

To run a model that has GatherBlockQuantized nodes, ONNX Runtime 1.20 is needed.

Code Examples:

```python
from onnxruntime.quantization import (
matmul_4bits_quantizer,
quant_utils,
quantize
)
from pathlib import Path

model_fp32_path="path/to/orignal/model.onnx"
model_int4_path="path/to/save/quantized/model.onnx"

quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
block_size=128, # 2's exponential and >= 16
is_symmetric=True, # if true, quantize to Int4. otherwsie, quantize to uint4.
accuracy_level=4, # used by MatMulNbits, see https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#attributes-35
quant_format=quant_utils.QuantFormat.QOperator,
op_types_to_quantize=("MatMul","Gather"), # specify which op types to quantize
quant_axes=(("MatMul", 0), ("Gather", 1),) # specify which axis to quantize for an op type.

model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(
model,
nodes_to_exclude=None, # specify a list of nodes to exclude from quantizaiton
nodes_to_include=None, # specify a list of nodes to force include from quantization
algo_config=quant_config,)
quant.process()
quant.model.save_model_to_file(
model_int4_path,
True) # save data to external file

```

For AWQ and GTPQ quantization usage, please refer to [Gen-AI model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models#quantized-pytorch-model).

## FAQ
### Why am I not seeing performance improvements?
{: .no_toc }
Expand Down

0 comments on commit 6265c3a

Please sign in to comment.