Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
578 changes: 578 additions & 0 deletions lib/levanter/scripts/launch_vlm_training.py

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions lib/levanter/src/levanter/compat/hf_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from jax import ShapeDtypeStruct
from jax._src.mesh import get_concrete_mesh
from jax._src.partition_spec import PartitionSpec
from jax.sharding import NamedSharding
from jax.random import PRNGKey
from jaxtyping import Array, PRNGKeyArray
from tqdm_loggable.auto import tqdm
Expand Down Expand Up @@ -276,7 +277,10 @@ def _to_state_dict_with_dtype(
logger.debug(f"Skipping dtype conversion for non-floating point array {k} with dtype {v.dtype}")

# deshard. We could be smarter here and use a process mesh or host offloading, but this is simpler for now
state_dict = jax.lax.with_sharding_constraint(state_dict, PartitionSpec())
mesh = get_concrete_mesh()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want the concrete mesh inside jit in general since it breaks compilation caching. can we do abstract?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also why is this necessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not so necessary. I added this because it would throws errors for ckpt saved outside of meshes because that it need a no empty mesh to work.

if mesh is not None and mesh.shape:
sharding = NamedSharding(mesh, PartitionSpec())
state_dict = jax.lax.with_sharding_constraint(state_dict, sharding)

return state_dict

Expand Down Expand Up @@ -673,7 +677,13 @@ def load_pretrained(

# Vocab: first we have to resize the vocab as loaded from the checkpoint
tokenizer_Vocab = self.Vocab
Vocab = tokenizer_Vocab.resize(hf_config.vocab_size)
# For multimodal models like LlavaOnevision, vocab_size is in text_config
hf_vocab_size = getattr(hf_config, "vocab_size", None)
if hf_vocab_size is None and hasattr(hf_config, "text_config"):
hf_vocab_size = hf_config.text_config.vocab_size
if hf_vocab_size is None:
raise ValueError("Could not find vocab_size in hf_config or hf_config.text_config")
Vocab = tokenizer_Vocab.resize(hf_vocab_size)

# TODO: in an ideal world, we would only load the part of the array we needed, but
# AFAICT neither torch state dicts nor safetensors support this.
Expand Down
Loading
Loading