Skip to content

Conversation

@ruili33
Copy link

@ruili33 ruili33 commented Jan 7, 2026

Description

This PR adds Vision-Language Model (VLM) training support to Levanter, with a focus on the LLaVA OneVision architecture.

Key Changes

New Features

SigLIP & Siglip2 Vision Encoder (models/siglip.py & models/siglip2.py)

  • Full implementation of SigLIP (Sigmoid Loss for Language Image Pre-Training) and Siglip2 vision encoder

LLaVA OneVision Model (models/llava_onevision.py)

  • Complete multimodal model combining SigLIP/Siglip2 vision encoders with Qwen language models
  • Support for loading HuggingFace pretrained weights
  • Inference engine integration with KV cache support

Image Data Pipeline (data/image.py)

  • Image preprocessing pipeline for VLM training
  • Support for multiple data sources: URLs, HuggingFace datasets, parquet files
  • Conversation-format data handling with interleaved images and text

VLM Training Infrastructure

  • train_vlm.py: End-to-end VLM training main script
  • launch_vlm_training.py: Launch script with TPUoptimizations
  • ImageDataLoader in data/loader.py: Specialized data loader for variable-length image patches with proper batching and padding

Data Sources (data/sharded_datasource.py)

  • ImageTextUrlDataSource: Dataset for image-text pairs from JSON/JSONL/Parquet
  • ConversationUrlDataSource: Dataset for conversation-format VLM training data

Improvements

Splash Attention Explicit Mask Support (layers/attention.py)

  • Implemented explicit mask support for TPU Splash Attention
  • Converts NamedArray explicit masks to NumpyMask for Splash Attention compatibility
  • Proper error handling for dynamic masks during JIT tracing

Qwen Model (models/qwen.py)

  • Added decode() method to QwenDecoderLayer for paged decoding with KV cache

HuggingFace Checkpoint Compatibility (compat/hf_checkpoints.py)

  • Extended vocab_size lookup to support multimodal models (e.g., LlavaOnevision with nested text_config)

Cache Improvements (store/cache.py)

  • Added progress bar with total row count for shard building
  • Fixed bug in _extend_cache_metadata_with_other: now correctly slices shape data to actual row count instead of copying entire pre-allocated shapes store

@ruili33 ruili33 requested review from Helw150 and dlwh January 7, 2026 23:38
Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't finished reviewinf model and tests yet


# deshard. We could be smarter here and use a process mesh or host offloading, but this is simpler for now
state_dict = jax.lax.with_sharding_constraint(state_dict, PartitionSpec())
mesh = get_concrete_mesh()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want the concrete mesh inside jit in general since it breaks compilation caching. can we do abstract?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why is this necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not so necessary. I added this because it would throws errors for ckpt saved outside of meshes because that it need a no empty mesh to work.

return LlavaOnevisionModel.init(Vocab, config.model, key=model_key)

# For freezing, we use is_trainable=True and handle gradient zeroing separately
# This avoids haliax partitioning issues with non-trivial is_trainable filters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait what's wrong

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When passing a non-trivial is_trainable filter (e.g., {"vision_tower": False, "projector": True}) to the trainer, I ran into issues with Haliax's partitioning/sharding logic - it had trouble computing consistent axis mappings for the non-uniform set of trainable parameters. Using is_trainable=True and applying jax.lax.stop_gradient() in the loss function achieves the same freezing behavior while keeping the model structure uniform from Haliax's perspective.

from transformers import LlavaOnevisionConfig as HfLlavaOnevisionConfig # noqa: E402


@LmConfig.register_subclass("llava_onevision")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be an LmConfig at all

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding VlmConfig.

num_unpadded_tokens = unpad_indices.axis_size("num_image_tokens")

# Gather features in HF's unpadded order
image_features_reordered = self._batch_gather(image_features_flat.array, unpad_indices.array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure you shouldn't need this. haliax ought to handle this case i think with its indexing though maybe i don't understand it

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why unpad_indices: HuggingFace's LLaVA OneVision applies spatial unpadding based on aspect ratio after vision encoding (landscape images remove top/bottom padding, portraits remove left/right). Since Levanter uses fixed-size tensors, we precompute unpad_indices to map our padded features back to HF's spatial order.

Why _batch_gather: We need per-batch dynamic indexing - each image has different aspect ratio → different indices. Haliax indexing works for uniform operations across batches, but here each batch element needs its own index set. vmap(lambda arr, idx: arr[idx]) is the cleanest way to express this in JAX.

Copy link
Member

@dlwh dlwh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reduce the tests by a lot?

  • don't need tests that are basically reimplmentations of the class but as asserts
  • asserting shapes isn't useful if that's all we're doing
  • don't print so much in tests
  • tolerances should be 1e-4 for single layer, 1e-3 for multi layer unless we're pretty sure they're right

let's not set env vars at import time. most of the jax config updates shouldn't be necsesary

@ruili33
Copy link
Author

ruili33 commented Jan 11, 2026

Can we reduce the tests by a lot?

  • don't need tests that are basically reimplmentations of the class but as asserts
  • asserting shapes isn't useful if that's all we're doing
  • don't print so much in tests
  • tolerances should be 1e-4 for single layer, 1e-3 for multi layer unless we're pretty sure they're right

let's not set env vars at import time. most of the jax config updates shouldn't be necsesary

I'm working to reduce the tests. For the Jax configs, in my previous experience, if we don't force Jax to do float32 calculation, the results would differs a lot from hf ones.

@ruili33
Copy link
Author

ruili33 commented Jan 11, 2026

Can we reduce the tests by a lot?

  • don't need tests that are basically reimplmentations of the class but as asserts
  • asserting shapes isn't useful if that's all we're doing
  • don't print so much in tests
  • tolerances should be 1e-4 for single layer, 1e-3 for multi layer unless we're pretty sure they're right

let's not set env vars at import time. most of the jax config updates shouldn't be necsesary

I'm working to reduce the tests. For the Jax configs, in my previous experience, if we don't force Jax to do float32 calculation, the results would differs a lot from hf ones.

And also, for siglip the mean difference can go lower than 1e-3, but the max difference only can go lower than 1e-2. Is this expected? I double checked for many times, the implementation should be correct.

@ruili33 ruili33 requested a review from dlwh January 12, 2026 05:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants