|
1 | 1 | """
|
2 |
| -.. _vgg16_fp8_ptq: |
| 2 | +.. _vgg16_ptq: |
3 | 3 |
|
4 | 4 | Deploy Quantized Models using Torch-TensorRT
|
5 | 5 | ======================================================
|
6 | 6 |
|
7 |
| -Here we demonstrate how to deploy a model quantized to FP8 using the Dynamo frontend of Torch-TensorRT |
| 7 | +Here we demonstrate how to deploy a model quantized to INT8 or FP8 using the Dynamo frontend of Torch-TensorRT |
8 | 8 | """
|
9 | 9 |
|
10 | 10 | # %%
|
@@ -111,7 +111,12 @@ def vgg16(num_classes=1000, init_weights=False):
|
111 | 111 | type=int,
|
112 | 112 | help="Batch size for tuning the model with PTQ and FP8",
|
113 | 113 | )
|
114 |
| - |
| 114 | +PARSER.add_argument( |
| 115 | + "--quantize-type", |
| 116 | + default="int8", |
| 117 | + type=str, |
| 118 | + help="quantization type, currently supported int8 or fp8 for PTQ", |
| 119 | +) |
115 | 120 | args = PARSER.parse_args()
|
116 | 121 |
|
117 | 122 | model = vgg16(num_classes=10, init_weights=False)
|
@@ -191,8 +196,10 @@ def calibrate_loop(model):
|
191 | 196 | # %%
|
192 | 197 | # Tune the pre-trained model with FP8 and PTQ
|
193 | 198 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
194 |
| - |
195 |
| -quant_cfg = mtq.FP8_DEFAULT_CFG |
| 199 | +if args.quantize_type == "int8": |
| 200 | + quant_cfg = mtq.INT8_DEFAULT_CFG |
| 201 | +elif args.quantize_type == "fp8": |
| 202 | + quant_cfg = mtq.FP8_DEFAULT_CFG |
196 | 203 | # PTQ with in-place replacement to quantized modules
|
197 | 204 | mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
198 | 205 | # model has FP8 qdq nodes at this point
|
@@ -226,11 +233,18 @@ def calibrate_loop(model):
|
226 | 233 | with export_torch_mode():
|
227 | 234 | # Compile the model with Torch-TensorRT Dynamo backend
|
228 | 235 | input_tensor = images.cuda()
|
229 |
| - exp_program = torch.export.export(model, (input_tensor,)) |
| 236 | + # torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode() |
| 237 | + from torch.export._trace import _export |
| 238 | + |
| 239 | + exp_program = _export(model, (input_tensor,)) |
| 240 | + if args.quantize_type == "int8": |
| 241 | + enabled_precisions = {torch.int8} |
| 242 | + elif args.quantize_type == "fp8": |
| 243 | + enabled_precisions = {torch.float8_e4m3fn} |
230 | 244 | trt_model = torchtrt.dynamo.compile(
|
231 | 245 | exp_program,
|
232 | 246 | inputs=[input_tensor],
|
233 |
| - enabled_precisions={torch.float8_e4m3fn}, |
| 247 | + enabled_precisions=enabled_precisions, |
234 | 248 | min_block_size=1,
|
235 | 249 | debug=False,
|
236 | 250 | )
|
|
0 commit comments