Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ MLX LM was developed with contributions from the following individuals:
THUKEG's `GLM`, `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's
`Bailing MoE e.g. Ling-family`, Klear team - Kuaishou Technology's `Klear`,
IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's
`Apertus`, Nikity's `Lille130m`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE`;
`Apertus`, Nikity's `Lille130m`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE` and `OLMo3`;
Helped add support for the following model architectures: Alibaba Qwen's `Qwen3 & Qwen3MoE)`;
Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon`
optimizer; Added support for the following other features: `Multiple Optimizers
Expand Down
226 changes: 226 additions & 0 deletions mlx_lm/models/olmo3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright © 2025 Apple Inc.

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .rope_utils import initialize_rope


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
max_position_embeddings: int
sliding_window: int
rope_theta: float
attention_bias: bool = False
layer_types: Optional[List[str]] = None
num_key_value_heads: Optional[int] = None
head_dim: Optional[int] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.layer_types is None:
self.layer_types = [
"full_attention" if (i + 1) % 4 == 0 else "sliding_attention"
for i in range(self.num_hidden_layers)
]


class Olmo3Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.num_key_value_heads = args.num_key_value_heads
self.layer_idx = layer_idx

self.head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
self.scale = self.head_dim**-0.5

self.q_proj = nn.Linear(
args.hidden_size,
args.num_attention_heads * self.head_dim,
bias=args.attention_bias,
)
self.k_proj = nn.Linear(
args.hidden_size,
args.num_key_value_heads * self.head_dim,
bias=args.attention_bias,
)
self.v_proj = nn.Linear(
args.hidden_size,
args.num_key_value_heads * self.head_dim,
bias=args.attention_bias,
)
self.o_proj = nn.Linear(
args.num_attention_heads * self.head_dim,
args.hidden_size,
bias=args.attention_bias,
)

self.q_norm = nn.RMSNorm(
args.num_attention_heads * self.head_dim, eps=args.rms_norm_eps
)
self.k_norm = nn.RMSNorm(
args.num_key_value_heads * self.head_dim, eps=args.rms_norm_eps
)
self.is_full = args.layer_types[layer_idx] == "full_attention"

if self.is_full:
self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta)
else:
self.rope = initialize_rope(
self.head_dim,
traditional=False,
base=args.rope_theta,
scaling_config=args.rope_scaling,
max_position_embeddings=args.max_position_embeddings,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries = self.q_norm(self.q_proj(x))
keys = self.k_norm(self.k_proj(x))
values = self.v_proj(x)

queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)

if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)

output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class Olmo3MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)

def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class Olmo3DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Olmo3Attention(args, layer_idx=layer_idx)
self.mlp = Olmo3MLP(args)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.post_attention_layernorm(self.self_attn(x, mask, cache))
h = x + r
r = self.post_feedforward_layernorm(self.mlp(h))
out = h + r
return out


class Olmo3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.sliding_window = args.sliding_window

self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
Olmo3DecoderLayer(args=args, layer_idx=i)
for i in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

self.swa_idx = args.layer_types.index("sliding_attention")
self.ga_idx = args.layer_types.index("full_attention")
self.layer_types = args.layer_types

def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(inputs)

if cache is None:
cache = [None] * len(self.layers)

full_mask = create_attention_mask(h, cache[self.ga_idx])
sliding_window_mask = create_attention_mask(
h, cache[self.swa_idx], window_size=self.sliding_window
)

for layer, c, layer_type in zip(self.layers, cache, self.layer_types):
mask = full_mask if layer_type == "full_attention" else sliding_window_mask
h = layer(h, mask, cache=c)

return self.norm(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Olmo3Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out

@property
def layers(self):
return self.model.layers
15 changes: 15 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,21 @@ def test_all_models(self):
"time_step_limit": (0.01, 10),
"time_step_rank": "auto",
},
{
"model_type": "olmo3",
"num_heads": 8,
"head_dim": 16,
"vocab_size": 1000,
"hidden_size": 128,
"intermediate_size": 128,
"num_attention_heads": 8,
"rope_theta": 1000,
"num_hidden_layers": 8,
"rms_norm_eps": 1e-4,
"sliding_window": 128,
"tie_word_embeddings": True,
"max_position_embeddings": 1000,
},
]
for config in test_configs:
model_type = config["model_type"]
Expand Down