diff --git a/catgrad-llm/examples/siglip/main.rs b/catgrad-llm/examples/siglip/main.rs index 756d89a..eedd70f 100644 --- a/catgrad-llm/examples/siglip/main.rs +++ b/catgrad-llm/examples/siglip/main.rs @@ -575,10 +575,7 @@ pub fn main() -> Result<(), Box> { let interp = Interpreter::new(backend, env, parameters); - let results = interp.eval( - interp.environment.to_core(typed_term.term), - vec![input_tensor, image_tensor], - )?; + let results = interp.run(typed_term.term, vec![input_tensor, image_tensor])?; let result_tensor = match &results[1] { interpreter::Value::Tensor(t) => t, _ => panic!("Expected tensor output"), diff --git a/catgrad-llm/src/models/gemma3.rs b/catgrad-llm/src/models/gemma3.rs index 630f384..c32499e 100644 --- a/catgrad-llm/src/models/gemma3.rs +++ b/catgrad-llm/src/models/gemma3.rs @@ -11,6 +11,7 @@ pub enum GemmaConfig { VLM { text_config: GemmaTextConfig, image_token_index: usize, + #[serde(default)] mm_tokens_per_image: usize, }, #[serde(untagged)] diff --git a/catgrad-llm/src/models/llama.rs b/catgrad-llm/src/models/llama.rs index f03a365..88102b1 100644 --- a/catgrad-llm/src/models/llama.rs +++ b/catgrad-llm/src/models/llama.rs @@ -4,6 +4,7 @@ use catgrad::prelude::ops::*; use catgrad::prelude::*; use nn::*; pub struct LlamaModel { + pub root: String, pub config: Config, pub max_sequence_length: usize, } @@ -65,8 +66,8 @@ impl LlamaModel { let sh = shape!(builder, b, s, num_kv_heads, head_dim); let k = reshape(builder, sh.clone(), k); - let v = reshape(builder, sh, v); + let sh = shape!(builder, b, s, num_heads, head_dim); let q = reshape(builder, sh, q); @@ -157,7 +158,12 @@ impl Module<1, 1> for LlamaModel { } fn def(&self, builder: &Builder, [x]: [Var; 1]) -> [Var; 1] { - let root = self.path(); + let mut root = self.path(); + if !self.root.is_empty() { + root = root + .extend(self.root.split('.').collect::>()) + .unwrap(); + } let mut cache = Cache::init(builder, &self.config, self.max_sequence_length); diff --git a/catgrad-llm/src/utils.rs b/catgrad-llm/src/utils.rs index 306cd5a..32ea9d6 100644 --- a/catgrad-llm/src/utils.rs +++ b/catgrad-llm/src/utils.rs @@ -243,6 +243,7 @@ pub fn get_model( )) } "MistralForCausalLM" | "LlamaForCausalLM" => Box::new(llama::LlamaModel { + root: "".to_string(), config: config.clone(), max_sequence_length, }),