diff --git a/catgrad-llm/examples/llama/main.rs b/catgrad-llm/examples/llama/main.rs index 650f088..b80946a 100644 --- a/catgrad-llm/examples/llama/main.rs +++ b/catgrad-llm/examples/llama/main.rs @@ -2,11 +2,14 @@ use anyhow::Result; use catgrad::interpreter::backend::candle::CandleBackend; use catgrad::interpreter::backend::ndarray::NdArrayBackend; use catgrad::prelude::*; +use chrono::Local; use clap::{Parser, ValueEnum}; +use minijinja::{Environment, context}; +use minijinja_contrib::pycompat::unknown_method_callback; use std::io::Write; use std::path::PathBuf; -use catgrad_llm::utils::{get_model, load_model}; +use catgrad_llm::utils::{get_model, get_model_chat_template, load_model}; #[derive(Parser, Debug)] struct Args { @@ -23,6 +26,9 @@ struct Args { /// Initial prompt #[arg(short = 'p', long, default_value = "Category theory is")] prompt: String, + /// Pass raw prompt without chat template + #[arg(long)] + raw: bool, /// Tokens to generate #[arg(short = 's', long, default_value_t = 1)] max_seq_len: usize, @@ -41,6 +47,13 @@ struct Args { /// Load model from a previously dumped JSON graph #[arg(long)] load: Option, + /// Benchmark + #[arg( + long, + num_args = 2, + value_names = ["PP", "TG"] + )] + bench: Option>, } #[derive(Copy, Clone, Debug, ValueEnum)] @@ -49,6 +62,10 @@ enum BackendChoice { Candle, } +fn strftime_now(format_str: String) -> String { + Local::now().format(&format_str).to_string() +} + /// Construct, shapecheck, and interpret the a given LLM using the selected backend. fn main() -> Result<()> { env_logger::init(); @@ -63,13 +80,43 @@ fn run_with_backend(args: &Args, backend: B) -> Result< let (parameter_values, parameter_types, config, tokenizer) = load_model(&args.model_name, &args.revision, &backend)?; + let chat_template = + get_model_chat_template(&args.model_name, &args.revision).unwrap_or_default(); + + let benchmarking = args.bench.is_some(); + let mut pp = 0; + let mut tg = 0; + let mut max_seq_len = args.max_seq_len; + + let prompt = if let Some(bench) = &args.bench { + pp = bench[0]; + tg = bench[1]; + max_seq_len = tg; + println!( + "Benchmarking {} with prefill size {} and sequence length {}", + args.model_name, pp, tg + ); + "The".repeat(pp) + } else if chat_template.is_empty() || args.raw { + args.prompt.clone() + } else { + let mut env = Environment::new(); + env.set_unknown_method_callback(unknown_method_callback); + env.add_function("strftime_now", strftime_now); + env.add_template("chat", &chat_template).unwrap(); + let tmpl = env.get_template("chat").unwrap(); + tmpl.render( + context!(messages => vec![ context!(role => "user",content => args.prompt)], add_generation_prompt => true), + )? + }; + let encoding = tokenizer - .encode(args.prompt.clone(), true) + .encode(prompt.clone(), true) .map_err(|err| anyhow::anyhow!("check error {:?}", err))?; let mut token_ids = encoding.get_ids().to_vec(); - let max_sequence_length = args.max_seq_len + token_ids.len(); + let max_sequence_length = max_seq_len + token_ids.len(); let model = get_model(&config, max_sequence_length)?; let typed_term = if let Some(load_path) = &args.load { @@ -101,29 +148,52 @@ fn run_with_backend(args: &Args, backend: B) -> Result< .map_err(|err| anyhow::anyhow!("check error {:?}", err))?; } - print!("{}", args.prompt); - let start_gen = std::time::Instant::now(); + let mut generated_tokens = 0; + if !benchmarking { + print!("{}", prompt); + } + let mut start_gen = std::time::Instant::now(); + let mut elapsed_pp = std::time::Duration::ZERO; let interpreter = interpreter::Interpreter::new(backend, env, parameter_values); // Run interpreter - for _ in 0..args.max_seq_len { + for i in 0..max_seq_len { let next_token_id = run_interpreter(&typed_term, &interpreter, &token_ids)?; - if config.get_eos_token_ids().contains(&(next_token_id as i32)) { + if i == 0 { + elapsed_pp = start_gen.elapsed(); + start_gen = std::time::Instant::now(); + } + generated_tokens += 1; + if config.get_eos_token_ids().contains(&(next_token_id as i32)) && !benchmarking { break; } - let decoded_token = tokenizer.decode(&[next_token_id], false).unwrap(); token_ids.push(next_token_id); - print!("{}", decoded_token); - std::io::stdout().flush()?; + if !benchmarking { + let decoded_token = tokenizer.decode(&[next_token_id], false).unwrap(); + print!("{}", decoded_token); + std::io::stdout().flush()?; + } } let elapsed_gen = start_gen.elapsed(); - let generated_tokens = args.max_seq_len; println!( "\n{} tokens generated in {} seconds. ({:.2} tps)", generated_tokens, - elapsed_gen.as_secs(), - generated_tokens as f64 / elapsed_gen.as_secs_f64(), + (elapsed_pp + elapsed_gen).as_secs(), + generated_tokens as f64 / (elapsed_pp + elapsed_gen).as_secs_f64(), ); + + if benchmarking { + println!( + "PP {pp} in {} ms {:.2} tps", + elapsed_pp.as_millis(), + pp as f64 / elapsed_pp.as_secs_f64() + ); + println!( + "TG {tg} in {} ms {:.2} tps", + elapsed_gen.as_millis(), + tg as f64 / elapsed_gen.as_secs_f64() + ); + } Ok(()) }