Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ MLX LM was developed with contributions from the following individuals:
THUKEG's `GLM`, `GLM4`, Rednote `dots.llm1`, Baisu's `Ernie4.5 MoE`, inclusionAI's
`Bailing MoE e.g. Ling-family`, Klear team - Kuaishou Technology's `Klear`,
IBM's `Granite MoE`, Meituan's `LongCat`, Nvidia's `Nemotron H`, Swiss-AI's
`Apertus`, Nikity's `Lille130m`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE`;
`Apertus`, Nikity's `Lille130m`, Alibaba Qwen's `Qwen3Next`, and Allenai's `OLMoE` and `OLMo3`;
Helped add support for the following model architectures: Alibaba Qwen's `Qwen3 & Qwen3MoE)`;
Added support for the following training algorithms: `Full Weight Fine-Tuning`, and the `Muon`
optimizer; Added support for the following other features: `Multiple Optimizers
Expand Down
211 changes: 211 additions & 0 deletions mlx_lm/models/olmo3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright © 2025 Apple Inc.

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import mlx.core as mx
import mlx.nn as nn

from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
max_position_embeddings: Optional[int]
num_key_value_heads: Optional[int]
attention_bias: bool
mlp_bias: bool
rope_theta: float
layer_types: List[str]
sliding_window: int
rope_traditional: bool = False
num_key_value_heads: Optional[int] = None
head_dim: Optional[int] = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads


class Olmo3Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = args.head_dim or args.hidden_size // n_heads
self.layer_idx = layer_idx

self.scale = self.head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * self.head_dim, bias=args.attention_bias)
self.k_proj = nn.Linear(
dim, n_kv_heads * self.head_dim, bias=args.attention_bias
)
self.v_proj = nn.Linear(
dim, n_kv_heads * self.head_dim, bias=args.attention_bias
)
self.o_proj = nn.Linear(n_heads * self.head_dim, dim, bias=args.attention_bias)

self.q_norm = nn.RMSNorm(dims=self.head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(dims=self.head_dim, eps=args.rms_norm_eps)
self.is_sliding = (layer_idx + 1) % args.layer_types[layer_idx] != 0

rope_base = args.rope_theta
if self.is_sliding:
self.rope = nn.RoPE(
self.head_dim, traditional=args.rope_traditional, base=rope_base
)
else:
self.rope = nn.RoPE(
self.head_dim,
traditional=args.rope_traditional,
base=rope_base,
scale=args.rope_scaling if hasattr(args, "rope_scaling") else 1.0,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)

keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

queries = self.q_norm(queries)
keys = self.k_norm(keys)

if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)

if isinstance(mask, mx.array) and mask.shape[-1] != keys.shape[-2]:
mask = mask[..., -keys.shape[-2] :]
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)


class Olmo3MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.gate_proj = nn.Linear(
args.hidden_size, args.intermediate_size, bias=args.mlp_bias
)
self.down_proj = nn.Linear(
args.intermediate_size, args.hidden_size, bias=args.mlp_bias
)
self.up_proj = nn.Linear(
args.hidden_size, args.intermediate_size, bias=args.mlp_bias
)

def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class Olmo3DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Olmo3Attention(args)
self.mlp = Olmo3MLP(args)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.post_attention_layernorm(self.self_attn(x, mask, cache))
h = x + r
r = self.post_feedforward_layernorm(self.mlp(h))
out = h + r
return out


class Olmo3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
Olmo3DecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(inputs)

if cache is None:
cache = [None] * len(self.layers)

mask = create_attention_mask(h, cache[0])

for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)

return self.norm(h)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Olmo3Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out

def sanitize(self, weights):
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

@property
def layers(self):
return self.model.layers
1 change: 1 addition & 0 deletions mlx_lm/tuner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def to_lora(layer):
"qwen3_next",
"Klear",
"lille-130m",
"olmo3",
}:
keys = {"self_attn.q_proj", "self_attn.v_proj"}
if model.model_type in ["mixtral", "phimoe"]:
Expand Down