diff --git a/catgrad-llm/src/models/gemma3.rs b/catgrad-llm/src/models/gemma3.rs index bf90fd8..630f384 100644 --- a/catgrad-llm/src/models/gemma3.rs +++ b/catgrad-llm/src/models/gemma3.rs @@ -146,7 +146,47 @@ pub struct Gemma3Model { pub max_sequence_length: usize, } +// Gemma uses a non-standard RMSNorm implementation. +// Generic because of unpack needing the last dimension and it is being called +// with ranks 2 and 3 too. +fn rmsnorm_raw_gemma(builder: &Builder, eps: f32, x: Var) -> Var { + let x_shape = shape(builder, x.clone()); + let u = unpack::(builder, x_shape.clone()); + let n = u[N - 1].clone(); + let s = sum(builder, x.clone() * x.clone()); + + let constn = nat_to_u32(builder, n); + let constn = cast(builder, constn, dtype(builder, x.clone())); + let sh = shape(builder, s.clone()); + let constn = broadcast(builder, constn, sh); + + let mean = s / constn; + + let epsilon = constant(builder, eps, &shape(builder, mean.clone())); + let rms = sqrt(builder, mean + epsilon); + let denom = broadcast(builder, rms, x_shape); + x / denom +} + +pub fn rmsnorm_gemma(builder: &Builder, eps: f32, p: Path, x: Var) -> Var { + let gamma = param(builder, &p.extend(["weight"]).unwrap()); + let lr = rmsnorm_raw_gemma::(builder, eps, x); + let lr_shape = shape(builder, lr.clone()); + let gamma = broadcast(builder, gamma, lr_shape); + let sh = shape(builder, gamma.clone()); + let one = constant(builder, 1.0, &sh); + lr * (one + gamma) +} + impl Gemma3Model { + pub fn new(root: &str, config: GemmaTextConfig, max_sequence_length: usize) -> Self { + Gemma3Model { + root: root.to_string(), + config, + max_sequence_length, + } + } + fn softcap(&self, builder: &Builder, softcap: f32, x: Var) -> Var { let sh = shape(builder, x.clone()); let s = constant(builder, softcap, &sh); @@ -179,38 +219,6 @@ impl Gemma3Model { ) } - // Gemma uses a non-standard RMSNorm implementation. - // Generic because of unpack needing the last dimension and it is being called - // with ranks 2 and 3 too. - pub fn rmsnorm_raw(&self, builder: &Builder, eps: f32, x: Var) -> Var { - let x_shape = shape(builder, x.clone()); - let u = unpack::(builder, x_shape.clone()); - let n = u[N - 1].clone(); - let s = sum(builder, x.clone() * x.clone()); - - let constn = nat_to_u32(builder, n); - let constn = cast(builder, constn, dtype(builder, x.clone())); - let sh = shape(builder, s.clone()); - let constn = broadcast(builder, constn, sh); - - let mean = s / constn; - - let epsilon = constant(builder, eps, &shape(builder, mean.clone())); - let rms = sqrt(builder, mean + epsilon); - let denom = broadcast(builder, rms, x_shape); - x / denom - } - - fn rmsnorm(&self, builder: &Builder, eps: f32, p: Path, x: Var) -> Var { - let gamma = param(builder, &p.extend(["weight"]).unwrap()); - let lr = self.rmsnorm_raw::(builder, eps, x); - let lr_shape = shape(builder, lr.clone()); - let gamma = broadcast(builder, gamma, lr_shape); - let sh = shape(builder, gamma.clone()); - let one = constant(builder, 1.0, &sh); - lr * (one + gamma) - } - fn attention( &self, builder: &Builder, @@ -279,13 +287,13 @@ impl Gemma3Model { let mut k = reshape(builder, sh, k); if is_gemma3 { - q = self.rmsnorm::<2>( + q = rmsnorm_gemma::<2>( builder, self.config.rms_norm_eps, p.extend(["q_norm"]).unwrap(), q, ); - k = self.rmsnorm::<2>( + k = rmsnorm_gemma::<2>( builder, self.config.rms_norm_eps, p.extend(["k_norm"]).unwrap(), @@ -361,7 +369,7 @@ impl Gemma3Model { x: Var, ) -> Var { let res = x.clone(); - let x = self.rmsnorm::<3>( + let x = rmsnorm_gemma::<3>( builder, self.config.rms_norm_eps, p.extend(["input_layernorm"]).unwrap(), @@ -375,7 +383,7 @@ impl Gemma3Model { p.extend(["self_attn"]).unwrap(), x, ); - let x = self.rmsnorm::<3>( + let x = rmsnorm_gemma::<3>( builder, self.config.rms_norm_eps, p.extend(["post_attention_layernorm"]).unwrap(), @@ -383,14 +391,14 @@ impl Gemma3Model { ); let x = res + x; let res = x.clone(); - let x = self.rmsnorm::<3>( + let x = rmsnorm_gemma::<3>( builder, self.config.rms_norm_eps, p.extend(["pre_feedforward_layernorm"]).unwrap(), x, ); let x = self.mlp(builder, p.extend(["mlp"]).unwrap(), x); - let x = self.rmsnorm::<3>( + let x = rmsnorm_gemma::<3>( builder, self.config.rms_norm_eps, p.extend(["post_feedforward_layernorm"]).unwrap(), @@ -433,7 +441,7 @@ impl Module<1, 1> for Gemma3Model { ); } - x = self.rmsnorm::<3>( + x = rmsnorm_gemma::<3>( builder, self.config.rms_norm_eps, root.extend(["norm"]).unwrap(), diff --git a/catgrad-llm/src/utils.rs b/catgrad-llm/src/utils.rs index ec6fc3f..69b9c20 100644 --- a/catgrad-llm/src/utils.rs +++ b/catgrad-llm/src/utils.rs @@ -242,7 +242,7 @@ pub fn get_model( max_sequence_length, })) } - "LlamaForCausalLM" => Ok(Box::new(llama::LlamaModel { + "MistralForCausalLM" | "LlamaForCausalLM" => Ok(Box::new(llama::LlamaModel { config, max_sequence_length, })),