diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 23ab1ecf9..2cf8c0afb 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -161,6 +161,25 @@ scripts/huggingface_example.sh --model $HF_PATH --quant [fp8|nvfp4|int8_sq|int4_ [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. +#### VLM calibration with image-text pairs (e.g., Nemotron VL) + +For vision-language models, calibration quality can likely improve by using image-text pairs instead of text-only data, especially on visual understanding tasks: + +```bash +python hf_ptq.py \ + --pyt_ckpt_path \ + --qformat nvfp4 \ + --export_path \ + --trust_remote_code \ + --calib_with_images \ + --vlm_dataset nemotron_vlm_dataset_v2 \ + --vlm_subsets sparsetables,plotqa_cot,wiki_en \ + --calib_size 512 +``` + +> Note: when `--calib_with_images` is set, `--calib_size` must be a single value. +This functionality is currently in beta and has been tested on `nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16`. + ### NeMo Example Script NeMo 2.0 framework PTQ and TensorRT-LLM deployment examples are maintained in the NeMo GitHub repo. Please refer to the [NeMo PTQ documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/quantization/quantization.html) for more details. diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index a9862a742..344ad74f2 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import inspect import random import time import warnings @@ -32,6 +33,7 @@ is_nemotron_vl, run_nemotron_vl_preview, ) +from nemotron_vl_calib import safe_nemotron_vl_forward from torch.utils.data import DataLoader from transformers import ( AutoConfig, @@ -65,7 +67,10 @@ from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor from modelopt.torch.utils.memory_monitor import launch_memory_monitor from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader -from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader +from modelopt.torch.utils.vlm_dataset_utils import ( + get_supported_vlm_datasets, + get_vlm_dataset_dataloader, +) RAND_SEED = 1234 @@ -107,7 +112,31 @@ def make_calib_dataloader( ) -> tuple[DataLoader, str | None]: calib_dataloader = None first_text_speech_dataset = None - if model_type == "mllama": + if getattr(args, "calib_with_images", False): + # Generic multimodal calibration path (used for Nemotron VL and other HF VLMs). + assert processor is not None, ( + "Please provide a processor (e.g., AutoProcessor) for image calibration." + ) + assert len(args.calib_size) == 1, ( + "Image calibration currently supports a single dataset. " + "Please pass --calib_size with one value (e.g., --calib_size 256)." + ) + calib_dataloader = get_vlm_dataset_dataloader( + dataset_name=getattr(args, "vlm_dataset", "scienceqa"), + processor=processor, + batch_size=args.batch_size, + num_samples=args.calib_size[0], + device=device, + max_length=args.calib_seq, + require_image=True, + subsets=getattr(args, "vlm_subsets", None), + shuffle_buffer_size=getattr(args, "vlm_shuffle_buffer", 10_000), + seed=getattr(args, "vlm_shuffle_seed", 42), + image_root=getattr(args, "vlm_image_root", None), + use_media_shards=not getattr(args, "vlm_disable_media_shards", False), + max_shards=getattr(args, "vlm_max_shards", None), + ) + elif model_type == "mllama": assert processor is not None and isinstance(processor, MllamaImageProcessor), ( "The MllamaImageProcessor must be set." ) @@ -164,6 +193,12 @@ def auto_quantize( ): """Auto search quantization of multiple formats.""" + if getattr(args, "calib_with_images", False): + raise NotImplementedError( + "AutoQuantize with image-text calibration is not supported yet. " + "Please run plain PTQ (e.g., --qformat nvfp4) with --calib_with_images." + ) + assert not (args.auto_quantize_bits and args.inference_pipeline_parallel > 1), ( "Auto Quantization is not supported for pipeline parallel size > 1" ) @@ -292,6 +327,7 @@ def load_model(args: argparse.Namespace): language_model = full_model default_padding_side = None + is_nemotron_vl_model = is_nemotron_vl(full_model) if model_type == "mllama": processor = get_processor( args.pyt_ckpt_path, @@ -307,6 +343,46 @@ def load_model(args: argparse.Namespace): device, trust_remote_code=args.trust_remote_code, ) + elif is_nemotron_vl_model and getattr(args, "calib_with_images", False): + # For Nemotron VL image calibration, we need an AutoProcessor to build multimodal inputs. + try: + processor = AutoProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code, padding_side="left" + ) + except Exception as e: + raise RuntimeError( + "Failed to load AutoProcessor for Nemotron VL image calibration. " + "Please ensure the checkpoint provides a compatible processor." + ) from e + + if hasattr(processor, "tokenizer") and processor.tokenizer is not None: + tokenizer = processor.tokenizer + else: + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) + + # Some Nemotron tokenizers may not define pad_token by default; but we use padding=True during calibration. + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + assert tokenizer.pad_token is not None, f"Pad token for {args.pyt_ckpt_path} cannot be set!" + + default_padding_side = tokenizer.padding_side + tokenizer.padding_side = "left" + + # Quantize only the language model, but keep the full_model for calibration forward. + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: + language_model = language_model_lineage.pop(-1) + ancestors = language_model_lineage + disabled_quant_cfg = {"quant_cfg": {"default": {"enable": False}}, "algorithm": "max"} + + memo = set(ancestors) | {language_model} + for ancestor in ancestors: + for _, module in ancestor.named_children(): + if module not in memo: + mtq.quantize(module, disabled_quant_cfg, forward_loop=None) + memo.add(module) + + model_type = get_model_type(language_model) else: if args.dataset is None: args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] @@ -432,9 +508,33 @@ def mono_quantize( if not use_calibration: warnings.warn("Dynamic quantization. Calibration skipped.") - calibrate_loop = ( - create_forward_loop(dataloader=calib_dataloader) if use_calibration else None - ) + calibrate_loop = None + if use_calibration: + base_forward_loop = create_forward_loop(dataloader=calib_dataloader) + # For Nemotron VL image calibration, the dataloader yields multimodal kwargs (e.g., pixel_values). + # Those kwargs must be consumed by the *full* VLM model, not the extracted language_model. + if getattr(args, "calib_with_images", False) and is_nemotron_vl_model: + + def calibrate_full_model(_model): + forward_params = inspect.signature(full_model.forward).parameters + accepts_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in forward_params.values() + ) + allowed_keys = set(forward_params.keys()) + + full_model.eval() + with torch.no_grad(): + for batch in calib_dataloader: + if accepts_kwargs: + call_kwargs = batch + else: + call_kwargs = {k: v for k, v in batch.items() if k in allowed_keys} + call_kwargs = {k: v for k, v in call_kwargs.items() if v is not None} + safe_nemotron_vl_forward(full_model, call_kwargs) + + calibrate_loop = calibrate_full_model + else: + calibrate_loop = base_forward_loop if calibration_only: language_model = mtq.calibrate( @@ -856,6 +956,67 @@ def parse_args() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--calib_with_images", + action="store_true", + help=( + "Calibrate with image-text pairs (for VLMs). " + "For Nemotron VL this enables multimodal calibration using --vlm_dataset." + ), + ) + parser.add_argument( + "--vlm_dataset", + type=str, + default="scienceqa", + help=f"VLM calibration dataset name (choices: {get_supported_vlm_datasets()}).", + ) + parser.add_argument( + "--vlm_subsets", + type=str, + default="sparsetables,plotqa_cot,wiki_en", + help=( + "Comma-separated subset/config names for multi-subset VLM datasets " + "(e.g., nemotron_vlm_dataset_v2)." + ), + ) + parser.add_argument( + "--vlm_shuffle_buffer", + type=int, + default=10_000, + help="Shuffle buffer size for streaming VLM datasets (higher is more random but downloads more).", + ) + parser.add_argument( + "--vlm_shuffle_seed", + type=int, + default=42, + help="Random seed for streaming VLM dataset shuffle.", + ) + parser.add_argument( + "--vlm_image_root", + type=str, + default=None, + help=( + "Local directory containing image files referenced by the VLM dataset annotations. " + "Required for nemotron_vlm_dataset_v2 subsets that only ship JSONL (e.g., docvqa_cot, chartqa_cot)." + ), + ) + parser.add_argument( + "--vlm_max_shards", + type=int, + default=1, + help=( + "For VLM subsets that include in-repo tar shards under `/media/*.tar`, " + "limit how many shards to download/use for calibration. Increase if you don't get enough samples." + ), + ) + parser.add_argument( + "--vlm_disable_media_shards", + action="store_true", + help=( + "Disable reading in-repo `media/shard_*.tar` files for nemotron_vlm_dataset_v2. " + "Useful if you want to use JSONL-only subsets together with --vlm_image_root." + ), + ) parser.add_argument("--inference_tensor_parallel", type=int, default=1) parser.add_argument("--inference_pipeline_parallel", type=int, default=1) parser.add_argument("--awq_block_size", default=0, type=int) @@ -1022,4 +1183,7 @@ def main(args: argparse.Namespace): args.dataset = args.dataset.split(",") if args.dataset else None args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + args.vlm_subsets = ( + [s.strip() for s in args.vlm_subsets.split(",") if s.strip()] if args.vlm_subsets else None + ) main(args) diff --git a/examples/llm_ptq/nemotron_vl_calib.py b/examples/llm_ptq/nemotron_vl_calib.py new file mode 100644 index 000000000..1b6187008 --- /dev/null +++ b/examples/llm_ptq/nemotron_vl_calib.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Nemotron VL calibration helpers. + +Nemotron Nano VL v2 remote-code wrapper `forward()` is not ideal to call during PTQ calibration because it may: +- Call `torch.distributed.get_rank()` unconditionally +- Assume `past_key_values` exists in the language model output + +Instead, we run a "safe multimodal forward" that exercises: +- Vision encoder feature extraction (C-RADIOv2-H) +- Insertion of vision embeddings into token embeddings at `img_context_token_id` +- Language model forward pass (to trigger quantizer calibration) +""" + +from __future__ import annotations + +import contextlib +from typing import Any + +import torch + + +def safe_nemotron_vl_forward(full_model: torch.nn.Module, batch: dict[str, Any]) -> None: + """Run a minimal multimodal forward for Nemotron VL that avoids wrapper output packaging.""" + pixel_values = batch.get("pixel_values") + input_ids = batch.get("input_ids") + attention_mask = batch.get("attention_mask") + position_ids = batch.get("position_ids") + image_flags = batch.get("image_flags") + + if pixel_values is None or input_ids is None: + return + + # Nemotron Nano VL v2 expects `image_flags` in forward(), but the processor doesn't always emit it. + # `pixel_values` is flattened across batch*images, so `image_flags` should align with pixel_values.shape[0]. + if image_flags is None and torch.is_tensor(pixel_values): + image_flags = torch.ones( + (pixel_values.shape[0], 1), device=pixel_values.device, dtype=torch.long + ) + if image_flags is None: + return + + # Match the model's preferred vision dtype (usually bf16). + vision_dtype = None + with contextlib.suppress(Exception): + vision_dtype = getattr(full_model.vision_model.config, "torch_dtype", None) + if vision_dtype is None: + with contextlib.suppress(Exception): + vision_dtype = getattr(full_model.language_model.config, "torch_dtype", None) + if ( + vision_dtype is not None + and torch.is_tensor(pixel_values) + and pixel_values.dtype != vision_dtype + ): + pixel_values = pixel_values.to(dtype=vision_dtype) + + # Token embeddings + inputs_embeds = full_model.language_model.get_input_embeddings()(input_ids) + image_flags_s = image_flags.squeeze(-1) + + b, n, c = inputs_embeds.shape + flat_embeds = inputs_embeds.reshape(b * n, c) + flat_ids = input_ids.reshape(b * n) + selected = flat_ids == full_model.img_context_token_id + + # Vision embeddings + vit_embeds = full_model.extract_feature(pixel_values) + vit_embeds = vit_embeds[image_flags_s == 1] + try: + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds.reshape(-1, c) + except Exception: + vit_embeds = vit_embeds.reshape(-1, c) + n_token = selected.sum() + flat_embeds[selected] = flat_embeds[selected] * 0.0 + vit_embeds[:n_token] + + inputs_embeds = flat_embeds.reshape(b, n, c) + + # LLM forward (drives activation stats) + full_model.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + return_dict=False, + ) diff --git a/modelopt/torch/utils/nemotron_vlm_dataset_utils.py b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py new file mode 100644 index 000000000..36690d3b5 --- /dev/null +++ b/modelopt/torch/utils/nemotron_vlm_dataset_utils.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Nemotron VLM dataset utilities. + +This module contains the Nemotron-VLM-Dataset-v2 specific logic: +- Subsets can store images in `media/shard_*.tar` (images only) +- Prompts/messages live in `/.jsonl` and reference the image filename (e.g. `292180.png`) + +We join the tar images with the JSONL messages by the shared filename and yield samples compatible with our +VLM calibration pipeline. +""" + +from __future__ import annotations + +import functools +import json +import os +import random +import tarfile +from io import BytesIO +from typing import Any + +import torch + +_IMG_EXTS = {".jpg", ".jpeg", ".png", ".webp", ".bmp"} + + +@functools.lru_cache(maxsize=8) +def list_repo_files_cached(repo_id: str, repo_type: str = "dataset") -> list[str]: + """List files in a HuggingFace repo (cached). + + Args: + repo_id: HF repo id (e.g., a dataset repo). + repo_type: HF repo type, usually "dataset" here. + """ + from huggingface_hub import list_repo_files + + return list_repo_files(repo_id=repo_id, repo_type=repo_type) + + +def extract_first_image_from_messages(messages: Any) -> Any: + """Best-effort extraction of an image reference from Nemotron-style `messages`.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not (isinstance(part, dict) and part.get("type") == "image"): + continue + if "image" in part: + return part["image"] + # fallback + for key in ("images", "path", "image_url", "url", "value", "data"): + if key in part: + return part[key] + return None + + +class NemotronTarPlusJsonlIterable(torch.utils.data.IterableDataset): + """Join Nemotron VLM `media/shard_*.tar` (images-only) with `/.jsonl` (messages).""" + + def __init__( + self, + repo_id: str, + subsets: list[str], + shard_paths: list[str], + num_samples: int, + seed: int, + shuffle_buffer_size: int, + max_shards: int | None, + ): + """Create an iterable dataset for Nemotron-VLM-Dataset-v2. + + Args: + repo_id: Dataset repo id, e.g. "nvidia/Nemotron-VLM-Dataset-v2". + subsets: Subset names to draw from (e.g., "sparsetables"). + shard_paths: Tar shard paths under `/media/`. + num_samples: Total number of samples to yield. + seed: RNG seed for sampling. + shuffle_buffer_size: Unused for now (kept for API compatibility). + max_shards: Max number of shards to use per subset (limits downloads). + """ + super().__init__() + self.repo_id = repo_id + self.subsets = subsets + self.shard_paths = shard_paths + self.num_samples = num_samples + self.seed = seed + self.shuffle_buffer_size = shuffle_buffer_size + self.max_shards = max_shards + + def __iter__(self): + from huggingface_hub import hf_hub_download + from PIL import Image + + rng = random.Random(self.seed) + + # Partition shards by subset. + shards_by_subset: dict[str, list[str]] = {s: [] for s in self.subsets} + for p in self.shard_paths: + subset = p.split("/", 1)[0] + if subset in shards_by_subset: + shards_by_subset[subset].append(p) + + for subset in list(shards_by_subset.keys()): + shard_list = sorted(shards_by_subset[subset]) + if self.max_shards is not None: + shard_list = shard_list[: max(0, self.max_shards)] + shards_by_subset[subset] = shard_list + + # Roughly split sample budget across subsets. + per_subset_target = max(1, self.num_samples // max(1, len(self.subsets))) + yielded_total = 0 + + for subset in self.subsets: + if yielded_total >= self.num_samples: + break + + shard_list = list(shards_by_subset.get(subset, [])) + if not shard_list: + continue + rng.shuffle(shard_list) + + # 1) Collect candidate image filenames from tar headers (no payload reads). + candidate_names: list[str] = [] + header_limit = per_subset_target * 50 + for shard in shard_list: + local_tar = hf_hub_download( + repo_id=self.repo_id, filename=shard, repo_type="dataset" + ) + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if not member.isfile(): + continue + name = member.name + _, ext = os.path.splitext(name) + if ext.lower() not in _IMG_EXTS: + continue + candidate_names.append(name) + if len(candidate_names) >= header_limit: + break + if len(candidate_names) >= header_limit: + break + + if not candidate_names: + continue + + rng.shuffle(candidate_names) + lookup_limit = per_subset_target * 10 + candidate_set = set(candidate_names[:lookup_limit]) + + # 2) Scan JSONL to map image filename -> messages. + jsonl_path = hf_hub_download( + repo_id=self.repo_id, filename=f"{subset}/{subset}.jsonl", repo_type="dataset" + ) + meta_by_image: dict[str, dict[str, Any]] = {} + with open(jsonl_path, encoding="utf-8") as f: + for line in f: + try: + obj = json.loads(line) + except Exception: + continue + msgs = obj.get("messages") + img_name = extract_first_image_from_messages(msgs) if msgs is not None else None + if isinstance(img_name, str) and img_name in candidate_set: + meta_by_image[img_name] = {"id": obj.get("id"), "messages": msgs} + if len(meta_by_image) >= per_subset_target: + break + + if not meta_by_image: + continue + + # 3) Extract matched images and yield samples. + needed = set(meta_by_image.keys()) + for shard in shard_list: + if yielded_total >= self.num_samples or not needed: + break + local_tar = hf_hub_download( + repo_id=self.repo_id, filename=shard, repo_type="dataset" + ) + with tarfile.open(local_tar, "r:*") as tf: + for member in tf: + if yielded_total >= self.num_samples or not needed: + break + if not member.isfile(): + continue + name = member.name + if name not in needed: + continue + f = tf.extractfile(member) + if f is None: + continue + try: + raw = f.read() + if isinstance(raw, str): + raw = raw.encode() + raw_bytes: bytes = raw + img = Image.open(BytesIO(raw_bytes)).convert("RGB") + except Exception: + continue + meta = meta_by_image.get(name) + if not meta: + continue + yield { + "id": meta.get("id", name), + "messages": meta.get("messages"), + "image": img, + } + needed.discard(name) + yielded_total += 1 diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 1d9f59484..841b82f65 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -13,29 +13,200 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utility functions for getting samples and forward loop function for different vlm datasets.""" +"""Utility functions for getting samples and dataloader for different VLM calibration datasets. +This module supports both: +- Small non-streaming VLM datasets (e.g., ScienceQA) +- Large streaming VLM datasets (e.g., Nemotron-VLM-Dataset-v2) where we want to avoid downloading everything. +""" + +import contextlib +import copy +import itertools +from io import BytesIO +from pathlib import Path from typing import Any +import torch from torch.utils.data import DataLoader from .image_processor import MllamaImageProcessor +from .nemotron_vlm_dataset_utils import NemotronTarPlusJsonlIterable, list_repo_files_cached # Use dict to store the config for each dataset. # If we want to export more options to user like target languages, we need more standardized approach like dataclass. SUPPORTED_VLM_DATASET_CONFIG: dict[str, dict[str, Any]] = { "scienceqa": {"config": {"path": "derek-thomas/ScienceQA", "split": "train"}}, + # Large multi-subset dataset (use streaming to avoid downloading the entire dataset) + "nemotron_vlm_dataset_v2": { + "config": {"path": "nvidia/Nemotron-VLM-Dataset-v2", "split": "train", "streaming": True}, + # Provide a sane default that (a) includes in-repo media shards and (b) is document-centric. + # Subsets like docvqa_cot/chartqa_cot are JSONL-only in the dataset repo and require --vlm_image_root. + "default_subsets": ["sparsetables", "plotqa_cot", "wiki_en"], + }, } __all__ = ["get_supported_vlm_datasets", "get_vlm_dataset_dataloader"] -def _get_vlm_dataset(dataset_name: str, num_samples: int): +class _HFDatasetsIterableWrapper(torch.utils.data.IterableDataset): + """Wrap a HF streaming IterableDataset to be compatible with torch DataLoader.""" + + def __init__(self, hf_iterable, num_samples: int): + super().__init__() + self._hf_iterable = hf_iterable + self._num_samples = num_samples + + def __iter__(self): + return itertools.islice(iter(self._hf_iterable), self._num_samples) + + def __len__(self): + return self._num_samples + + +def _extract_text_from_messages(messages: Any) -> str | None: + """Best-effort extraction of a user text prompt from a chat-style `messages` field.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + if msg.get("role") != "user": + continue + content = msg.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + # Common multimodal format: [{"type":"image"}, {"type":"text","text":"..."}] + texts = [ + part["text"] + for part in content + if isinstance(part, dict) + and part.get("type") == "text" + and isinstance(part.get("text"), str) + ] + if texts: + return "\n".join(texts) + return None + + +def _messages_up_to_last_user(messages: Any) -> list[dict[str, Any]] | None: + """Return messages truncated to the last user turn (inclusive).""" + if not isinstance(messages, list): + return None + last_user_idx = None + for i, msg in enumerate(messages): + if isinstance(msg, dict) and msg.get("role") == "user": + last_user_idx = i + if last_user_idx is None: + return None + trimmed = messages[: last_user_idx + 1] + return [m for m in trimmed if isinstance(m, dict)] + + +def _extract_first_image_from_messages(messages: Any) -> Any: + """Best-effort extraction of an image object from a chat-style `messages` field.""" + if not isinstance(messages, list): + return None + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not (isinstance(part, dict) and part.get("type") == "image"): + continue + # Common keys used by HF datasets / chat templates + for key in ("image", "images", "value", "data", "path", "image_url", "url"): + if key in part: + val = part[key] + if isinstance(val, list) and val: + return val[0] + return val + # Fallback: return the dict itself (some processors may accept it) + return part + return None + + +def _maybe_load_image(image_obj: Any, repo_id: str | None, image_root: str | Path | None) -> Any: + """Convert common image references (path/bytes) into a PIL image if possible. + + For some streaming datasets, images are stored as file paths inside the dataset repo. + In that case, we lazily download just the referenced files via `hf_hub_download`. + """ + if image_obj is None: + return None + + # If it's a list, take the first (some formats store a list for multi-image samples). + if isinstance(image_obj, list) and image_obj: + image_obj = image_obj[0] + + # Path-like reference + if isinstance(image_obj, str): + # First, try resolving against a local image root (best option for datasets that only ship JSONL refs). + if image_root is not None: + try: + from PIL import Image + + local_path = Path(image_root) / image_obj + if local_path.exists(): + return Image.open(local_path).convert("RGB") + except Exception: + pass + + if repo_id is None: + return image_obj + try: + from huggingface_hub import hf_hub_download + from PIL import Image + + local_path = hf_hub_download(repo_id=repo_id, filename=image_obj, repo_type="dataset") + return Image.open(local_path).convert("RGB") + except Exception: + return None + + # Dict-like reference (common in chat content items) + if isinstance(image_obj, dict): + # bytes payload + if "bytes" in image_obj and isinstance(image_obj["bytes"], (bytes, bytearray)): + try: + from PIL import Image + + return Image.open(BytesIO(image_obj["bytes"])).convert("RGB") + except Exception: + return None + + # path/url-ish payloads + for key in ("path", "image", "image_path", "file", "url", "image_url"): + if key in image_obj and isinstance(image_obj[key], str): + return _maybe_load_image(image_obj[key], repo_id=repo_id, image_root=image_root) + + # If it's already a PIL/numpy/torch image-like object, just return it and let the processor validate. + return image_obj + + +def _get_vlm_dataset( + dataset_name: str, + num_samples: int, + require_image: bool = True, + subsets: list[str] | None = None, + shuffle_buffer_size: int = 10_000, + seed: int = 42, + use_media_shards: bool = True, + max_shards: int | None = None, +): """Load a portion of train dataset with the dataset name and a given size. Args: dataset_name: Name of the dataset to load. num_samples: Number of samples to load from the dataset. + require_image: If True, keep only samples that have an image field. + subsets: Optional subset/config names for multi-subset datasets (e.g., Nemotron-VLM-Dataset-v2). + shuffle_buffer_size: Shuffle buffer size for streaming datasets (higher is "more random"). + seed: RNG seed for streaming dataset shuffle. + use_media_shards: If True, prefer reading in-repo `media/shard_*.tar` files when available. + max_shards: Optional cap on the number of tar shards to download/use. Returns: A hugging face Dataset. @@ -44,16 +215,91 @@ def _get_vlm_dataset(dataset_name: str, num_samples: int): if dataset_name in SUPPORTED_VLM_DATASET_CONFIG: from datasets import load_dataset - # Use streaming can reduce the downloading time for large datasets - dataset = load_dataset( - **SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"], - ) + cfg = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].copy() + streaming = bool(cfg.pop("streaming", False)) + + if dataset_name == "nemotron_vlm_dataset_v2": + # This dataset contains many subsets; load only the requested ones via `name=...`. + if not subsets: + subsets = SUPPORTED_VLM_DATASET_CONFIG[dataset_name].get("default_subsets", []) + if not subsets: + raise ValueError("No VLM subsets provided for nemotron_vlm_dataset_v2.") + + repo_id = cfg["path"] + + # Prefer in-repo media tar shards when present. HF `datasets` streaming alone does not join media. + if use_media_shards: + all_files = list_repo_files_cached(repo_id, repo_type="dataset") + shard_paths: list[str] = [] + for subset in subsets: + prefix = f"{subset}/media/" + shard_paths.extend( + [ + p + for p in all_files + if p.startswith(prefix) and p.lower().endswith(".tar") + ] + ) + + shard_paths = sorted(set(shard_paths)) + if shard_paths: + return NemotronTarPlusJsonlIterable( + repo_id=repo_id, + subsets=subsets, + shard_paths=shard_paths, + num_samples=num_samples, + seed=seed, + shuffle_buffer_size=shuffle_buffer_size, + max_shards=max_shards, + ) + + # Load each subset as a separate (streaming) dataset, then interleave. + streams = [ + load_dataset( + cfg["path"], + name=subset, + split=cfg.get("split", "train"), + streaming=streaming, + ) + for subset in subsets + ] + try: + from datasets import interleave_datasets + + ds = interleave_datasets(streams) + except Exception: + # Fallback: round-robin by chaining (less balanced than interleave). + ds = itertools.chain.from_iterable(streams) + else: + dataset = load_dataset(**cfg, streaming=streaming) + split = cfg.get("split", "train") + ds = dataset[split] if hasattr(dataset, "__getitem__") and split in dataset else dataset else: raise NotImplementedError( f"dataset {dataset_name} is not supported. Please use one of the following:" f" {get_supported_vlm_datasets()}." ) - return dataset.select(range(num_samples)) + + # Streaming datasets: shuffle with bounded buffer and wrap into a torch IterableDataset. + if dataset_name == "nemotron_vlm_dataset_v2": + with contextlib.suppress(Exception): + ds = ds.shuffle(seed=seed, buffer_size=shuffle_buffer_size) + + if require_image: + # Keep only samples with a non-null image field (ScienceQA has both). + with contextlib.suppress(Exception): + ds = ds.filter( + lambda ex: ex.get("image", None) is not None + or ex.get("images", None) is not None + or _extract_first_image_from_messages(ex.get("messages")) is not None + ) + + # Select the first `num_samples` entries (or fewer if dataset is smaller). + try: + return ds.select(range(min(num_samples, len(ds)))) + except Exception: + # For streaming/iterable datasets without __len__/select, wrap for DataLoader iteration. + return _HFDatasetsIterableWrapper(ds, num_samples=num_samples) def get_supported_vlm_datasets() -> list[str]: @@ -75,9 +321,18 @@ def get_supported_vlm_datasets() -> list[str]: def get_vlm_dataset_dataloader( dataset_name: str = "scienceqa", - processor: MllamaImageProcessor = None, + processor: Any = None, batch_size: int = 1, num_samples: int = 512, + device: str | torch.device | None = None, + max_length: int | None = None, + require_image: bool = True, + subsets: list[str] | None = None, + shuffle_buffer_size: int = 10_000, + seed: int = 42, + image_root: str | Path | None = None, + use_media_shards: bool = True, + max_shards: int | None = None, ) -> DataLoader: """Get a dataloader with the dataset name and processor of the target model. @@ -86,22 +341,131 @@ def get_vlm_dataset_dataloader( processor: Processor used for encoding images and text data. batch_size: Batch size of the returned dataloader. num_samples: Number of samples from the dataset. + device: Device to move returned tensors to. If None, keep on CPU. + max_length: Optional max length for text tokenization (if supported by the processor). + require_image: If True, keep only samples that have an image field. Returns: An instance of dataloader. """ assert processor is not None, "Please provide a valid processor." - dataset = _get_vlm_dataset(dataset_name, num_samples=num_samples) - # Apply the preprocessing function to the dataset - processed_dataset = dataset.map( - processor.preprocess_function, batched=False, remove_columns=dataset.column_names - ) + # Optional: allow callers to set a local image root for datasets that only ship JSON references. + # We store it on the processor instance to avoid threading it through a bunch of nested closures. + if image_root is not None: + setattr(processor, "_modelopt_vlm_image_root", image_root) - # Create DataLoader with the custom collate function - return DataLoader( - processed_dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=processor.collate_function, + if device is not None: + device = torch.device(device) + + dataset = _get_vlm_dataset( + dataset_name, + num_samples=num_samples, + require_image=require_image, + subsets=subsets, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + use_media_shards=use_media_shards, + max_shards=max_shards, ) + + # Legacy path: our internal image processor wrapper (e.g., Mllama). + if isinstance(processor, MllamaImageProcessor): + processed_dataset = dataset.map( + processor.preprocess_function, batched=False, remove_columns=dataset.column_names + ) + return DataLoader( + processed_dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=processor.collate_function, + ) + + # Generic HF ProcessorMixin / AutoProcessor path: tokenize & process images at collate-time. + # For Nemotron VLM datasets, we prefer to follow the model-card flow: + # prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + # inputs = processor(text=[prompt], images=[pil_image], ...) + + def _collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor] | dict[str, Any]: + repo_id = None + if dataset_name == "nemotron_vlm_dataset_v2": + repo_id = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"]["path"] + image_root = getattr(processor, "_modelopt_vlm_image_root", None) + + pairs: list[tuple[str, Any]] = [] + for ex in examples: + messages = ex.get("messages") + + # Image extraction + img = ex.get("image", None) + if img is None: + img = ex.get("images", None) + if img is None and messages is not None: + img = _extract_first_image_from_messages(messages) + img = _maybe_load_image(img, repo_id=repo_id, image_root=image_root) + if require_image and img is None: + continue + + # Prompt extraction + prompt = None + tok = getattr(processor, "tokenizer", None) + if tok is not None and messages is not None: + trimmed = _messages_up_to_last_user(messages) or [] + # For some Nemotron-style templates, the image content expects an empty string. + # Keep the actual image path separate for loading; blank it in the prompt message. + prompt_msgs = copy.deepcopy(trimmed) + for msg in prompt_msgs: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "image": + part["image"] = "" + with contextlib.suppress(Exception): + prompt = tok.apply_chat_template( + prompt_msgs, tokenize=False, add_generation_prompt=True + ) + + if prompt is None: + # Fallback: best-effort question-only prompt. + q = ex.get("question") + if q is None and messages is not None: + q = _extract_text_from_messages(messages) + prompt = q or "Describe the image." + + pairs.append((prompt, img)) + + if not pairs: + raise ValueError( + "No usable images found in the current batch. " + "If you're using JSONL-only subsets (e.g., docvqa_cot/chartqa_cot), provide " + "`--vlm_image_root ` so referenced paths can be resolved. " + "If you're using asset-included subsets, keep media shard loading enabled " + "(default) and consider increasing `--vlm_max_shards`." + ) + + prompts, images = zip(*pairs) + + kwargs: dict[str, Any] = { + "text": list(prompts), + "images": list(images), + "return_tensors": "pt", + "padding": True, + } + if max_length is not None: + kwargs.update({"truncation": True, "max_length": max_length}) + + enc = processor(**kwargs) + + # Some processors return BatchEncoding; normalize to plain dict of tensors. + if hasattr(enc, "data"): + enc = enc.data + out: dict[str, Any] = dict(enc) + + # Move tensors to device if requested. + if device is not None: + for k, v in list(out.items()): + if torch.is_tensor(v): + out[k] = v.to(device) + return out + + return DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=_collate_fn)