diff --git a/.gitignore b/.gitignore index 2bd422eca7..4f351e98f7 100644 --- a/.gitignore +++ b/.gitignore @@ -223,3 +223,4 @@ gha-creds-*.json *.jsonl **/*.jsonl scr/* +output \ No newline at end of file diff --git a/lib/levanter/scripts/launch_vlm_training.py b/lib/levanter/scripts/launch_vlm_training.py new file mode 100644 index 0000000000..ddf7ac3159 --- /dev/null +++ b/lib/levanter/scripts/launch_vlm_training.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python3 +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +""" +Launch script for VLM (Vision-Language Model) training with LLaVA OneVision. + +This script provides a complete training pipeline for LLaVA OneVision models +using real parquet data, with performance optimizations for TPU/GPU training. + +Usage: + # Train from scratch with small model config + python launch_vlm_training.py + + # Train with HuggingFace pretrained weights + python launch_vlm_training.py --initialize_from_hf + + # Train with a single parquet file + python launch_vlm_training.py --train_data /path/to/train.parquet --val_data /path/to/val.parquet + + # Train with a folder containing multiple parquet files + python launch_vlm_training.py --train_data /path/to/train_folder/ --val_data /path/to/val_folder/ + + # Train with glob pattern + python launch_vlm_training.py --train_data "/path/to/data/*.parquet" + + # Full training run + python launch_vlm_training.py --initialize_from_hf --num_train_steps 10000 --train_batch_size 32 + + # High-performance training with all speed optimizations enabled + python launch_vlm_training.py --initialize_from_hf --mp bfloat16 \\ + --freeze_vision_encoder --per_device_parallelism 8 + +Performance Optimization Flags: + --freeze_vision_encoder : Freeze vision encoder (only train projector + LLM) + --per_device_parallelism: Number of examples per device (for gradient accumulation) + --fsdp_axis : FSDP sharding axis (default: embed) +""" + +import argparse +import asyncio +import dataclasses +import logging + +import jmp # For mixed precision policy + +import levanter.main.train_vlm as train_vlm +from levanter.data.image import ConversationDatasetSourceConfig, ImageMixtureDatasetConfig +from levanter.distributed import DistributedConfig, RayConfig +from levanter.models.llava_onevision import LlavaOnevisionConfig +from levanter.models.siglip import SiglipVisionConfig +from levanter.models.qwen import Qwen3Config, QwenConfig +from levanter.models.rotary import DefaultRotaryEmbeddingsConfig +from levanter.layers.attention import AttentionBackend +from levanter.optim import AdamConfig +from levanter.tracker import NoopConfig +from levanter.tracker.wandb import WandbConfig +from levanter.checkpoint import CheckpointerConfig + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Launch VLM training with LLaVA OneVision") + + # Data arguments + parser.add_argument( + "--train_data", + type=str, + default="./output", + help="Path to training data. Can be: a single parquet file, a directory containing parquet files, " + "or a glob pattern (e.g., '/path/to/*.parquet')", + ) + parser.add_argument( + "--val_data", + type=str, + default=None, + help="Path to validation data. Same format as --train_data (defaults to train_data)", + ) + parser.add_argument( + "--cache_dir", + type=str, + default="/tmp/vlm_cache", + help="Directory for data caching", + ) + parser.add_argument( + "--no_cache", + action="store_true", + help="Disable caching and use streaming mode (processes images on-the-fly, saves disk space)", + ) + parser.add_argument( + "--max_length", + type=int, + default=8192, + help="Maximum sequence length", + ) + + # Model arguments + parser.add_argument( + "--model_name", + type=str, + default="llava-hf/llava-onevision-qwen2-7b-ov-hf", + help="HuggingFace model name for processor and optional weight initialization", + ) + parser.add_argument( + "--initialize_from_hf", + action="store_true", # Default is False; we use custom weight loading for SigLIP + Qwen3 + help="Initialize model weights from HuggingFace checkpoint (for unified llava-onevision models)", + ) + parser.add_argument( + "--use_hf_model_config", + action="store_true", # Default is False; use custom SigLIP + Qwen3 config + help="Use model config from HuggingFace checkpoint (set to True to load full llava-onevision model)", + ) + parser.add_argument( + "--use_small_model", + action="store_true", + help="Use small model config for testing (overrides --use_hf_model_config)", + ) + + # Training arguments + parser.add_argument( + "--num_train_steps", + type=int, + default=20000, + help="Number of training steps", + ) + parser.add_argument( + "--epoch", + type=int, + default=1, + help="Number of epochs to train (default: 1). If 0, train indefinitely until num_train_steps is reached. " + "If > 0, dataset will cycle through the data for the specified number of epochs.", + ) + parser.add_argument( + "--train_batch_size", + type=int, + default=8, + help="Training batch size", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=0.0, + help="Weight decay", + ) + parser.add_argument( + "--warmup_ratio", + type=float, + default=0.03, + help="Warmup ratio", + ) + + # === Performance Optimization Arguments === + parser.add_argument( + "--mp", + type=str, + default="bfloat16", + choices=["bfloat16", "float16", "float32", None], + help="Mixed precision mode: bfloat16 (recommended for TPU), float16 (GPU), or float32 (full precision)", + ) + parser.add_argument( + "--no_flash_attention", + action="store_true", + help="Disable flash attention (enabled by default for memory-efficient attention computation)", + ) + parser.add_argument( + "--flash_attention_block_size", + type=int, + default=1024, + help="Block size for flash attention (default: 512, use smaller values if OOM)", + ) + parser.add_argument( + "--per_device_parallelism", + type=int, + default=-1, + help="Number of examples to process per device. -1 means train_batch_size/num_devices. " + "Set lower for gradient accumulation to save memory.", + ) + parser.add_argument( + "--freeze_vision_encoder", + action="store_true", + help="Freeze vision encoder weights (only train projector and LLM). " + "Reduces compute by ~30% and often improves fine-tuning results.", + ) + parser.add_argument( + "--freeze_llm", + action="store_true", + help="Freeze LLM weights (only train projector and vision encoder). " + "Useful for vision encoder fine-tuning or projector-only training.", + ) + parser.add_argument( + "--fsdp_axis", + type=str, + default="embed", + help="Axis to use for FSDP sharding. Options: embed, mlp, or comma-separated list", + ) + parser.add_argument( + "--no_gradient_checkpointing", + action="store_true", + help="Disable gradient checkpointing (enabled by default to reduce memory usage)", + ) + + # Checkpoint arguments + parser.add_argument( + "--output_dir", + type=str, + default="/tmp/vlm_output", + help="Directory for saving checkpoints", + ) + parser.add_argument( + "--hf_save_path", + type=str, + default=None, + help="Path to save HuggingFace format checkpoints", + ) + parser.add_argument( + "--hf_save_steps", + type=int, + default=1000, + help="Save HF checkpoint every N steps", + ) + parser.add_argument( + "--checkpointer_path", + type=str, + default=None, + help="Path for Levanter checkpoints (defaults to output_dir/checkpoints)", + ) + + # Logging arguments + parser.add_argument( + "--wandb_project", + type=str, + default="marin-vlm", + help="Weights & Biases project name (None to disable)", + ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="Weights & Biases run name", + ) + + # Distributed arguments + parser.add_argument( + "--no_distributed", + action="store_true", + help="Disable JAX distributed initialization", + ) + + # Evaluation arguments + parser.add_argument( + "--max_eval_batches", + type=int, + default=10, + help="Maximum number of evaluation batches", + ) + parser.add_argument( + "--steps_per_eval", + type=int, + default=500, # Default to less frequent eval to reduce memory pressure from dual JIT + help="How often to run evaluation (in steps). Higher values reduce JIT compilation memory overhead.", + ) + parser.add_argument( + "--per_device_eval_parallelism", + type=int, + default=-1, # Same as training to potentially reuse XLA compilation cache + help="Number of examples to process per device during evaluation. " + "Default: -1 (same as training batch size).", + ) + parser.add_argument( + "--no_eval", + action="store_true", + help="Disable evaluation completely to save memory", + ) + + return parser.parse_args() + + +def get_model_config(args) -> LlavaOnevisionConfig: + """Get model configuration based on arguments with performance optimizations.""" + + # Determine gradient checkpointing setting + use_gradient_checkpointing = not args.no_gradient_checkpointing + + # Determine attention backend (flash attention enabled by default) + use_flash = not args.no_flash_attention + if use_flash: + attn_backend = AttentionBackend.DEFAULT + flash_block_size = args.flash_attention_block_size + else: + attn_backend = AttentionBackend.VANILLA + flash_block_size = None + + if args.use_small_model: + # Small model config for testing + logger.info("Using small model config for testing") + vision_config = SiglipVisionConfig( + hidden_size=64, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + image_size=384, + gradient_checkpointing=use_gradient_checkpointing, + use_flash_attention=use_flash, + attn_backend=attn_backend, + flash_attention_block_size=flash_block_size, + ) + text_config = QwenConfig( + hidden_dim=128, + intermediate_dim=512, + num_layers=2, + num_heads=4, + num_kv_heads=2, + gradient_checkpointing=use_gradient_checkpointing, + attn_backend=attn_backend, + flash_attention_block_size=flash_block_size, + ) + else: + # Custom config: SigLIP2 (from google/siglip2-so400m-patch16-384) + Qwen3-1.7B + # Vision: SigLIP2 so400m-patch16-384 config (using SigLIP architecture) + # LLM: Qwen3-1.7B config (not Qwen2) + logger.info("Using custom config: SigLIP2-so400m-patch16 + Qwen3-1.7B") + + # SigLIP2 so400m-patch16-384 config (from HuggingFace) + vision_config = SiglipVisionConfig( + hidden_size=1152, + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + image_size=384, + patch_size=16, + gradient_checkpointing=use_gradient_checkpointing, + use_flash_attention=use_flash, + attn_backend=attn_backend, + flash_attention_block_size=flash_block_size, + ) + + # Qwen3-1.7B config (from HuggingFace Qwen/Qwen3-1.7B) + text_config = Qwen3Config( + hidden_dim=2048, + intermediate_dim=6144, + num_layers=28, + num_heads=16, + num_kv_heads=8, + max_seq_len=40960, + gradient_checkpointing=use_gradient_checkpointing, + attn_backend=attn_backend, + flash_attention_block_size=flash_block_size, + rope=DefaultRotaryEmbeddingsConfig(theta=1000000.0), + use_bias=False, + tie_word_embeddings=True, + ) + + config = LlavaOnevisionConfig( + vision_config=vision_config, + text_config=text_config, + gradient_checkpointing=use_gradient_checkpointing, + ) + + # Log optimization settings + logger.info(f" Gradient checkpointing: {use_gradient_checkpointing}") + logger.info(f" Flash attention: {use_flash}") + if use_flash: + logger.info(f" Flash attention block size: {flash_block_size}") + + return config + + +def main(): + args = parse_args() + + # Set validation data to train data if not specified + if args.val_data is None: + args.val_data = args.train_data + + logger.info("=" * 60) + logger.info("VLM Training Configuration") + logger.info("=" * 60) + logger.info(f"Training data: {args.train_data}") + logger.info(f"Validation data: {args.val_data}") + logger.info(f"Model: {args.model_name}") + logger.info(f"Initialize from HF: {args.initialize_from_hf}") + logger.info(f"Num train steps: {args.num_train_steps}") + logger.info(f"Batch size: {args.train_batch_size}") + + # Log performance optimization settings + logger.info("-" * 60) + logger.info("Performance Optimizations:") + logger.info(f" Mixed precision: {args.mp or 'disabled (float32)'}") + logger.info(f" Flash attention: {not args.no_flash_attention}") + logger.info(f" Freeze vision encoder: {args.freeze_vision_encoder}") + logger.info(f" Per-device parallelism: {args.per_device_parallelism}") + logger.info(f" FSDP axis: {args.fsdp_axis}") + logger.info(f" Gradient checkpointing: {not args.no_gradient_checkpointing}") + logger.info("-" * 60) + + # Create data config + data_config = ImageMixtureDatasetConfig( + cache_dir=args.cache_dir, + configs={ + "train": ConversationDatasetSourceConfig( + train_urls=[f"file://{args.train_data}"], + validation_urls=[f"file://{args.val_data}"], + cache_dir=f"{args.cache_dir}/train", + ), + }, + train_weights={"train": 1.0}, + processor=args.model_name, + max_length=args.max_length, + use_cache=not args.no_cache, # Use streaming mode if --no_cache is set + ) + + if args.no_cache: + logger.info("Using streaming mode (no caching) - images will be processed on-the-fly") + + # Log dataset file count + logger.info("-" * 60) + logger.info("Dataset Files:") + for name, source_config in data_config.configs.items(): + train_urls = source_config.urls_for_split("train") + val_urls = source_config.urls_for_split("validation") + logger.info(f" {name}: {len(train_urls)} train file(s), {len(val_urls)} validation file(s)") + logger.info("-" * 60) + + # Calculate num_train_steps based on epoch if specified + num_train_steps = args.num_train_steps + if args.epoch > 0: + # Build training datasets to get the actual dataset size + logger.info("Building training datasets to calculate epoch-based steps...") + train_datasets = data_config.training_sets() + + # Calculate total dataset size from all training datasets + total_dataset_size = 0 + for name, ds in train_datasets.items(): + try: + ds_len = asyncio.run(ds.async_len()) + total_dataset_size += ds_len + logger.info(f" Dataset '{name}': {ds_len:,} samples") + except Exception as e: + logger.warning(f"Could not get length of dataset '{name}': {e}") + + if total_dataset_size > 0: + # Calculate steps needed for the specified number of epochs + steps_per_epoch = total_dataset_size // args.train_batch_size + epoch_based_steps = steps_per_epoch * args.epoch + num_train_steps = epoch_based_steps + logger.info( + f"Epoch-based training: {args.epoch} epoch(s) = {num_train_steps:,} steps " + f"({total_dataset_size:,} samples / {args.train_batch_size} batch_size * {args.epoch} epochs)" + ) + else: + logger.warning("Could not determine dataset size, using --num_train_steps instead") + + # Create model config with optimizations + model_config = get_model_config(args) + + # Create optimizer config + warmup_steps = int(num_train_steps * args.warmup_ratio) + optimizer_config = AdamConfig( + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + warmup=warmup_steps, + ) + + # Create tracker config + if args.wandb_project: + tracker_config = WandbConfig( + project=args.wandb_project, + name=args.wandb_run_name, + ) + else: + tracker_config = NoopConfig() + + # Create distributed config + distributed_config = DistributedConfig(initialize_jax_distributed=not args.no_distributed) + + # Set checkpoint path + checkpointer_path = args.checkpointer_path or f"{args.output_dir}/checkpoints" + checkpointer_config = CheckpointerConfig(base_path=checkpointer_path) + + # Parse FSDP axis (can be comma-separated for multi-axis) + fsdp_axis = args.fsdp_axis + if "," in fsdp_axis: + fsdp_axis = [ax.strip() for ax in fsdp_axis.split(",")] + + # Convert mixed precision string to jmp.Policy + # jmp.get_policy accepts strings like "f32", "bf16", "bfloat16", or + # "compute=bfloat16,params=float32,output=float32" + if args.mp: + mp_policy = jmp.get_policy(args.mp) + else: + mp_policy = jmp.get_policy("f32") # Default to full precision + + # Create trainer config with performance optimizations + trainer_config = train_vlm.TrainerConfig( + num_train_steps=num_train_steps, + train_batch_size=args.train_batch_size, + per_device_parallelism=args.per_device_parallelism, + per_device_eval_parallelism=args.per_device_eval_parallelism, # Smaller eval batch to save memory + max_eval_batches=args.max_eval_batches, + steps_per_eval=args.steps_per_eval, + tracker=tracker_config, + checkpointer=checkpointer_config, + distributed=distributed_config, + ray=RayConfig(auto_start_cluster=False), + # # FSDP configuration + # fsdp_axis=fsdp_axis, + # Mixed precision configuration + mp=mp_policy, + ) + + # Create main training config + # Note: When using custom config (SigLIP + Qwen3), we disable use_hf_model_config + # and initialize_from_hf since we'll load weights separately + use_custom_config = not args.use_small_model and not args.use_hf_model_config + config = train_vlm.TrainVLMConfig( + data=data_config, + model=model_config, + trainer=trainer_config, + optimizer=optimizer_config, + # Disable HF loading when using custom config - we'll load weights separately + initialize_from_hf=( + False + if use_custom_config + else ( + args.initialize_from_hf + if args.initialize_from_hf + else args.model_name if args.use_hf_model_config else False + ) + ), + use_hf_model_config=args.use_hf_model_config and not args.use_small_model, + hf_save_path=args.hf_save_path, + hf_save_steps=args.hf_save_steps, + # Custom weight loading paths for hybrid model + # Though it's SigLIP2, the architecture is the same as SigLIP, so we use the siglip config. + vision_checkpoint="google/siglip2-so400m-patch16-384" if use_custom_config else None, + llm_checkpoint="Qwen/Qwen3-1.7B" if use_custom_config else None, + # Evaluation control + no_eval=args.no_eval, + # Epoch control + epoch=args.epoch, + ) + + # Handle freezing if requested + if args.freeze_vision_encoder: + config = dataclasses.replace(config, freeze_vision_encoder=True) + if args.freeze_llm: + config = dataclasses.replace(config, freeze_llm=True) + + logger.info("=" * 60) + logger.info("Starting VLM training...") + logger.info(f"Checkpoints will be saved to: {checkpointer_path}") + if args.hf_save_path: + logger.info(f"HF checkpoints will be saved to: {args.hf_save_path}") + if args.epoch > 0: + logger.info(f"Training for {args.epoch} epoch(s) ({num_train_steps:,} steps)") + else: + logger.info(f"Training for {num_train_steps:,} steps (no epoch limit)") + + # Run training + train_vlm.main(config) + + logger.info("Training completed!") + + +if __name__ == "__main__": + main() diff --git a/lib/levanter/src/levanter/compat/hf_checkpoints.py b/lib/levanter/src/levanter/compat/hf_checkpoints.py index 53dff0ae82..95fe39ab96 100644 --- a/lib/levanter/src/levanter/compat/hf_checkpoints.py +++ b/lib/levanter/src/levanter/compat/hf_checkpoints.py @@ -41,6 +41,7 @@ from jax import ShapeDtypeStruct from jax._src.mesh import get_concrete_mesh from jax._src.partition_spec import PartitionSpec +from jax.sharding import NamedSharding from jax.random import PRNGKey from jaxtyping import Array, PRNGKeyArray from tqdm_loggable.auto import tqdm @@ -276,7 +277,10 @@ def _to_state_dict_with_dtype( logger.debug(f"Skipping dtype conversion for non-floating point array {k} with dtype {v.dtype}") # deshard. We could be smarter here and use a process mesh or host offloading, but this is simpler for now - state_dict = jax.lax.with_sharding_constraint(state_dict, PartitionSpec()) + mesh = get_concrete_mesh() + if mesh is not None and mesh.shape: + sharding = NamedSharding(mesh, PartitionSpec()) + state_dict = jax.lax.with_sharding_constraint(state_dict, sharding) return state_dict @@ -673,7 +677,13 @@ def load_pretrained( # Vocab: first we have to resize the vocab as loaded from the checkpoint tokenizer_Vocab = self.Vocab - Vocab = tokenizer_Vocab.resize(hf_config.vocab_size) + # For multimodal models like LlavaOnevision, vocab_size is in text_config + hf_vocab_size = getattr(hf_config, "vocab_size", None) + if hf_vocab_size is None and hasattr(hf_config, "text_config"): + hf_vocab_size = hf_config.text_config.vocab_size + if hf_vocab_size is None: + raise ValueError("Could not find vocab_size in hf_config or hf_config.text_config") + Vocab = tokenizer_Vocab.resize(hf_vocab_size) # TODO: in an ideal world, we would only load the part of the array we needed, but # AFAICT neither torch state dicts nor safetensors support this. diff --git a/lib/levanter/src/levanter/data/image.py b/lib/levanter/src/levanter/data/image.py new file mode 100644 index 0000000000..ae66342d1b --- /dev/null +++ b/lib/levanter/src/levanter/data/image.py @@ -0,0 +1,3076 @@ +# Copyright 2025 The Levanter Authors +# SPDX-License-Identifier: Apache-2.0 + +""" +Image data processing module for vision-language models like LLaVA OneVision. + +This module provides utilities for: +- Loading and preprocessing images from various sources (URLs, HuggingFace datasets) +- Processing conversation-format data with interleaved images and text +- Converting images to model-ready tensors with proper axes +- Batching and caching processed image-text pairs + +Conversation Format Example: +{ + "messages": [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is in this image?"} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "This image shows..."} + ] + } + ], + "images": ["path/to/image.jpg"] # or PIL Images, or URLs +} +""" + +import abc +import asyncio +import dataclasses +import json +import logging +import math +import os +import threading +import weakref +from collections import OrderedDict +from collections.abc import Iterable +from dataclasses import dataclass +from functools import cached_property +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Union, cast + +import braceexpand +import datasets +import equinox as eqx +import fsspec +import haliax as hax +import jax +import numpy +import numpy as np +from draccus import field +from haliax import Axis, NamedArray +from haliax.partitioning import ResourceMapping +from jax.sharding import Mesh, PartitionSpec + +from levanter.data.mixture import MixtureDataset, StopStrategy +from jaxtyping import PRNGKeyArray +from typing_extensions import TypedDict + +from levanter.compat.hf_checkpoints import load_processor +from levanter.data import AsyncDataset +from levanter.data._preprocessor import BatchProcessor +from levanter.data.dataset import EpochDataset, MappedAsyncDataset +from levanter.data.loader import DataLoader, DataLoaderIterator, _Batch +from levanter.data.sharded_datasource import ( + ShardedDataSource, + UrlBackedShardedDataSource, + WrappedHFDataSource, + _sniff_format_for_dataset, +) +from levanter.schedule import IntSchedule +from levanter.shapes import NamedShapeSpec, ShapeSpec +from levanter.store.cache import CacheOptions, TreeCache, build_or_load_cache +from levanter.utils.jax_utils import key_iterator +from levanter.utils.logging import silence_transformer_nag + +silence_transformer_nag() +from transformers import ( # noqa: E402 + BatchFeature, + PreTrainedTokenizerBase, + ProcessorMixin, +) +from transformers.image_processing_utils import select_best_resolution # noqa: E402 +from transformers.image_utils import ImageInput, get_image_size, to_numpy_array # noqa: E402 +from transformers.processing_utils import MultiModalData, ProcessingKwargs, Unpack # noqa: E402 +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput # noqa: E402 +from transformers.utils import logging as transformers_logging # noqa: E402 +from transformers.video_utils import VideoInput # noqa: E402 + +# Image loading dependencies - imported at module level for performance +from io import BytesIO # noqa: E402 + +import requests # noqa: E402 +from PIL import Image # noqa: E402 + +logger = logging.getLogger("levanter.data.image") + + +class ImageTextUrlDataSource(UrlBackedShardedDataSource[dict]): + """ + Dataset for image-text pairs from various file formats (JSON, JSONL, Parquet). + + This data source reads image-text pairs where: + - image_key: points to the image data (can be path, URL, bytes, or HF dict format) + - text_key: points to the text description/caption + + Supports HuggingFace-style image formats: + - {"bytes": } + - {"path": "path/to/image.jpg"} + - Direct path string or URL + """ + + def __init__(self, urls, image_key="image", text_key="text"): + super().__init__(urls) + self.image_key = image_key + self.text_key = text_key + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: + url = self._shard_name_to_url_mapping[shard_name] + i = 0 + with fsspec.open(url, "r", compression="infer") as f: + format = _sniff_format_for_dataset(url) + match format: + case ".jsonl": + for line in f: + if i >= row: + data = json.loads(line) + yield { + "image": data[self.image_key], + "text": data[self.text_key], + } + i += 1 + case ".json": + data = json.load(f) + for doc in data[row:]: + yield { + "image": doc[self.image_key], + "text": doc[self.text_key], + } + case _: + raise ValueError(f"Unknown format {format}") + + +class ImageConversationUrlDataSource(UrlBackedShardedDataSource[dict]): + """ + Dataset for conversation-format image-text data (VLM training format). + + This data source reads conversation data with interleaved images and text, + used for vision-language model training like LLaVA. + + Expected data format: + { + "messages": [ + {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "..."}]}, + {"role": "assistant", "content": [{"type": "text", "text": "..."}]} + ], + "images": ["path/to/image.jpg"] # or PIL Images, URLs, or bytes + } + """ + + def __init__(self, urls, messages_key="messages", images_key="images"): + super().__init__(urls) + self.messages_key = messages_key + self.images_key = images_key + + def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: + url = self._shard_name_to_url_mapping[shard_name] + i = 0 + format = _sniff_format_for_dataset(url) + if format == ".parquet": + # Handle parquet files + import pyarrow.parquet as pq + + with fsspec.open(url, "rb") as f: + table = pq.read_table(f) + data = table.to_pydict() + num_rows = table.num_rows + for idx in range(row, num_rows): + yield { + "messages": data[self.messages_key][idx], + "images": data.get(self.images_key, [[]])[idx], + } + else: + with fsspec.open(url, "r", compression="infer") as f: + match format: + case ".jsonl": + for line in f: + if i >= row: + data = json.loads(line) + yield { + "messages": data[self.messages_key], + "images": data.get(self.images_key, []), + } + i += 1 + case ".json": + data = json.load(f) + for doc in data[row:]: + yield { + "messages": doc[self.messages_key], + "images": doc.get(self.images_key, []), + } + case _: + raise ValueError(f"Unknown format {format}") + + +class CustomVLMProcessor(ProcessorMixin): + """ + Custom VLM processor that combines components from different sources. + + This allows using a different tokenizer (e.g., Qwen3-1.7B) while keeping + the image/video processing from the original processor. Instead of mutating + the original processor's tokenizer, this creates a new processor instance + that properly combines the components. + """ + + attributes = ["image_processor", "tokenizer", "video_processor"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + video_processor_class = "AutoVideoProcessor" + + # Critical tokens for validation when combining processors + CRITICAL_SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>"] + CRITICAL_ROLE_TOKENS = ["assistant", "user", "system"] + + def __init__( + self, + image_processor, + tokenizer, + video_processor=None, + *, + chat_template=None, + image_token="", + video_token="