Skip to content

Commit b3a8cdd

Browse files
add int8 quantization support (#3058)
1 parent d84cd18 commit b3a8cdd

File tree

6 files changed

+107
-22
lines changed

6 files changed

+107
-22
lines changed

examples/dynamo/vgg16_fp8_ptq.py renamed to examples/dynamo/vgg16_ptq.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
2-
.. _vgg16_fp8_ptq:
2+
.. _vgg16_ptq:
33
44
Deploy Quantized Models using Torch-TensorRT
55
======================================================
66
7-
Here we demonstrate how to deploy a model quantized to FP8 using the Dynamo frontend of Torch-TensorRT
7+
Here we demonstrate how to deploy a model quantized to INT8 or FP8 using the Dynamo frontend of Torch-TensorRT
88
"""
99

1010
# %%
@@ -111,7 +111,12 @@ def vgg16(num_classes=1000, init_weights=False):
111111
type=int,
112112
help="Batch size for tuning the model with PTQ and FP8",
113113
)
114-
114+
PARSER.add_argument(
115+
"--quantize-type",
116+
default="int8",
117+
type=str,
118+
help="quantization type, currently supported int8 or fp8 for PTQ",
119+
)
115120
args = PARSER.parse_args()
116121

117122
model = vgg16(num_classes=10, init_weights=False)
@@ -191,8 +196,10 @@ def calibrate_loop(model):
191196
# %%
192197
# Tune the pre-trained model with FP8 and PTQ
193198
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
194-
195-
quant_cfg = mtq.FP8_DEFAULT_CFG
199+
if args.quantize_type == "int8":
200+
quant_cfg = mtq.INT8_DEFAULT_CFG
201+
elif args.quantize_type == "fp8":
202+
quant_cfg = mtq.FP8_DEFAULT_CFG
196203
# PTQ with in-place replacement to quantized modules
197204
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
198205
# model has FP8 qdq nodes at this point
@@ -226,11 +233,18 @@ def calibrate_loop(model):
226233
with export_torch_mode():
227234
# Compile the model with Torch-TensorRT Dynamo backend
228235
input_tensor = images.cuda()
229-
exp_program = torch.export.export(model, (input_tensor,))
236+
# torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
237+
from torch.export._trace import _export
238+
239+
exp_program = _export(model, (input_tensor,))
240+
if args.quantize_type == "int8":
241+
enabled_precisions = {torch.int8}
242+
elif args.quantize_type == "fp8":
243+
enabled_precisions = {torch.float8_e4m3fn}
230244
trt_model = torchtrt.dynamo.compile(
231245
exp_program,
232246
inputs=[input_tensor],
233-
enabled_precisions={torch.float8_e4m3fn},
247+
enabled_precisions=enabled_precisions,
234248
min_block_size=1,
235249
debug=False,
236250
)

py/torch_tensorrt/_enums.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from typing import Any, Optional, Type, Union
66

77
import numpy as np
8+
import tensorrt as trt
89
import torch
910
from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
1011

11-
import tensorrt as trt
12-
1312

1413
class dtype(Enum):
1514
"""Enum to describe data types to Torch-TensorRT, has compatibility with torch, tensorrt and numpy dtypes"""

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -606,28 +606,30 @@ def aten_ops_neg(
606606
try:
607607
import modelopt.torch.quantization as mtq # noqa: F401
608608

609-
assert torch.ops.trt.quantize_fp8.default
609+
assert torch.ops.tensorrt.quantize_op.default
610610
except Exception as e:
611611
_LOGGER.warning(
612612
"Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
613613
)
614614
else:
615615

616-
@dynamo_tensorrt_converter(torch.ops.trt.quantize_fp8.default)
617-
def aten_ops_quantize_fp8(
616+
@dynamo_tensorrt_converter(torch.ops.tensorrt.quantize_op.default)
617+
def aten_ops_quantize_op(
618618
ctx: ConversionContext,
619619
target: Target,
620620
args: Tuple[Argument, ...],
621621
kwargs: Dict[str, Argument],
622622
name: str,
623623
) -> Union[TRTTensor, Sequence[TRTTensor]]:
624-
return impl.quantize.quantize_fp8(
624+
return impl.quantize.quantize(
625625
ctx,
626626
target,
627627
SourceIR.ATEN,
628628
name,
629629
args[0],
630630
args[1],
631+
args[2],
632+
args[3],
631633
)
632634

633635

py/torch_tensorrt/dynamo/conversion/impl/quantize.py

+27-9
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,54 @@
1010
from torch_tensorrt.fx.types import TRTTensor
1111

1212

13-
def quantize_fp8(
13+
def quantize(
1414
ctx: ConversionContext,
1515
target: Target,
1616
source_ir: Optional[SourceIR],
1717
name: str,
1818
input_tensor: TRTTensor,
19-
scale: np.ndarray,
19+
amax: np.ndarray,
20+
num_bits: int,
21+
exponent_bits: int,
2022
) -> TRTTensor:
2123
"""
2224
Adds quantize and dequantize ops (QDQ) which quantize to INT8 or FP8 based
2325
on the output_type set and dequantizes them back.
2426
"""
25-
if (isinstance(input_tensor, TRTTensor)) and not (
26-
input_tensor.dtype == trt.float32 or input_tensor.dtype == trt.float16
27+
if isinstance(input_tensor, TRTTensor) and input_tensor.dtype not in (
28+
trt.float32,
29+
trt.float16,
2730
):
2831
raise ValueError(
29-
f"quantize_fp8 converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
32+
f"quantize converter received an input of {input_tensor.dtype} type. Supported types: float32 | float16"
3033
)
31-
34+
if num_bits != 8 or exponent_bits not in (0, 4):
35+
raise ValueError(
36+
f"quantize converter currently only accept INT8 or FP8 based quantize, got {num_bits=}, {exponent_bits=}"
37+
)
38+
if num_bits == 8 and exponent_bits == 0:
39+
max_bound = 127
40+
elif num_bits == 8 and exponent_bits == 4:
41+
max_bound = 448
42+
scale = np.divide(amax, max_bound)
3243
scale = get_trt_tensor(ctx, scale, name + "_scale")
3344
# Add Q node
3445
quantize_layer = ctx.net.add_quantize(input_tensor, scale)
35-
quantize_layer.set_output_type(0, trt.DataType.FP8)
46+
if num_bits == 8 and exponent_bits == 0:
47+
quantize_layer.set_output_type(0, trt.DataType.INT8)
48+
elif num_bits == 8 and exponent_bits == 4:
49+
quantize_layer.set_output_type(0, trt.DataType.FP8)
50+
3651
set_layer_name(quantize_layer, target, name + "_quantize", source_ir)
3752
q_output = quantize_layer.get_output(0)
3853
# Add DQ node
3954
dequantize_layer = ctx.net.add_dequantize(q_output, scale)
4055
set_layer_name(dequantize_layer, target, name + "_dequantize", source_ir)
41-
# Set DQ layer precision to FP8
42-
dequantize_layer.precision = trt.DataType.FP8
56+
if num_bits == 8 and exponent_bits == 0:
57+
dequantize_layer.precision = trt.DataType.INT8
58+
elif num_bits == 8 and exponent_bits == 4:
59+
# Set DQ layer precision to FP8
60+
dequantize_layer.precision = trt.DataType.FP8
4361
dq_output = dequantize_layer.get_output(0)
4462

4563
return dq_output

tests/py/dynamo/models/test_models_export.py

+50
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# type: ignore
22
import unittest
33

4+
import modelopt
45
import pytest
56
import timm
67
import torch
@@ -225,3 +226,52 @@ def calibrate_loop(model):
225226
)
226227
outputs_trt = trt_model(input_tensor)
227228
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)
229+
230+
231+
@unittest.skipIf(
232+
modelopt.__version__ < "0.16.1",
233+
"Int8 quantization is supported in modelopt since 0.16.1 or later",
234+
)
235+
@pytest.mark.unit
236+
def test_base_int8(ir):
237+
class SimpleNetwork(torch.nn.Module):
238+
def __init__(self):
239+
super(SimpleNetwork, self).__init__()
240+
self.linear1 = torch.nn.Linear(in_features=10, out_features=5)
241+
self.linear2 = torch.nn.Linear(in_features=5, out_features=1)
242+
243+
def forward(self, x):
244+
x = self.linear1(x)
245+
x = torch.nn.ReLU()(x)
246+
x = self.linear2(x)
247+
return x
248+
249+
import modelopt.torch.quantization as mtq
250+
from modelopt.torch.quantization.utils import export_torch_mode
251+
252+
def calibrate_loop(model):
253+
"""Simple calibration function for testing."""
254+
model(input_tensor)
255+
256+
input_tensor = torch.randn(1, 10).cuda()
257+
model = SimpleNetwork().eval().cuda()
258+
259+
quant_cfg = mtq.INT8_DEFAULT_CFG
260+
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
261+
# model has INT8 qdq nodes at this point
262+
output_pyt = model(input_tensor)
263+
264+
with torch.no_grad():
265+
with export_torch_mode():
266+
from torch.export._trace import _export
267+
268+
exp_program = _export(model, (input_tensor,))
269+
trt_model = torchtrt.dynamo.compile(
270+
exp_program,
271+
inputs=[input_tensor],
272+
enabled_precisions={torch.int8},
273+
min_block_size=1,
274+
debug=True,
275+
)
276+
outputs_trt = trt_model(input_tensor)
277+
assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2)

tests/py/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ pytest-xdist>=3.6.1
99
pyyaml
1010
timm>=1.0.3
1111
transformers==4.40.2
12+
# TODO: once 0.16.1 is out, update it here
13+
nvidia-modelopt>=0.15.1
1214
--extra-index-url https://pypi.nvidia.com

0 commit comments

Comments
 (0)