-
Notifications
You must be signed in to change notification settings - Fork 237
Support VLM calibration with image-text data #755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b9acc43
528b51d
3ef4b9d
2d60f98
42a8406
7489a36
bd87154
3200a63
8964aa5
3b7373d
5c774f9
f2774fc
59d97a6
e2e59f6
2a3868a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -161,6 +161,25 @@ scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_ | |
|
|
||
| [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. | ||
|
|
||
| #### VLM calibration with image-text pairs (e.g., Nemotron VL) | ||
|
|
||
| For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: | ||
|
|
||
| ```bash | ||
| python hf_ptq.py \ | ||
| --pyt_ckpt_path <huggingface_model_card> \ | ||
| --qformat nvfp4 \ | ||
| --export_path <quantized_ckpt_path> \ | ||
| --trust_remote_code \ | ||
| --calib_with_images \ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a way to check if the model supports an image as input and enable this automatically? |
||
| --vlm_dataset nemotron_vlm_dataset_v2 \ | ||
| --vlm_subsets sparsetables,plotqa_cot,wiki_en \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so far for LLM we just make these default without introducing the flag. Curious to know if we can follow the same convention and reduce the flags. |
||
| --calib_size 512 | ||
| ``` | ||
|
|
||
| > Note: when `--calib_with_images` is set, `--calib_size` must be a single value. | ||
| This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. | ||
|
|
||
| ### NeMo Example Script | ||
|
|
||
| NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details. | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||
| # limitations under the License. | ||||
|
|
||||
| import argparse | ||||
| import inspect | ||||
| import random | ||||
| import time | ||||
| import warnings | ||||
|
|
@@ -32,6 +33,7 @@ | |||
| is_nemotron_vl, | ||||
| run_nemotron_vl_preview, | ||||
| ) | ||||
| from nemotron_vl_calib import safe_nemotron_vl_forward | ||||
| from torch.utils.data import DataLoader | ||||
| from transformers import ( | ||||
| AutoConfig, | ||||
|
|
@@ -65,7 +67,10 @@ | |||
| from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor | ||||
| from modelopt.torch.utils.memory_monitor import launch_memory_monitor | ||||
| from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader | ||||
| from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader | ||||
| from modelopt.torch.utils.vlm_dataset_utils import ( | ||||
| get_supported_vlm_datasets, | ||||
| get_vlm_dataset_dataloader, | ||||
| ) | ||||
|
|
||||
| RAND_SEED = 1234 | ||||
|
|
||||
|
|
@@ -107,7 +112,31 @@ def make_calib_dataloader( | |||
| ) -> tuple[DataLoader, str | None]: | ||||
| calib_dataloader = None | ||||
| first_text_speech_dataset = None | ||||
| if model_type == "mllama": | ||||
| if getattr(args, "calib_with_images", False): | ||||
| # Generic multimodal calibration path (used for Nemotron VL and other HF VLMs). | ||||
| assert processor is not None, ( | ||||
| "Please provide a processor (e.g., AutoProcessor) for image calibration." | ||||
| ) | ||||
| assert len(args.calib_size) == 1, ( | ||||
| "Image calibration currently supports a single dataset. " | ||||
| "Please pass --calib_size with one value (e.g., --calib_size 256)." | ||||
| ) | ||||
| calib_dataloader = get_vlm_dataset_dataloader( | ||||
| dataset_name=getattr(args, "vlm_dataset", "scienceqa"), | ||||
| processor=processor, | ||||
| batch_size=args.batch_size, | ||||
| num_samples=args.calib_size[0], | ||||
| device=device, | ||||
| max_length=args.calib_seq, | ||||
| require_image=True, | ||||
| subsets=getattr(args, "vlm_subsets", None), | ||||
| shuffle_buffer_size=getattr(args, "vlm_shuffle_buffer", 10_000), | ||||
| seed=getattr(args, "vlm_shuffle_seed", 42), | ||||
| image_root=getattr(args, "vlm_image_root", None), | ||||
| use_media_shards=not getattr(args, "vlm_disable_media_shards", False), | ||||
| max_shards=getattr(args, "vlm_max_shards", None), | ||||
| ) | ||||
| elif model_type == "mllama": | ||||
| assert processor is not None and isinstance(processor, MllamaImageProcessor), ( | ||||
| "The MllamaImageProcessor must be set." | ||||
| ) | ||||
|
|
@@ -164,6 +193,12 @@ def auto_quantize( | |||
| ): | ||||
| """Auto search quantization of multiple formats.""" | ||||
|
|
||||
| if getattr(args, "calib_with_images", False): | ||||
| raise NotImplementedError( | ||||
| "AutoQuantize with image-text calibration is not supported yet. " | ||||
| "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." | ||||
| ) | ||||
|
|
||||
| assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( | ||||
| "Auto Quantization is not supported for pipeline parallel size > 1" | ||||
| ) | ||||
|
|
@@ -292,6 +327,7 @@ def load_model(args: argparse.Namespace): | |||
| language_model = full_model | ||||
| default_padding_side = None | ||||
|
|
||||
| is_nemotron_vl_model = is_nemotron_vl(full_model) | ||||
| if model_type == "mllama": | ||||
| processor = get_processor( | ||||
| args.pyt_ckpt_path, | ||||
|
|
@@ -307,6 +343,46 @@ def load_model(args: argparse.Namespace): | |||
| device, | ||||
| trust_remote_code=args.trust_remote_code, | ||||
| ) | ||||
| elif is_nemotron_vl_model and getattr(args, "calib_with_images", False): | ||||
| # For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs. | ||||
| try: | ||||
| processor = AutoProcessor.from_pretrained( | ||||
| args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left" | ||||
| ) | ||||
| except Exception as e: | ||||
| raise RuntimeError( | ||||
| "Failed to load AutoProcessor for Nemotron VL image calibration. " | ||||
| "Please ensure the checkpoint provides a compatible processor." | ||||
| ) from e | ||||
|
|
||||
| if hasattr(processor, "tokenizer") and processor.tokenizer is not None: | ||||
| tokenizer = processor.tokenizer | ||||
| else: | ||||
| tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) | ||||
|
|
||||
| # Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration. | ||||
| if tokenizer.pad_token is None: | ||||
| tokenizer.pad_token = tokenizer.eos_token | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to reverse this change before saving the tokenizer here? Model-Optimizer/examples/llm_ptq/hf_ptq.py Line 549 in 951c6aa
|
||||
| assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!" | ||||
|
|
||||
| default_padding_side = tokenizer.padding_side | ||||
| tokenizer.padding_side = "left" | ||||
|
|
||||
| # Quantize only the language model, but keep the full_model for calibration forward. | ||||
| language_model_lineage = get_language_model_from_vl(full_model) | ||||
| if language_model_lineage is not None: | ||||
| language_model = language_model_lineage.pop(-1) | ||||
| ancestors = language_model_lineage | ||||
| disabled_quant_cfg = {"quant_cfg": {"default": {"enable": False}}, "algorithm": "max"} | ||||
|
|
||||
| memo = set(ancestors) | {language_model} | ||||
| for ancestor in ancestors: | ||||
| for _, module in ancestor.named_children(): | ||||
| if module not in memo: | ||||
| mtq.quantize(module, disabled_quant_cfg, forward_loop=None) | ||||
| memo.add(module) | ||||
|
|
||||
| model_type = get_model_type(language_model) | ||||
| else: | ||||
| if args.dataset is None: | ||||
| args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] | ||||
|
|
@@ -432,9 +508,33 @@ def mono_quantize( | |||
|
|
||||
| if not use_calibration: | ||||
| warnings.warn("Dynamic quantization. Calibration skipped.") | ||||
| calibrate_loop = ( | ||||
| create_forward_loop(dataloader=calib_dataloader) if use_calibration else None | ||||
| ) | ||||
| calibrate_loop = None | ||||
| if use_calibration: | ||||
| base_forward_loop = create_forward_loop(dataloader=calib_dataloader) | ||||
| # For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values). | ||||
| # Those kwargs must be consumed by the *full* VLM model, not the extracted language_model. | ||||
| if getattr(args, "calib_with_images", False) and is_nemotron_vl_model: | ||||
|
|
||||
| def calibrate_full_model(_model): | ||||
| forward_params = inspect.signature(full_model.forward).parameters | ||||
| accepts_kwargs = any( | ||||
| p.kind == inspect.Parameter.VAR_KEYWORD for p in forward_params.values() | ||||
| ) | ||||
| allowed_keys = set(forward_params.keys()) | ||||
|
|
||||
| full_model.eval() | ||||
| with torch.no_grad(): | ||||
| for batch in calib_dataloader: | ||||
| if accepts_kwargs: | ||||
| call_kwargs = batch | ||||
| else: | ||||
| call_kwargs = {k: v for k, v in batch.items() if k in allowed_keys} | ||||
| call_kwargs = {k: v for k, v in call_kwargs.items() if v is not None} | ||||
| safe_nemotron_vl_forward(full_model, call_kwargs) | ||||
|
|
||||
| calibrate_loop = calibrate_full_model | ||||
| else: | ||||
| calibrate_loop = base_forward_loop | ||||
|
|
||||
| if calibration_only: | ||||
| language_model = mtq.calibrate( | ||||
|
|
@@ -856,6 +956,67 @@ def parse_args() -> argparse.Namespace: | |||
| type=str, | ||||
| default=None, | ||||
| ) | ||||
| parser.add_argument( | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are all these added flags necessary? why cannot we just use calib dataset instead?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 |
||||
| "--calib_with_images", | ||||
| action="store_true", | ||||
| help=( | ||||
| "Calibrate with image-text pairs (for VLMs). " | ||||
| "For Nemotron VL this enables multimodal calibration using --vlm_dataset." | ||||
| ), | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--vlm_dataset", | ||||
| type=str, | ||||
| default="scienceqa", | ||||
| help=f"VLM calibration dataset name (choices: {get_supported_vlm_datasets()}).", | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--vlm_subsets", | ||||
| type=str, | ||||
| default="sparsetables,plotqa_cot,wiki_en", | ||||
| help=( | ||||
| "Comma-separated subset/config names for multi-subset VLM datasets " | ||||
| "(e.g., nemotron_vlm_dataset_v2)." | ||||
| ), | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--vlm_shuffle_buffer", | ||||
| type=int, | ||||
| default=10_000, | ||||
| help="Shuffle buffer size for streaming VLM datasets (higher is more random but downloads more).", | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--vlm_shuffle_seed", | ||||
| type=int, | ||||
| default=42, | ||||
| help="Random seed for streaming VLM dataset shuffle.", | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--vlm_image_root", | ||||
| type=str, | ||||
| default=None, | ||||
| help=( | ||||
| "Local directory containing image files referenced by the VLM dataset annotations. " | ||||
| "Required for nemotron_vlm_dataset_v2 subsets that only ship JSONL (e.g., docvqa_cot, chartqa_cot)." | ||||
| ), | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--vlm_max_shards", | ||||
| type=int, | ||||
| default=1, | ||||
| help=( | ||||
| "For VLM subsets that include in-repo tar shards under `<subset>/media/*.tar`, " | ||||
| "limit how many shards to download/use for calibration. Increase if you don't get enough samples." | ||||
| ), | ||||
| ) | ||||
| parser.add_argument( | ||||
| "--vlm_disable_media_shards", | ||||
| action="store_true", | ||||
| help=( | ||||
| "Disable reading in-repo `media/shard_*.tar` files for nemotron_vlm_dataset_v2. " | ||||
| "Useful if you want to use JSONL-only subsets together with --vlm_image_root." | ||||
| ), | ||||
| ) | ||||
| parser.add_argument("--inference_tensor_parallel", type=int, default=1) | ||||
| parser.add_argument("--inference_pipeline_parallel", type=int, default=1) | ||||
| parser.add_argument("--awq_block_size", default=0, type=int) | ||||
|
|
@@ -1022,4 +1183,7 @@ def main(args: argparse.Namespace): | |||
|
|
||||
| args.dataset = args.dataset.split(",") if args.dataset else None | ||||
| args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] | ||||
| args.vlm_subsets = ( | ||||
| [s.strip() for s in args.vlm_subsets.split(",") if s.strip()] if args.vlm_subsets else None | ||||
| ) | ||||
| main(args) | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Nemotron VL calibration helpers. | ||
|
|
||
| Nemotron Nano VL v2 remote-code wrapper `forward()` is not ideal to call during PTQ calibration because it may: | ||
| - Call `torch.distributed.get_rank()` unconditionally | ||
| - Assume `past_key_values` exists in the language model output | ||
|
|
||
| Instead, we run a "safe multimodal forward" that exercises: | ||
| - Vision encoder feature extraction (C-RADIOv2-H) | ||
| - Insertion of vision embeddings into token embeddings at `img_context_token_id` | ||
| - Language model forward pass (to trigger quantizer calibration) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import contextlib | ||
| from typing import Any | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def safe_nemotron_vl_forward(full_model: torch.nn.Module, batch: dict[str, Any]) -> None: | ||
| """Run a minimal multimodal forward for Nemotron VL that avoids wrapper output packaging.""" | ||
| pixel_values = batch.get("pixel_values") | ||
| input_ids = batch.get("input_ids") | ||
| attention_mask = batch.get("attention_mask") | ||
| position_ids = batch.get("position_ids") | ||
| image_flags = batch.get("image_flags") | ||
|
|
||
| if pixel_values is None or input_ids is None: | ||
| return | ||
|
|
||
| # Nemotron Nano VL v2 expects `image_flags` in forward(), but the processor doesn't always emit it. | ||
| # `pixel_values` is flattened across batch*images, so `image_flags` should align with pixel_values.shape[0]. | ||
| if image_flags is None and torch.is_tensor(pixel_values): | ||
| image_flags = torch.ones( | ||
| (pixel_values.shape[0], 1), device=pixel_values.device, dtype=torch.long | ||
| ) | ||
| if image_flags is None: | ||
| return | ||
|
|
||
| # Match the model's preferred vision dtype (usually bf16). | ||
| vision_dtype = None | ||
| with contextlib.suppress(Exception): | ||
| vision_dtype = getattr(full_model.vision_model.config, "torch_dtype", None) | ||
| if vision_dtype is None: | ||
| with contextlib.suppress(Exception): | ||
| vision_dtype = getattr(full_model.language_model.config, "torch_dtype", None) | ||
| if ( | ||
| vision_dtype is not None | ||
| and torch.is_tensor(pixel_values) | ||
| and pixel_values.dtype != vision_dtype | ||
| ): | ||
| pixel_values = pixel_values.to(dtype=vision_dtype) | ||
|
|
||
| # Token embeddings | ||
| inputs_embeds = full_model.language_model.get_input_embeddings()(input_ids) | ||
| image_flags_s = image_flags.squeeze(-1) | ||
|
|
||
| b, n, c = inputs_embeds.shape | ||
| flat_embeds = inputs_embeds.reshape(b * n, c) | ||
| flat_ids = input_ids.reshape(b * n) | ||
| selected = flat_ids == full_model.img_context_token_id | ||
|
|
||
| # Vision embeddings | ||
| vit_embeds = full_model.extract_feature(pixel_values) | ||
| vit_embeds = vit_embeds[image_flags_s == 1] | ||
| try: | ||
| flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c) | ||
| except Exception: | ||
| vit_embeds = vit_embeds.reshape(-1, c) | ||
| n_token = selected.sum() | ||
| flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds[:n_token] | ||
|
|
||
| inputs_embeds = flat_embeds.reshape(b, n, c) | ||
|
|
||
| # LLM forward (drives activation stats) | ||
| full_model.language_model( | ||
| inputs_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| use_cache=False, | ||
| return_dict=False, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be the default for VLM?