Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions catgrad-llm/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use minijinja_contrib::pycompat::unknown_method_callback;
use std::io::Write;
use std::path::PathBuf;

use catgrad_llm::utils::{get_model, get_model_chat_template, load_model};
use catgrad_llm::utils::{get_model, get_model_chat_template, load_model, parse_config};

#[derive(Parser, Debug)]
struct Args {
Expand Down Expand Up @@ -77,9 +77,11 @@ fn main() -> Result<()> {
}

fn run_with_backend<B: interpreter::Backend>(args: &Args, backend: B) -> Result<()> {
let (parameter_values, parameter_types, config, tokenizer, total_params) =
let (parameter_values, parameter_types, config_json, tokenizer, total_params) =
load_model(&args.model_name, &args.revision, &backend)?;

let config = parse_config(&config_json)?;

let chat_template =
get_model_chat_template(&args.model_name, &args.revision).unwrap_or_default();

Expand Down Expand Up @@ -117,7 +119,7 @@ fn run_with_backend<B: interpreter::Backend>(args: &Args, backend: B) -> Result<
let mut token_ids = encoding.get_ids().to_vec();

let max_sequence_length = max_seq_len + token_ids.len();
let model = get_model(&config, max_sequence_length)?;
let model = get_model(&config_json, max_sequence_length)?;

let typed_term = if let Some(load_path) = &args.load {
let file = std::fs::File::open(load_path)?;
Expand Down
69 changes: 51 additions & 18 deletions catgrad-llm/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,27 @@ pub enum RopeScaling {
Yarn(YarnRopeScaling),
}

pub trait LLMConfig {
fn num_hidden_layers(&self) -> usize;
fn num_key_value_heads(&self) -> usize;
fn num_local_experts(&self) -> usize;
fn rope_theta(&self) -> f32;
fn rope_scaling(&self) -> Option<RopeScaling>;
fn partial_rotary_factor(&self) -> f32;

fn get_head_dim(&self) -> usize;
fn get_qk_head_dim(&self) -> usize;
fn get_v_head_dim(&self) -> usize;
fn eos_token_id(&self) -> Option<EosTokenId>;
fn get_eos_token_ids(&self) -> Vec<i32> {
match self.eos_token_id() {
Some(EosTokenId::Single(id)) => vec![id],
Some(EosTokenId::Multiple(ref ids)) => ids.clone(),
None => vec![],
}
}
}

// This configuration contains the union of relevant fields from all supported models.
// Models ignore fields they don't need. The aliases are for GPT-2 alternative names.
#[derive(Debug, Clone, Default, Deserialize)]
Expand Down Expand Up @@ -104,9 +125,33 @@ fn default_partial_rotary_factor() -> f32 {
1.0
}

impl Config {
impl LLMConfig for Config {
fn num_hidden_layers(&self) -> usize {
self.num_hidden_layers
}
fn num_key_value_heads(&self) -> usize {
if self.num_key_value_heads == 0 {
self.num_attention_heads
} else {
self.num_key_value_heads
}
}

fn num_local_experts(&self) -> usize {
self.num_local_experts
}
fn rope_theta(&self) -> f32 {
self.rope_theta
}
fn rope_scaling(&self) -> Option<RopeScaling> {
self.rope_scaling.clone()
}
fn partial_rotary_factor(&self) -> f32 {
self.partial_rotary_factor
}

// Sometimes the head_dim fields is missing
pub fn get_head_dim(&self) -> usize {
fn get_head_dim(&self) -> usize {
if self.qk_rope_head_dim != 0 {
self.qk_rope_head_dim
} else if self.head_dim == 0 {
Expand All @@ -117,7 +162,7 @@ impl Config {
}

// DeepSeek Multihead Latent Attention uses different head dimensions for queries, keys and values
pub fn get_qk_head_dim(&self) -> usize {
fn get_qk_head_dim(&self) -> usize {
let qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim;
if qk_head_dim != 0 {
qk_head_dim
Expand All @@ -126,27 +171,15 @@ impl Config {
}
}

pub fn get_v_head_dim(&self) -> usize {
fn get_v_head_dim(&self) -> usize {
if self.v_head_dim != 0 {
self.v_head_dim
} else {
self.get_head_dim()
}
}

pub fn get_num_kv_heads(&self) -> usize {
if self.num_key_value_heads == 0 {
self.num_attention_heads
} else {
self.num_key_value_heads
}
}

pub fn get_eos_token_ids(&self) -> Vec<i32> {
match self.eos_token_id {
Some(EosTokenId::Single(id)) => vec![id],
Some(EosTokenId::Multiple(ref ids)) => ids.clone(),
None => vec![],
}
fn eos_token_id(&self) -> Option<EosTokenId> {
self.eos_token_id.clone()
}
}
35 changes: 23 additions & 12 deletions catgrad-llm/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::config::{Config, Llama3RopeScaling, RopeScaling, YarnRopeScaling};
use crate::config::{LLMConfig, Llama3RopeScaling, RopeScaling, YarnRopeScaling};
use catgrad::prelude::ops::*;
use catgrad::prelude::*;
use catgrad::stdlib::nn::*;
Expand Down Expand Up @@ -30,27 +30,28 @@ pub struct Cache {
}

impl Cache {
pub fn init(builder: &Builder, config: &Config, positions: usize) -> Self {
let (cos, sin) = match &config.rope_scaling {
pub fn init(builder: &Builder, config: &dyn LLMConfig, positions: usize) -> Self {
let (cos, sin) = match config.rope_scaling() {
Some(RopeScaling::Yarn(params)) => rope_tables_yarn(
builder,
config.rope_theta,
params,
config.rope_theta(),
&params,
positions,
config.get_head_dim(),
),
Some(RopeScaling::Llama3(params)) => rope_tables_llama3(
builder,
config.rope_theta,
params,
config.rope_theta(),
&params,
positions,
config.get_head_dim(),
),
_ => rope_tables(
builder,
config.rope_theta,
config.rope_theta(),
positions.to_nat(builder),
((config.get_head_dim() as f32) * config.partial_rotary_factor) as usize,
((config.get_head_dim() as f32) * config.partial_rotary_factor()) as usize,
1.0,
),
};

Expand Down Expand Up @@ -177,7 +178,13 @@ pub fn repeat_kv(builder: &Builder, rep: usize, x: Var) -> Var {
}

// Generate rope tables. This part is usually precomputed
pub fn rope_tables(builder: &Builder, theta: f32, seq_len: Var, head_dim: usize) -> (Var, Var) {
pub fn rope_tables(
builder: &Builder,
theta: f32,
seq_len: Var,
head_dim: usize,
factor: f32,
) -> (Var, Var) {
let half_dim = head_dim / 2;

let f = arange(builder, half_dim);
Expand All @@ -190,7 +197,10 @@ pub fn rope_tables(builder: &Builder, theta: f32, seq_len: Var, head_dim: usize)
let inv_freq = inverse(builder, freq);

let sh = shape!(builder, seq_len, half_dim);
let inv_freq = broadcast(builder, inv_freq, sh);
let inv_freq = broadcast(builder, inv_freq, sh.clone());

let factor = constant(builder, factor, &sh);
let inv_freq = inv_freq / factor;

let pos = arange(builder, seq_len.clone());
let pos = cast(builder, pos, Dtype::F32);
Expand Down Expand Up @@ -443,11 +453,12 @@ pub fn rope(
pos: impl IntoNatVar,
seq_len: &impl IntoNatVar,
head_dim: usize,
factor: f32,
x: Var,
) -> Var {
let pos = pos.to_nat(builder);
let seq_len = seq_len.to_nat(builder);
let (cos, sin) = rope_tables(builder, theta, pos.clone() + seq_len, head_dim);
let (cos, sin) = rope_tables(builder, theta, pos.clone() + seq_len, head_dim, factor);

apply_rope_embedding(builder, pos, head_dim, cos, sin, x)
}
Expand Down
Loading