-
Notifications
You must be signed in to change notification settings - Fork 71
VLM for Marin #2298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
VLM for Marin #2298
Conversation
dlwh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haven't finished reviewinf model and tests yet
|
|
||
| # deshard. We could be smarter here and use a process mesh or host offloading, but this is simpler for now | ||
| state_dict = jax.lax.with_sharding_constraint(state_dict, PartitionSpec()) | ||
| mesh = get_concrete_mesh() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want the concrete mesh inside jit in general since it breaks compilation caching. can we do abstract?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also why is this necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not so necessary. I added this because it would throws errors for ckpt saved outside of meshes because that it need a no empty mesh to work.
| return LlavaOnevisionModel.init(Vocab, config.model, key=model_key) | ||
|
|
||
| # For freezing, we use is_trainable=True and handle gradient zeroing separately | ||
| # This avoids haliax partitioning issues with non-trivial is_trainable filters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait what's wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When passing a non-trivial is_trainable filter (e.g., {"vision_tower": False, "projector": True}) to the trainer, I ran into issues with Haliax's partitioning/sharding logic - it had trouble computing consistent axis mappings for the non-uniform set of trainable parameters. Using is_trainable=True and applying jax.lax.stop_gradient() in the loss function achieves the same freezing behavior while keeping the model structure uniform from Haliax's perspective.
| from transformers import LlavaOnevisionConfig as HfLlavaOnevisionConfig # noqa: E402 | ||
|
|
||
|
|
||
| @LmConfig.register_subclass("llava_onevision") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be an LmConfig at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding VlmConfig.
| num_unpadded_tokens = unpad_indices.axis_size("num_image_tokens") | ||
|
|
||
| # Gather features in HF's unpadded order | ||
| image_features_reordered = self._batch_gather(image_features_flat.array, unpad_indices.array) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pretty sure you shouldn't need this. haliax ought to handle this case i think with its indexing though maybe i don't understand it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why unpad_indices: HuggingFace's LLaVA OneVision applies spatial unpadding based on aspect ratio after vision encoding (landscape images remove top/bottom padding, portraits remove left/right). Since Levanter uses fixed-size tensors, we precompute unpad_indices to map our padded features back to HF's spatial order.
Why _batch_gather: We need per-batch dynamic indexing - each image has different aspect ratio → different indices. Haliax indexing works for uniform operations across batches, but here each batch element needs its own index set. vmap(lambda arr, idx: arr[idx]) is the cleanest way to express this in JAX.
dlwh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reduce the tests by a lot?
- don't need tests that are basically reimplmentations of the class but as asserts
- asserting shapes isn't useful if that's all we're doing
- don't print so much in tests
- tolerances should be 1e-4 for single layer, 1e-3 for multi layer unless we're pretty sure they're right
let's not set env vars at import time. most of the jax config updates shouldn't be necsesary
I'm working to reduce the tests. For the Jax configs, in my previous experience, if we don't force Jax to do float32 calculation, the results would differs a lot from hf ones. |
And also, for siglip the mean difference can go lower than 1e-3, but the max difference only can go lower than 1e-2. Is this expected? I double checked for many times, the implementation should be correct. |
Description
This PR adds Vision-Language Model (VLM) training support to Levanter, with a focus on the LLaVA OneVision architecture.
Key Changes
New Features
SigLIP & Siglip2 Vision Encoder (models/siglip.py & models/siglip2.py)
LLaVA OneVision Model (models/llava_onevision.py)
Image Data Pipeline (data/image.py)
VLM Training Infrastructure
Data Sources (data/sharded_datasource.py)
Improvements
Splash Attention Explicit Mask Support (layers/attention.py)
Qwen Model (models/qwen.py)
HuggingFace Checkpoint Compatibility (compat/hf_checkpoints.py)
Cache Improvements (store/cache.py)