diff --git a/experiments/evals/exp1602_lm_eval_harness.py b/experiments/evals/exp1602a_lm_eval_harness.py similarity index 100% rename from experiments/evals/exp1602_lm_eval_harness.py rename to experiments/evals/exp1602a_lm_eval_harness.py diff --git a/experiments/evals/exp1602b_lm_eval_harness_selected.py b/experiments/evals/exp1602b_lm_eval_harness_selected.py new file mode 100644 index 0000000000..a4ce1037ed --- /dev/null +++ b/experiments/evals/exp1602b_lm_eval_harness_selected.py @@ -0,0 +1,79 @@ +# Copyright 2025 The Marin Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the selected LM Eval Harness tasks across a set of Marin, Qwen 2.5, OLMo 2, Llama 3, and OLMo 3 models. +""" + +from collections.abc import Iterable +from dataclasses import replace + +from fray.cluster import ResourceConfig +from experiments.evals.evals import default_eval +from experiments.evals.task_configs import LM_EVAL_HARNESS_SELECTED_TASKS +from experiments.models import ( + llama_3_1_8b, + llama_3_70b, + marin_8b_base, + marin_32b_base, + olmo_2_base_32b, + olmo_2_base_8b, + olmo_3_32b, + olmo_3_7b, + qwen2_5_32b, +) +from marin.evaluation.evaluation_config import EvalTaskConfig +from marin.execution.executor import ExecutorStep, executor_main + +MARIN_MODELS: tuple[ExecutorStep, ...] = (marin_8b_base, marin_32b_base) +QWEN_2_5_MODELS: tuple[ExecutorStep, ...] = (qwen2_5_32b, ) +OLMO_2_MODELS: tuple[ExecutorStep, ...] = (olmo_2_base_8b, olmo_2_base_32b) +LLAMA_3_MODELS: tuple[ExecutorStep, ...] = (llama_3_1_8b, llama_3_70b) +OLMO_3_MODELS: tuple[ExecutorStep, ...] = (olmo_3_7b, olmo_3_32b) + +ALL_MODEL_STEPS: tuple[ExecutorStep, ...] = ( + # *MARIN_MODELS, + # *QWEN_2_5_MODELS, + # *OLMO_2_MODELS, + # *LLAMA_3_MODELS, + *OLMO_3_MODELS, +) + + +def _create_per_task_eval_steps(model_step: ExecutorStep, tasks: Iterable[EvalTaskConfig]) -> list[ExecutorStep]: + """Return one evaluation step per LM Eval Harness task for a given model.""" + + per_task_steps: list[ExecutorStep] = [] + for task in tasks: + eval_step = default_eval( + step=model_step, + resource_config=ResourceConfig.with_tpu("v5p-8"), + evals=(task,), + discover_latest_checkpoint=False, + ) + task_label = task.task_alias or task.name + # Make it obvious which harness task is running to simplify scheduling/debugging. + per_task_steps.append(replace(eval_step, name=f"{eval_step.name}/{task_label}")) + + return per_task_steps + + +eval_steps: list[ExecutorStep] = [] +for model_step in ALL_MODEL_STEPS: + eval_steps.extend(_create_per_task_eval_steps(model_step, LM_EVAL_HARNESS_SELECTED_TASKS)) + +if __name__ == "__main__": + # executor_main(steps=eval_steps) + for i in range(0, len(eval_steps), 4): + executor_main(steps=eval_steps[i : i + 4]) diff --git a/experiments/evals/task_configs.py b/experiments/evals/task_configs.py index 38762345c0..1ffa1a70a5 100644 --- a/experiments/evals/task_configs.py +++ b/experiments/evals/task_configs.py @@ -111,6 +111,34 @@ ) # LM-Eval-Harness Tasks +LM_EVAL_HARNESS_SELECTED_TASKS = ( # multiple choice tasks + EvalTaskConfig("copa", 0), + EvalTaskConfig("mmlu", 5), + EvalTaskConfig("leaderboard_musr", 0), + EvalTaskConfig("anli_r1", 0), + EvalTaskConfig("anli_r2", 0), + EvalTaskConfig("anli_r3", 0), + EvalTaskConfig("truthfulqa_mc2", 6), + EvalTaskConfig("race", 0), + EvalTaskConfig("toxigen", 0), + EvalTaskConfig("agieval_lsat_ar", 3), + EvalTaskConfig("arc_easy", 10), + EvalTaskConfig("arc_challenge", 10), + EvalTaskConfig("leaderboard_bbh", 3), + EvalTaskConfig("boolq", 10), + EvalTaskConfig("commonsense_qa", 10), + EvalTaskConfig("leaderboard_gpqa", 0), + EvalTaskConfig("hellaswag", 10), + EvalTaskConfig("leaderboard_mmlu_pro", 5), + EvalTaskConfig("openbookqa", 0), + EvalTaskConfig("piqa", 10), + EvalTaskConfig("winogrande", 0), + EvalTaskConfig("wsc273", 0), + EvalTaskConfig("squadv2", 0), + EvalTaskConfig("minerva_math", 4), +) + + # Reasoning and Logic Tasks REASONING_TASKS = ( EvalTaskConfig("anli_r1", 0, task_alias="anli_r1_0shot"), @@ -312,6 +340,7 @@ ) + def convert_to_levanter_task_config(tasks: Sequence[EvalTaskConfig]) -> list[TaskConfig]: """ Convert a list of EvalTaskConfig to a list of TaskConfig that Levanter's eval_harness expects. diff --git a/experiments/models.py b/experiments/models.py index 395dfdb83a..a70a85e8f3 100644 --- a/experiments/models.py +++ b/experiments/models.py @@ -128,6 +128,14 @@ def get_model_local_path(step: ExecutorStep) -> str: ) ) +llama_3_70b = download_model_step( + ModelConfig( + hf_repo_id="meta-llama/Meta-Llama-3-70B", + hf_revision="main", + ) +) + + tulu_3_1_8b_sft = download_model_step( ModelConfig( hf_repo_id="allenai/Llama-3.1-Tulu-3-8B-SFT", @@ -163,6 +171,20 @@ def get_model_local_path(step: ExecutorStep) -> str: ) ) +olmo_3_7b = download_model_step( + ModelConfig( + hf_repo_id="allenai/OLMo-3-1025-7B", + hf_revision="main", + ) +) + +olmo_3_32b = download_model_step( + ModelConfig( + hf_repo_id="allenai/OLMo-3-1125-32B", + hf_revision="main", + ) +) + amber_base_7b = download_model_step( ModelConfig( hf_repo_id="LLM360/Amber", @@ -185,6 +207,13 @@ def get_model_local_path(step: ExecutorStep) -> str: ) ) +marin_32b_base = download_model_step( + ModelConfig( + hf_repo_id="marin-community/marin-32b-base", + hf_revision="main", + ) +) + llama_3_2_1b = download_model_step( ModelConfig( hf_repo_id="meta-llama/Llama-3.2-1B", diff --git a/lib/levanter/src/levanter/eval_harness.py b/lib/levanter/src/levanter/eval_harness.py index e0591a37d5..d4d67d58ae 100644 --- a/lib/levanter/src/levanter/eval_harness.py +++ b/lib/levanter/src/levanter/eval_harness.py @@ -1309,11 +1309,18 @@ def _compute_averages(outputs): for metric in metric_keys: # Collect valid tasks for this metric # We iterate over the n-samples because real tasks (as opposed to aggregates like "mmlu") have counts - valid_tasks = [ - (outputs["results"][task_name].get(metric), outputs["n-samples"][task_name]["effective"]) - for task_name in outputs["n-samples"] - if outputs["results"][task_name].get(metric, None) is not None - ] + valid_tasks = [] + for task_name, sample_counts in outputs["n-samples"].items(): + task_results = outputs["results"].get(task_name) + if task_results is None: + logger.debug("Skipping %s because no results were produced.", task_name) + continue + + metric_value = task_results.get(metric) + if metric_value is None: + continue + + valid_tasks.append((metric_value, sample_counts["effective"])) if not valid_tasks: continue # Skip metrics with no valid tasks diff --git a/lib/levanter/src/levanter/models/olmo.py b/lib/levanter/src/levanter/models/olmo.py index e48085cb2c..7024b128f4 100644 --- a/lib/levanter/src/levanter/models/olmo.py +++ b/lib/levanter/src/levanter/models/olmo.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, Optional, Type, Union import equinox as eqx +import numpy as np import jax.numpy as jnp import jax.random as jrandom from jaxtyping import PRNGKeyArray @@ -14,7 +15,7 @@ import haliax.nn as hnn from haliax import Axis, AxisSpec, NamedArray from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split -from haliax.nn.scan import Stacked +from haliax.nn.scan import BlockSeq, Stacked from haliax.state_dict import ModuleWithStateDictSerialization from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig @@ -35,6 +36,7 @@ silence_transformer_nag() from transformers import Olmo2Config as HfOlmo2Config # noqa: E402 +from transformers import Olmo3Config as HfOlmo3Config # noqa: E402 from transformers import PretrainedConfig as HfConfig # noqa: E402 @@ -601,3 +603,337 @@ def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[Olmo2Config]": return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) else: return dataclasses.replace(self, embeddings=new_embeddings) + + +@LmConfig.register_subclass("olmo3") +@dataclass(frozen=True) +class Olmo3Config(HFCompatConfig): + max_seq_len: int = 65536 + hidden_dim: int = 4096 + intermediate_dim: int = 11008 + num_layers: int = 32 + num_heads: int = 32 + num_kv_heads: int = 32 + activation_function: ActivationFunctionEnum = ActivationFunctionEnum.silu + initializer_range: float = 0.02 + layer_norm_epsilon: float = 1e-6 + tie_word_embeddings: bool = False + attention_bias: bool = False + attention_dropout: float = 0.0 + + upcast_attn: bool = False + use_flash_attention: Optional[bool] = True + attn_backend: Optional[AttentionBackend] = None + flash_attention_block_size: Optional[int] = None + + gradient_checkpointing: bool = True + scan_layers: bool = True + + use_bias: bool = False + use_layer_norm_weight: bool = True + rope: RotaryEmbeddingsConfig = dataclasses.field(default_factory=DefaultRotaryEmbeddingsConfig) + + sliding_window: int = 4096 + layer_types: tuple[str, ...] = tuple( + ["sliding_attention", "sliding_attention", "sliding_attention", "full_attention"] * 8 + ) + + reference_checkpoint: str = "allenai/Olmo-3-1025-7B" + tokenizer: Optional[str] = None + + @property + def seq_len(self) -> int: + return self.max_seq_len + + @property + def Pos(self) -> Axis: + return Axis(name="position", size=self.max_seq_len) + + @property + def KeyPos(self) -> Axis: + return self.Pos.alias("key_position") + + @property + def Embed(self) -> Axis: + return Axis(name="embed", size=self.hidden_dim) + + Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_head", size=self.num_kv_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + + def __post_init__(self): + assert self.num_heads % self.num_kv_heads == 0 + assert len(self.layer_types) == self.num_layers + + def hf_checkpoint_converter( + self, ref_checkpoint: Optional[str] = None + ) -> HFCheckpointConverter["Olmo3Config"]: # type: ignore + hf_model_path = self.reference_checkpoint if ref_checkpoint is None else ref_checkpoint + + return HFCheckpointConverter( + self.__class__, + reference_checkpoint=hf_model_path, + trust_remote_code=True, + tokenizer=hf_model_path if self.tokenizer is None else self.tokenizer, + HfConfigClass=HfOlmo3Config, + ) + + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + rope_config = RotaryEmbeddingsConfig.from_hf_config( + getattr(hf_config, "rope_theta", None), getattr(hf_config, "rope_scaling", None) + ) + hf_sliding_window = getattr(hf_config, "sliding_window", None) + if hf_sliding_window is None: + hf_sliding_window = getattr(hf_config, "max_position_embeddings", None) or cls.sliding_window + + return Olmo3Config( + max_seq_len=hf_config.max_position_embeddings, + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, + activation_function=ActivationFunctionEnum(hf_config.hidden_act), + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + tie_word_embeddings=hf_config.tie_word_embeddings, + attention_bias=hf_config.attention_bias, + attention_dropout=hf_config.attention_dropout, + sliding_window=hf_sliding_window, + layer_types=tuple( + getattr(hf_config, "layer_types", ["sliding_attention"] * hf_config.num_hidden_layers) + ), + rope=rope_config, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfOlmo3Config: + if config_overrides is None: + config_overrides = {} + + rope_theta, rope_scaling = self.rope.to_hf_config() + + return HfOlmo3Config( + max_position_embeddings=self.max_seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + hidden_act=self.activation_function.name, + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + tie_word_embeddings=self.tie_word_embeddings, + attention_bias=self.attention_bias, + attention_dropout=self.attention_dropout, + sliding_window=self.sliding_window, + layer_types=list(self.layer_types), + vocab_size=vocab_size, + pad_token_id=None, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + _attn_implementation="eager", + **config_overrides, + ) + + @property + def model_type(self): + return Olmo3LMHeadModel + + def attention_config(self) -> AttentionConfig: + return AttentionConfig( + Embed=self.Embed, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + use_bias=self.attention_bias, + upcast_attn=self.upcast_attn, + attn_backend=self.attn_backend, + flash_attention_block_size=self.flash_attention_block_size, + rope=self.rope, + qk_norm=self.norm_config, + ) + + @property + def norm_config(self) -> RmsNormConfig: + return RmsNormConfig( + eps=self.layer_norm_epsilon, + use_weight=self.use_layer_norm_weight, + use_bias=self.use_bias, + ) + + +class Olmo3Attention(ModuleWithStateDictSerialization, Attention): + sliding_window: Optional[int] = eqx.field(static=True, default=None) + + @staticmethod + def init(config: AttentionConfig, *, key, sliding_window: int | None = None): + attn_config = config.attention_config() if hasattr(config, "attention_config") else config + base = Attention.init(attn_config, key=key) + q_norm = None + k_norm = None + if attn_config.qk_norm is not None: + q_norm = attn_config.qk_norm.build( + (attn_config.KVHeads, attn_config.QHeadsPerGroup, attn_config.HeadSize) + ) + k_norm = attn_config.qk_norm.build((attn_config.KVHeads, attn_config.HeadSize)) + return Olmo3Attention( + base.config, + base.q_proj, + base.k_proj, + base.v_proj, + base.o_proj, + q_norm, + k_norm, + base.rot_embs, + sliding_window=int(sliding_window) if sliding_window is not None else None, + ) + + def __call__(self, x, mask, *, key=None, pos_ids=None): + if self.sliding_window is not None: + if isinstance(mask, AttentionMask): + mask = mask.with_sliding_window(self.sliding_window) + elif mask is None: + mask = AttentionMask.causal(sliding_window=self.sliding_window) + else: + mask = AttentionMask(explicit_mask=mask, sliding_window=self.sliding_window) + + return super().__call__(x, mask, key=key, pos_ids=pos_ids) + + +class Olmo3DecoderLayer(ModuleWithStateDictSerialization, eqx.Module): + config: Olmo3Config = eqx.field(static=True) + self_attn: Olmo3Attention + mlp: Olmo2MLP + post_attention_layernorm: hnn.RmsNorm + post_feedforward_layernorm: hnn.RmsNorm + + @staticmethod + def init(config: Olmo3Config, layer_sliding_window: int, *, key): + k_attn, k_mlp = jrandom.split(key, 2) + + sliding = None if layer_sliding_window is None or layer_sliding_window < 0 else int(layer_sliding_window) + + attn = Olmo3Attention.init(config.attention_config(), key=k_attn, sliding_window=sliding) + mlp = Olmo2MLP.init( + config.Embed, + config.Mlp, + config.activation_function, + key=k_mlp, + use_bias=config.use_bias, + ) + + ln1 = config.norm_config.build(config.Embed) + ln2 = config.norm_config.build(config.Embed) + + return Olmo3DecoderLayer(config, attn, mlp, ln1, ln2) + + def __call__(self, x, mask, *, key=None, pos_ids=None): + k1, k2 = maybe_rng_split(key, 2) + h = x + self.post_attention_layernorm(self.self_attn(x, mask, key=k1, pos_ids=pos_ids)) + return h + self.post_feedforward_layernorm(self.mlp(h, key=k2)) + + +class Olmo3Transformer(ModuleWithStateDictSerialization, eqx.Module): + config: Olmo3Config = eqx.field(static=True) + layers: Stacked[Olmo3DecoderLayer] + norm: hnn.RmsNorm + + @staticmethod + def init(config: Olmo3Config, *, key): + keys = shaped_rng_split(key, config.num_layers) + + layer_sliding_windows = np.array( + [config.sliding_window if layer == "sliding_attention" else -1 for layer in config.layer_types], + dtype=int, + ) + + # BlockSeq is used to allow per-layer sliding window choices + layers = BlockSeq.init( + config.Layers, + Olmo3DecoderLayer, + gradient_checkpointing=config.gradient_checkpointing, + )(config, layer_sliding_windows, key=keys) + + ln_f = config.norm_config.build(config.Embed) + return Olmo3Transformer(config, layers, ln_f) + + def __call__(self, x, attn_mask, *, key, pos_ids=None): + keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None + x = self.layers.fold(x, mask=attn_mask, key=keys, pos_ids=pos_ids) + return self.norm(x) + + +class Olmo3LMHeadModel(ModuleWithStateDictSerialization, LmHeadModel[Olmo3Config]): + transformer: Olmo3Transformer + embeddings: Olmo2Embedding + lm_head: Optional[hnn.Linear] + + @classmethod + def init(cls, Vocab: Axis, config: Olmo3Config, *, key): + k_t, k_emb, k_head = jrandom.split(key, 3) + + transformer = Olmo3Transformer.init(config, key=k_t) + embeddings = Olmo2Embedding.init(Vocab, config, key=k_emb) + + if config.tie_word_embeddings: + lm_head = None + else: + lm_head = hnn.Linear.init( + In=config.Embed, Out=Vocab, key=k_head, use_bias=False, out_first=True + ) + + return Olmo3LMHeadModel(transformer, embeddings, lm_head) + + def __call__(self, input_ids, attn_mask=None, pos_ids=None, *, key=None): + k_t, k_h = maybe_rng_split(key, 2) + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=k_t, pos_ids=pos_ids) + + if self.lm_head is None: + return self.embeddings.unembed(x) + else: + return self.lm_head(x, key=k_h) + + @property + def config(self) -> Olmo3Config: + return self.transformer.config + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + def activations( + self, + input_ids: NamedArray, + attn_mask: Optional[AttentionMask | NamedArray] = None, + *, + key=None, + pos_ids: NamedArray | None = None, + ) -> NamedArray: + x = self.embeddings.embed(input_ids) + return self.transformer(x, attn_mask=attn_mask, key=key, pos_ids=pos_ids) + + def get_lm_head(self) -> hax.NamedArray: + if self.lm_head is None: + return self.embeddings.token_embeddings.weight + return self.lm_head.weight + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return { + "transformer": "model", + "embeddings": "model", + "lm_head": "lm_head", + } + + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[Olmo3Config]": + new_vocab = self.Vocab.resize(new_size) + k1, k2 = maybe_rng_split(key, 2) + new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) + if self.lm_head is None: + return dataclasses.replace(self, embeddings=new_embeddings) + + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_vocab, weight=new_lm_matrix) + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) diff --git a/lib/levanter/src/levanter/models/qwen.py b/lib/levanter/src/levanter/models/qwen.py index ea636d9cd8..2f321688ef 100644 --- a/lib/levanter/src/levanter/models/qwen.py +++ b/lib/levanter/src/levanter/models/qwen.py @@ -17,6 +17,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.layers.attention import Attention, AttentionConfig, AttentionMask +from levanter.layers.normalization import LayerNormConfigBase, RmsNormConfig from levanter.layers.rotary import RotaryEmbeddingsConfig from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaLMHeadModel, LlamaMlp, LlamaTransformer from levanter.models.lm_model import LmConfig, LmHeadModel @@ -61,6 +62,10 @@ def hf_checkpoint_converter( def from_hf_config(cls, hf_config: HfConfig): rope_theta = hf_config.rope_theta rope_config = RotaryEmbeddingsConfig.from_hf_config(rope_theta, hf_config.rope_scaling) + use_bias = getattr(hf_config, "attention_bias", None) + if use_bias is None: + # Qwen2Config in newer transformers drops no_bias; assume bias by default + use_bias = not getattr(hf_config, "no_bias", False) return QwenConfig( max_seq_len=hf_config.max_position_embeddings, hidden_dim=hf_config.hidden_size, @@ -76,7 +81,7 @@ def from_hf_config(cls, hf_config: HfConfig): layer_norm_epsilon=hf_config.rms_norm_eps, tie_word_embeddings=hf_config.tie_word_embeddings, rope=rope_config, - use_bias=not hf_config.no_bias, + use_bias=use_bias, ) def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfQwenConfig: @@ -156,7 +161,6 @@ def init(config: QwenConfig, *, key) -> "QwenDecoderLayer": config.Mlp, config.activation_function, key=k_mlp, - use_bias=config.use_bias, ) ln_1 = config.mk_LayerNorm(config.Embed) ln_2 = config.mk_LayerNorm(config.Embed) @@ -310,6 +314,14 @@ class Qwen3Config(LlamaConfig): def model_type(self): # noqa: D401 return Qwen3LMHeadModel + @property # type: ignore[override] + def norm_config(self) -> LayerNormConfigBase: + return RmsNormConfig( + use_weight=self.use_layer_norm_weight, + use_bias=False, + eps=self.layer_norm_epsilon, + ) + def hf_checkpoint_converter( self, ref_checkpoint: Optional[str] = None ) -> HFCheckpointConverter["Qwen3Config"]: # type: ignore diff --git a/lib/levanter/tests/test_olmo.py b/lib/levanter/tests/test_olmo.py index c8cc805b15..41b1e4f499 100644 --- a/lib/levanter/tests/test_olmo.py +++ b/lib/levanter/tests/test_olmo.py @@ -13,7 +13,16 @@ import haliax.nn as hnn from levanter.layers.attention import AttentionMask -from levanter.models.olmo import Olmo2Attention, Olmo2Config, Olmo2DecoderLayer, Olmo2LMHeadModel +from levanter.models.olmo import ( + Olmo2Attention, + Olmo2Config, + Olmo2DecoderLayer, + Olmo2LMHeadModel, + Olmo3Config, + Olmo3LMHeadModel, + Olmo3Attention, + Olmo3Transformer, +) from levanter.utils.jax_utils import parameter_count from test_utils import skip_if_no_torch, use_test_mesh @@ -45,6 +54,36 @@ def _get_random_inputs(config: Olmo2Config, override_Pos=None): return x, mask +def _get_olmo3_config( + seq_len=128, + hidden_dim=16, + intermediate_dim=32, + num_layers=4, + num_heads=4, + num_kv_heads=2, + sliding_window=16, + layer_types=None, + tie_embeddings=False, +) -> Olmo3Config: + if layer_types is None: + layer_types = tuple(["sliding_attention"] * (num_layers - 1) + ["full_attention"]) + + return Olmo3Config( + max_seq_len=seq_len, + hidden_dim=hidden_dim, + intermediate_dim=intermediate_dim, + num_layers=num_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, + scan_layers=False, + use_flash_attention=False, + sliding_window=sliding_window, + layer_types=tuple(layer_types), + tie_word_embeddings=tie_embeddings, + ) + + @skip_if_no_torch def test_olmo2_config(): # Check we can create a config @@ -401,3 +440,202 @@ def compute(model, input): jax_out_2 = compute(model, input_256)[config.max_Pos, :128] assert np.allclose(jax_out_1.array, jax_out_2.array, rtol=1e-6, atol=1e-6) + + +def test_olmo3_config_roundtrip(): + layer_types = ("sliding_attention", "full_attention", "sliding_attention") + config = _get_olmo3_config( + seq_len=256, + hidden_dim=32, + intermediate_dim=64, + num_layers=3, + num_heads=6, + num_kv_heads=3, + sliding_window=64, + layer_types=layer_types, + ) + + assert config.Pos.size == 256 + assert config.Embed.size == 32 + assert config.Heads.size == 6 + assert config.KVHeads.size == 3 + assert config.Layers.size == 3 + assert config.Mlp.size == 64 + assert config.HeadSize.size == 32 // 6 + assert config.layer_types == layer_types + assert config.sliding_window == 64 + + hf_config = config.to_hf_config(vocab_size=321) + assert hf_config.sliding_window == 64 + assert tuple(hf_config.layer_types) == layer_types + + config2 = Olmo3Config.from_hf_config(hf_config) + assert config2.max_seq_len == 256 + assert config2.hidden_dim == 32 + assert config2.intermediate_dim == 64 + assert config2.num_layers == 3 + assert config2.num_heads == 6 + assert config2.num_kv_heads == 3 + assert config2.layer_types == layer_types + assert config2.sliding_window == 64 + + +def test_olmo3_layer_sliding_window_selection(): + layer_types = ("sliding_attention", "full_attention", "sliding_attention", "full_attention") + config = _get_olmo3_config(num_layers=4, layer_types=layer_types, sliding_window=8) + + transformer = Olmo3Transformer.init(config, key=random.PRNGKey(0)) + sliding_windows = [layer.self_attn.sliding_window for layer in transformer.layers.blocks] + + assert sliding_windows == [8, None, 8, None] + + +def test_olmo3_attention_applies_sliding_window(): + config = _get_olmo3_config( + seq_len=12, + hidden_dim=24, + intermediate_dim=48, + num_layers=1, + num_heads=4, + num_kv_heads=2, + sliding_window=3, + layer_types=("sliding_attention",), + ) + attn = Olmo3Attention.init(config.attention_config(), key=random.PRNGKey(0), sliding_window=config.sliding_window) + attn_no_sliding = eqx.tree_at(lambda a: a.sliding_window, attn, None) + + Batch = hax.Axis("batch", 1) + x = hax.random.normal(random.PRNGKey(1), (Batch, config.Pos, config.Embed)) + base_mask = AttentionMask.causal() + manual_mask = base_mask.with_sliding_window(config.sliding_window) + + out_with_attr = attn(x, base_mask, key=random.PRNGKey(2)) + out_manual = attn_no_sliding(x, manual_mask, key=random.PRNGKey(2)) + + chex.assert_trees_all_close(out_with_attr, out_manual, rtol=1e-6, atol=1e-6) + + +@skip_if_no_torch +def test_olmo3_roundtrip(): + import torch + from transformers import AutoModelForCausalLM, Olmo3ForCausalLM + + converter = Olmo3Config().hf_checkpoint_converter() + + config = Olmo3Config( + max_seq_len=64, + hidden_dim=16, + intermediate_dim=32, + num_heads=4, + num_kv_heads=2, + num_layers=4, + gradient_checkpointing=False, + scan_layers=False, + use_flash_attention=False, + sliding_window=16, + layer_types=("sliding_attention", "sliding_attention", "full_attention", "full_attention"), + ) + # Use large vocab so tokenizer pad/eos ids from the reference checkpoint stay in range + Vocab = hax.Axis("vocab", 150000) + hf_config = config.to_hf_config(Vocab.size) + + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = AttentionMask.causal() + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + + torch.random.manual_seed(0) + torch_model = Olmo3ForCausalLM(hf_config) + torch_model.eval() + torch_out = torch_model(input_torch) + torch_logits = torch_out.logits[0].detach().cpu().numpy() + + with tempfile.TemporaryDirectory() as tmpdir, use_test_mesh(): + model_path = f"{tmpdir}/torch_model" + torch_model.save_pretrained(model_path) + + model = converter.load_pretrained( + Olmo3LMHeadModel, ref=model_path, resize_vocab_to_match_tokenizer=False + ) + + @hax.named_jit + def compute(model, ids): + return model(ids, attn_mask=attn_mask) + + jax_out = compute(model, input).array + + assert torch_logits.shape == jax_out.shape, f"{torch_logits.shape} != {jax_out.shape}" + + abs_diff = np.abs(torch_logits - jax_out.astype(np.float32)) + max_diff_idx = np.unravel_index(np.argmax(abs_diff), abs_diff.shape) + print(f"\nOLMo3 max diff at {max_diff_idx}: {abs_diff[max_diff_idx]}") + print(f"HF value: {torch_logits[max_diff_idx]}, JAX value: {jax_out[max_diff_idx]}") + + assert np.isclose(torch_logits, np.array(jax_out), rtol=1e-4, atol=1e-4).all(), f"{torch_logits} != {jax_out}" + + converter_with_ref = converter.replaced(reference_checkpoint=model_path) + converter_with_ref.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + + torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") + torch_model2.eval() + + torch_out2 = torch_model2(input_torch) + torch_logits2 = torch_out2.logits[0].detach().cpu().numpy() + assert torch_logits2.shape == jax_out.shape, f"{torch_logits2.shape} != {jax_out.shape}" + np.testing.assert_allclose(torch_logits2, jax_out, rtol=1e-5, atol=1e-5) + + +def test_olmo3_lm_head_model_forward_and_grad(): + config = _get_olmo3_config( + seq_len=16, + hidden_dim=24, + intermediate_dim=48, + num_layers=2, + num_heads=4, + num_kv_heads=2, + sliding_window=4, + layer_types=("sliding_attention", "full_attention"), + ) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 100) + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, config.Pos), 0, Vocab.size) + mask = AttentionMask.causal() + + model = Olmo3LMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(1)) + outputs = model(input_ids, attn_mask=mask, key=random.PRNGKey(2)) + + assert outputs.array.shape == (Batch.size, config.Pos.size, Vocab.size) + assert outputs.axes[0] == Batch + assert outputs.axes[1] == config.Pos + + def loss_fn(m, ids): + logits = m(ids, attn_mask=mask) + return hax.sum(logits).scalar() + + _, grads = eqx.filter_value_and_grad(loss_fn)(model, input_ids) + assert grads is not None + + +def test_olmo3_tied_embeddings(): + config = _get_olmo3_config( + seq_len=8, + hidden_dim=12, + intermediate_dim=24, + num_layers=1, + num_heads=4, + num_kv_heads=2, + sliding_window=4, + layer_types=("full_attention",), + tie_embeddings=True, + ) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 50) + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, config.Pos), 0, Vocab.size) + mask = AttentionMask.causal() + + model = Olmo3LMHeadModel.init(Vocab=Vocab, config=config, key=random.PRNGKey(1)) + + assert model.lm_head is None + chex.assert_trees_all_equal(model.get_lm_head(), model.embeddings.token_embeddings.weight) + + outputs = model(input_ids, attn_mask=mask, key=random.PRNGKey(2)) + assert outputs.array.shape == (Batch.size, config.Pos.size, Vocab.size) diff --git a/scripts/gpu_eval/pt_lm_eval_harness.sh b/scripts/gpu_eval/pt_lm_eval_harness.sh new file mode 100644 index 0000000000..a3e6ff34e2 --- /dev/null +++ b/scripts/gpu_eval/pt_lm_eval_harness.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# We evaluate the models not supported on TPU (e.g. Gemma 3, Nemotron) with this script + +set -euo pipefail + +DEFAULT_MODELS=( + "meta-llama/Meta-Llama-3-70B" + "marin-community/marin-32b-base" + "allenai/OLMo-2-0325-32B" + "Qwen/Qwen2.5-32B" + "allenai/OLMo-3-1125-32B" + "google/gemma-3-27b-pt" + "meta-llama/Llama-3.1-8B" + "marin-community/marin-8b-base" + "allenai/OLMo-2-1124-7B" + "Qwen/Qwen3-8B-Base" + "allenai/OLMo-3-1025-7B" +) + +if [ "$#" -gt 0 ]; then + MODELS=("$@") +else + MODELS=("${DEFAULT_MODELS[@]}") +fi + +for MODEL in "${MODELS[@]}"; do + echo "Running lm_eval for model: ${MODEL}" + HF_ALLOW_CODE_EVAL=1 lm_eval \ + --model vllm \ + --model_args "pretrained=$MODEL,trust_remote_code=True,dtype=auto,gpu_memory_utilization=0.8,max_model_len=4096,max_gen_toks=4096" \ + --tasks paloma_c4_en,leaderboard_musr,anli,triviaqa,drop,truthfulqa_mc2,squadv2,race,toxigen,blimp,nq_open,xsum,uncheatable_eval,agieval_lsat_ar,arc_easy,arc_challenge,leaderboard_bbh,boolq,commonsense_qa,copa,leaderboard_gpqa,gsm8k_cot,hellaswag,humaneval,lambada_openai,minerva_math,mmlu,leaderboard_mmlu_pro,openbookqa,piqa,winogrande,wsc273 \ + --batch_size auto \ + --output_path ./local-eval-results \ + --confirm_run_unsafe_code \ + --apply_chat_template +done