Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 47 additions & 39 deletions catgrad-llm/src/models/gemma3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,47 @@ pub struct Gemma3Model {
pub max_sequence_length: usize,
}

// Gemma uses a non-standard RMSNorm implementation.
// Generic because of unpack needing the last dimension and it is being called
// with ranks 2 and 3 too.
fn rmsnorm_raw_gemma<const N: usize>(builder: &Builder, eps: f32, x: Var) -> Var {
let x_shape = shape(builder, x.clone());
let u = unpack::<N>(builder, x_shape.clone());
let n = u[N - 1].clone();
let s = sum(builder, x.clone() * x.clone());

let constn = nat_to_u32(builder, n);
let constn = cast(builder, constn, dtype(builder, x.clone()));
let sh = shape(builder, s.clone());
let constn = broadcast(builder, constn, sh);

let mean = s / constn;

let epsilon = constant(builder, eps, &shape(builder, mean.clone()));
let rms = sqrt(builder, mean + epsilon);
let denom = broadcast(builder, rms, x_shape);
x / denom
}

pub fn rmsnorm_gemma<const N: usize>(builder: &Builder, eps: f32, p: Path, x: Var) -> Var {
let gamma = param(builder, &p.extend(["weight"]).unwrap());
let lr = rmsnorm_raw_gemma::<N>(builder, eps, x);
let lr_shape = shape(builder, lr.clone());
let gamma = broadcast(builder, gamma, lr_shape);
let sh = shape(builder, gamma.clone());
let one = constant(builder, 1.0, &sh);
lr * (one + gamma)
}

impl Gemma3Model {
pub fn new(root: &str, config: GemmaTextConfig, max_sequence_length: usize) -> Self {
Gemma3Model {
root: root.to_string(),
config,
max_sequence_length,
}
}

fn softcap(&self, builder: &Builder, softcap: f32, x: Var) -> Var {
let sh = shape(builder, x.clone());
let s = constant(builder, softcap, &sh);
Expand Down Expand Up @@ -179,38 +219,6 @@ impl Gemma3Model {
)
}

// Gemma uses a non-standard RMSNorm implementation.
// Generic because of unpack needing the last dimension and it is being called
// with ranks 2 and 3 too.
pub fn rmsnorm_raw<const N: usize>(&self, builder: &Builder, eps: f32, x: Var) -> Var {
let x_shape = shape(builder, x.clone());
let u = unpack::<N>(builder, x_shape.clone());
let n = u[N - 1].clone();
let s = sum(builder, x.clone() * x.clone());

let constn = nat_to_u32(builder, n);
let constn = cast(builder, constn, dtype(builder, x.clone()));
let sh = shape(builder, s.clone());
let constn = broadcast(builder, constn, sh);

let mean = s / constn;

let epsilon = constant(builder, eps, &shape(builder, mean.clone()));
let rms = sqrt(builder, mean + epsilon);
let denom = broadcast(builder, rms, x_shape);
x / denom
}

fn rmsnorm<const N: usize>(&self, builder: &Builder, eps: f32, p: Path, x: Var) -> Var {
let gamma = param(builder, &p.extend(["weight"]).unwrap());
let lr = self.rmsnorm_raw::<N>(builder, eps, x);
let lr_shape = shape(builder, lr.clone());
let gamma = broadcast(builder, gamma, lr_shape);
let sh = shape(builder, gamma.clone());
let one = constant(builder, 1.0, &sh);
lr * (one + gamma)
}

fn attention(
&self,
builder: &Builder,
Expand Down Expand Up @@ -279,13 +287,13 @@ impl Gemma3Model {
let mut k = reshape(builder, sh, k);

if is_gemma3 {
q = self.rmsnorm::<2>(
q = rmsnorm_gemma::<2>(
builder,
self.config.rms_norm_eps,
p.extend(["q_norm"]).unwrap(),
q,
);
k = self.rmsnorm::<2>(
k = rmsnorm_gemma::<2>(
builder,
self.config.rms_norm_eps,
p.extend(["k_norm"]).unwrap(),
Expand Down Expand Up @@ -361,7 +369,7 @@ impl Gemma3Model {
x: Var,
) -> Var {
let res = x.clone();
let x = self.rmsnorm::<3>(
let x = rmsnorm_gemma::<3>(
builder,
self.config.rms_norm_eps,
p.extend(["input_layernorm"]).unwrap(),
Expand All @@ -375,22 +383,22 @@ impl Gemma3Model {
p.extend(["self_attn"]).unwrap(),
x,
);
let x = self.rmsnorm::<3>(
let x = rmsnorm_gemma::<3>(
builder,
self.config.rms_norm_eps,
p.extend(["post_attention_layernorm"]).unwrap(),
x,
);
let x = res + x;
let res = x.clone();
let x = self.rmsnorm::<3>(
let x = rmsnorm_gemma::<3>(
builder,
self.config.rms_norm_eps,
p.extend(["pre_feedforward_layernorm"]).unwrap(),
x,
);
let x = self.mlp(builder, p.extend(["mlp"]).unwrap(), x);
let x = self.rmsnorm::<3>(
let x = rmsnorm_gemma::<3>(
builder,
self.config.rms_norm_eps,
p.extend(["post_feedforward_layernorm"]).unwrap(),
Expand Down Expand Up @@ -433,7 +441,7 @@ impl Module<1, 1> for Gemma3Model {
);
}

x = self.rmsnorm::<3>(
x = rmsnorm_gemma::<3>(
builder,
self.config.rms_norm_eps,
root.extend(["norm"]).unwrap(),
Expand Down
2 changes: 1 addition & 1 deletion catgrad-llm/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ pub fn get_model(
max_sequence_length,
}))
}
"LlamaForCausalLM" => Ok(Box::new(llama::LlamaModel {
"MistralForCausalLM" | "LlamaForCausalLM" => Ok(Box::new(llama::LlamaModel {
config,
max_sequence_length,
})),
Expand Down