diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 19db5b2e3..422c17d2b 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -13,7 +13,7 @@ MLX LM was developed with contributions from the following individuals: THUKEG's `GLM`, `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's `Bailing MoE e.g. Ling-family`, Klear team - Kuaishou Technology's `Klear`, IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's - `Apertus`, Nikity's `Lille130m`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE`; + `Apertus`, Nikity's `Lille130m`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE` and `Olmo 3`; Helped add support for the following model architectures: Alibaba Qwen's `Qwen3 & Qwen3MoE)`; Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon` optimizer; Added support for the following other features: `Multiple Optimizers diff --git a/mlx_lm/models/olmo3.py b/mlx_lm/models/olmo3.py new file mode 100644 index 000000000..61d5ea485 --- /dev/null +++ b/mlx_lm/models/olmo3.py @@ -0,0 +1,226 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + max_position_embeddings: int + sliding_window: int + rope_theta: float + attention_bias: bool = False + layer_types: Optional[List[str]] = None + num_key_value_heads: Optional[int] = None + head_dim: Optional[int] = None + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + if self.layer_types is None: + self.layer_types = [ + "full_attention" if (i + 1) % 4 == 0 else "sliding_attention" + for i in range(self.num_hidden_layers) + ] + + +class Olmo3Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.num_key_value_heads = args.num_key_value_heads + self.layer_idx = layer_idx + + self.head_dim = args.head_dim or args.hidden_size // args.num_attention_heads + self.scale = self.head_dim**-0.5 + + self.q_proj = nn.Linear( + args.hidden_size, + args.num_attention_heads * self.head_dim, + bias=args.attention_bias, + ) + self.k_proj = nn.Linear( + args.hidden_size, + args.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.v_proj = nn.Linear( + args.hidden_size, + args.num_key_value_heads * self.head_dim, + bias=args.attention_bias, + ) + self.o_proj = nn.Linear( + args.num_attention_heads * self.head_dim, + args.hidden_size, + bias=args.attention_bias, + ) + + self.q_norm = nn.RMSNorm( + args.num_attention_heads * self.head_dim, eps=args.rms_norm_eps + ) + self.k_norm = nn.RMSNorm( + args.num_key_value_heads * self.head_dim, eps=args.rms_norm_eps + ) + self.is_full = args.layer_types[layer_idx] == "full_attention" + + if self.is_full: + self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta) + else: + self.rope = initialize_rope( + self.head_dim, + traditional=False, + base=args.rope_theta, + scaling_config=args.rope_scaling, + max_position_embeddings=args.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, _ = x.shape + queries = self.q_norm(self.q_proj(x)) + keys = self.k_norm(self.k_proj(x)) + values = self.v_proj(x) + + queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class Olmo3MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Olmo3DecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Olmo3Attention(args, layer_idx=layer_idx) + self.mlp = Olmo3MLP(args) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_feedforward_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.post_attention_layernorm(self.self_attn(x, mask, cache)) + h = x + r + r = self.post_feedforward_layernorm(self.mlp(h)) + out = h + r + return out + + +class Olmo3Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.sliding_window = args.sliding_window + + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + Olmo3DecoderLayer(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + self.swa_idx = args.layer_types.index("sliding_attention") + self.ga_idx = args.layer_types.index("full_attention") + self.layer_types = args.layer_types + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + ) -> mx.array: + h = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + full_mask = create_attention_mask(h, cache[self.ga_idx]) + sliding_window_mask = create_attention_mask( + h, cache[self.swa_idx], window_size=self.sliding_window + ) + + for layer, c, layer_type in zip(self.layers, cache, self.layer_types): + mask = full_mask if layer_type == "full_attention" else sliding_window_mask + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = Olmo3Model(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + ) -> mx.array: + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.model.layers diff --git a/tests/test_models.py b/tests/test_models.py index 7c031a41e..39ad73469 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1824,6 +1824,21 @@ def test_all_models(self): "time_step_limit": (0.01, 10), "time_step_rank": "auto", }, + { + "model_type": "olmo3", + "num_heads": 8, + "head_dim": 16, + "vocab_size": 1000, + "hidden_size": 128, + "intermediate_size": 128, + "num_attention_heads": 8, + "rope_theta": 1000, + "num_hidden_layers": 8, + "rms_norm_eps": 1e-4, + "sliding_window": 128, + "tie_word_embeddings": True, + "max_position_embeddings": 1000, + }, ] for config in test_configs: model_type = config["model_type"]