Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 83 additions & 13 deletions catgrad-llm/examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ use anyhow::Result;
use catgrad::interpreter::backend::candle::CandleBackend;
use catgrad::interpreter::backend::ndarray::NdArrayBackend;
use catgrad::prelude::*;
use chrono::Local;
use clap::{Parser, ValueEnum};
use minijinja::{Environment, context};
use minijinja_contrib::pycompat::unknown_method_callback;
use std::io::Write;
use std::path::PathBuf;

use catgrad_llm::utils::{get_model, load_model};
use catgrad_llm::utils::{get_model, get_model_chat_template, load_model};

#[derive(Parser, Debug)]
struct Args {
Expand All @@ -23,6 +26,9 @@ struct Args {
/// Initial prompt
#[arg(short = 'p', long, default_value = "Category theory is")]
prompt: String,
/// Pass raw prompt without chat template
#[arg(long)]
raw: bool,
/// Tokens to generate
#[arg(short = 's', long, default_value_t = 1)]
max_seq_len: usize,
Expand All @@ -41,6 +47,13 @@ struct Args {
/// Load model from a previously dumped JSON graph
#[arg(long)]
load: Option<PathBuf>,
/// Benchmark
#[arg(
long,
num_args = 2,
value_names = ["PP", "TG"]
)]
bench: Option<Vec<usize>>,
}

#[derive(Copy, Clone, Debug, ValueEnum)]
Expand All @@ -49,6 +62,10 @@ enum BackendChoice {
Candle,
}

fn strftime_now(format_str: String) -> String {
Local::now().format(&format_str).to_string()
}

/// Construct, shapecheck, and interpret the a given LLM using the selected backend.
fn main() -> Result<()> {
env_logger::init();
Expand All @@ -63,13 +80,43 @@ fn run_with_backend<B: interpreter::Backend>(args: &Args, backend: B) -> Result<
let (parameter_values, parameter_types, config, tokenizer) =
load_model(&args.model_name, &args.revision, &backend)?;

let chat_template =
get_model_chat_template(&args.model_name, &args.revision).unwrap_or_default();

let benchmarking = args.bench.is_some();
let mut pp = 0;
let mut tg = 0;
let mut max_seq_len = args.max_seq_len;

let prompt = if let Some(bench) = &args.bench {
pp = bench[0];
tg = bench[1];
max_seq_len = tg;
println!(
"Benchmarking {} with prefill size {} and sequence length {}",
args.model_name, pp, tg
);
"The".repeat(pp)
} else if chat_template.is_empty() || args.raw {
args.prompt.clone()
} else {
let mut env = Environment::new();
env.set_unknown_method_callback(unknown_method_callback);
env.add_function("strftime_now", strftime_now);
env.add_template("chat", &chat_template).unwrap();
let tmpl = env.get_template("chat").unwrap();
tmpl.render(
context!(messages => vec![ context!(role => "user",content => args.prompt)], add_generation_prompt => true),
)?
};

let encoding = tokenizer
.encode(args.prompt.clone(), true)
.encode(prompt.clone(), true)
.map_err(|err| anyhow::anyhow!("check error {:?}", err))?;

let mut token_ids = encoding.get_ids().to_vec();

let max_sequence_length = args.max_seq_len + token_ids.len();
let max_sequence_length = max_seq_len + token_ids.len();
let model = get_model(&config, max_sequence_length)?;

let typed_term = if let Some(load_path) = &args.load {
Expand Down Expand Up @@ -101,29 +148,52 @@ fn run_with_backend<B: interpreter::Backend>(args: &Args, backend: B) -> Result<
.map_err(|err| anyhow::anyhow!("check error {:?}", err))?;
}

print!("{}", args.prompt);
let start_gen = std::time::Instant::now();
let mut generated_tokens = 0;
if !benchmarking {
print!("{}", prompt);
}
let mut start_gen = std::time::Instant::now();
let mut elapsed_pp = std::time::Duration::ZERO;
let interpreter = interpreter::Interpreter::new(backend, env, parameter_values);
// Run interpreter
for _ in 0..args.max_seq_len {
for i in 0..max_seq_len {
let next_token_id = run_interpreter(&typed_term, &interpreter, &token_ids)?;
if config.get_eos_token_ids().contains(&(next_token_id as i32)) {
if i == 0 {
elapsed_pp = start_gen.elapsed();
start_gen = std::time::Instant::now();
}
generated_tokens += 1;
if config.get_eos_token_ids().contains(&(next_token_id as i32)) && !benchmarking {
break;
}
let decoded_token = tokenizer.decode(&[next_token_id], false).unwrap();
token_ids.push(next_token_id);
print!("{}", decoded_token);
std::io::stdout().flush()?;
if !benchmarking {
let decoded_token = tokenizer.decode(&[next_token_id], false).unwrap();
print!("{}", decoded_token);
std::io::stdout().flush()?;
}
}

let elapsed_gen = start_gen.elapsed();
let generated_tokens = args.max_seq_len;
println!(
"\n{} tokens generated in {} seconds. ({:.2} tps)",
generated_tokens,
elapsed_gen.as_secs(),
generated_tokens as f64 / elapsed_gen.as_secs_f64(),
(elapsed_pp + elapsed_gen).as_secs(),
generated_tokens as f64 / (elapsed_pp + elapsed_gen).as_secs_f64(),
);

if benchmarking {
println!(
"PP {pp} in {} ms {:.2} tps",
elapsed_pp.as_millis(),
pp as f64 / elapsed_pp.as_secs_f64()
);
println!(
"TG {tg} in {} ms {:.2} tps",
elapsed_gen.as_millis(),
tg as f64 / elapsed_gen.as_secs_f64()
);
}
Ok(())
}

Expand Down