diff --git a/candle-holder-examples/examples/chat/README.md b/candle-holder-examples/examples/chat/README.md new file mode 100644 index 0000000..807e1cc --- /dev/null +++ b/candle-holder-examples/examples/chat/README.md @@ -0,0 +1,7 @@ +# Streamed Text Generation + +## Running the example + +```bash +cargo run --example streamed_text_generation --features cuda,flash-attn -- --device cuda:0 --prompt "What's the three body problem?" --apply-chat-template +``` diff --git a/candle-holder-examples/examples/chat/main.rs b/candle-holder-examples/examples/chat/main.rs new file mode 100644 index 0000000..337fae6 --- /dev/null +++ b/candle-holder-examples/examples/chat/main.rs @@ -0,0 +1,153 @@ +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use std::{ + io::{self, Write}, + sync::{mpsc, Arc, Mutex}, +}; + +use anyhow::{Error, Result}; +use candle_holder_examples::Cli; +use candle_holder_models::{ + AutoModelForCausalLM, GenerationConfig, GenerationParams, TextIteratorStreamer, TokenStreamer, +}; +use candle_holder_tokenizers::{AutoTokenizer, Message}; +use clap::Parser; + +#[derive(Debug, Parser)] +pub struct GenerationCli { + #[command(flatten)] + pub base: Cli, + + #[arg(long, default_value = "meta-llama/Meta-Llama-3.1-8B-Instruct")] + pub model: String, + + #[arg(short, long)] + pub do_sample: bool, + + #[arg(long, default_value = "0.6")] + pub temperature: f64, + + #[arg(long, default_value = "0.9")] + pub top_p: f32, + + #[arg(long, default_value = "50")] + pub top_k: usize, + + #[arg(long)] + pub system_prompt: Option, +} + +fn main() -> Result<()> { + let args = GenerationCli::parse(); + + let device = args.base.get_device()?; + println!("Model: {}", args.model); + println!("Device: {:?}", device); + + // Load the model and the tokenizer + let tokenizer = AutoTokenizer::from_pretrained(args.model.clone(), None, None)?; + let model = AutoModelForCausalLM::from_pretrained(args.model, &device, None, None)?; + + // Create a token streamer to stream the generated tokens by the model + let (token_streamer, output_receiver) = + TextIteratorStreamer::new(tokenizer.clone(), true, true); + let token_streamer: Arc> = Arc::new(Mutex::new(token_streamer)); + + // Run the model generation loop in a background thread + let (sender, receiver) = mpsc::channel::>(); + let generation_handle = std::thread::spawn(move || { + let mut messages = Vec::new(); + + if let Some(system_prompt) = args.system_prompt { + messages.push(Message::system(system_prompt)); + } + + while let Ok(message) = receiver.recv() { + if message.is_none() { + println!("Stopping generation loop..."); + break; + } + + let prompt = message.unwrap(); + messages.push(Message::user(prompt)); + + let mut encodings = tokenizer + .apply_chat_template_and_encode(messages.clone(), true) + .map_err(Error::msg) + .unwrap(); + encodings.to_device(&device).unwrap(); + + let input_ids = encodings.get_input_ids(); + + let params = GenerationParams::new() + .with_generation_config(GenerationConfig { + do_sample: args.do_sample, + top_p: Some(args.top_p), + top_k: Some(args.top_k), + temperature: args.temperature, + max_new_tokens: Some(2048), + ..GenerationConfig::default() + }) + .with_tokenizer(tokenizer.clone()) + .with_token_streamer(token_streamer.clone()); + + let output = model.generate(input_ids, params).unwrap(); + + let inputs_prompt_length: usize = input_ids + .to_vec2::() + .unwrap() + .first() + .map(|seq_input_ids| tokenizer.decode(&seq_input_ids[..], true).unwrap().len()) + .unwrap_or(0); + let sequence = &output[0].get_sequences()[0]; + let system_message: String = tokenizer + .decode(&sequence, true) + .unwrap() + .chars() + .skip(inputs_prompt_length) + .collect(); + messages.push(Message::system(system_message)); + } + }); + + // User input loop + loop { + // read user input + let mut input = String::new(); + print!("> "); + std::io::stdout().flush().unwrap(); + std::io::stdin().read_line(&mut input).unwrap(); + let input = input.trim().to_string(); + + if input.is_empty() { + continue; + } + + if input.to_lowercase() == "/quit" || input.to_lowercase() == "/exit" { + sender.send(None)?; + break; + } + + // Send the user message to the background thread + sender.send(Some(input))?; + + // Print the new tokens generated by the model in the background thread + while let Ok(message) = output_receiver.recv() { + if let Some(text) = message { + print!("{}", text); + io::stdout().flush().unwrap(); + } else { + println!(""); + break; + } + } + } + + generation_handle.join().unwrap(); + + Ok(()) +} diff --git a/candle-holder-models/src/generation/generate.rs b/candle-holder-models/src/generation/generate.rs index 8846152..07da563 100644 --- a/candle-holder-models/src/generation/generate.rs +++ b/candle-holder-models/src/generation/generate.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use candle_core::{IndexOp, Tensor}; use candle_holder::{Error, Result}; @@ -39,7 +39,7 @@ pub fn generate( generation_config: &GenerationConfig, tokenizer: Option>, stopping_criteria: Option>>, - mut token_streamer: Option>, + mut token_streamer: Option>>, seed: Option, ) -> Result> { let num_return_sequences = generation_config.get_num_return_sequences().max(1); @@ -143,18 +143,18 @@ pub fn generate( } fn stream_tokens( - token_streamer: &mut Option>, + token_streamer: &mut Option>>, tokens: &[Vec], ) -> Result<()> { - if let Some(streamer) = token_streamer.as_mut() { - streamer.put(tokens)?; + if let Some(streamer) = token_streamer { + streamer.lock().unwrap().put(tokens)?; } Ok(()) } -fn stream_end(token_streamer: &mut Option>) -> Result<()> { +fn stream_end(token_streamer: &mut Option>>) -> Result<()> { if let Some(streamer) = token_streamer.as_mut() { - streamer.end()?; + streamer.lock().unwrap().end()?; } Ok(()) } diff --git a/candle-holder-models/src/generation/token_streamer.rs b/candle-holder-models/src/generation/token_streamer.rs index 438df1c..89b2378 100644 --- a/candle-holder-models/src/generation/token_streamer.rs +++ b/candle-holder-models/src/generation/token_streamer.rs @@ -1,6 +1,9 @@ use candle_holder::{Error, Result}; use candle_holder_tokenizers::Tokenizer; -use std::{io::{self, Write}, sync::{mpsc, Arc}}; +use std::{ + io::{self, Write}, + sync::{mpsc, Arc}, +}; /// A trait for streamers that receives tokens generated by `generate` method using an /// auto-regressive model. Streamers can be used to handle tokens as they are generated by the @@ -76,7 +79,6 @@ impl TextStreamer { Ok(self.process_text(&text)) } - fn process_text(&mut self, full_text: &str) -> Option { let new_text = &full_text[self.print_len..]; @@ -132,11 +134,11 @@ impl TokenStreamer for TextStreamer { fn end(&mut self) -> Result<()> { let printable_text = self.flush_tokens()?; + self.next_tokens_are_prompt = false; self.print(printable_text, true) } } - pub struct TextIteratorStreamer { text_streamer: TextStreamer, sender: mpsc::Sender>, @@ -151,12 +153,12 @@ impl TextIteratorStreamer { let (sender, receiver) = mpsc::channel(); let streamer = Self { text_streamer: TextStreamer::new(tokenizer, skip_prompt, skip_special_tokens), - sender + sender, }; (streamer, receiver) } - fn send(&self, t: Option) -> Result<()>{ + fn send(&self, t: Option) -> Result<()> { self.sender.send(t).map_err(|e| Error::msg(e.to_string())) } } @@ -171,10 +173,9 @@ impl TokenStreamer for TextIteratorStreamer { fn end(&mut self) -> Result<()> { let text = self.text_streamer.flush_tokens()?; - if !text.is_empty() { - self.send(Some(text))?; - } + self.send(Some(text))?; self.send(None)?; + self.text_streamer.next_tokens_are_prompt = true; Ok(()) } } diff --git a/candle-holder-models/src/model.rs b/candle-holder-models/src/model.rs index bad6db6..5de3c25 100644 --- a/candle-holder-models/src/model.rs +++ b/candle-holder-models/src/model.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use candle_core::{DType, Device, Tensor}; use candle_holder::{utils::from_pretrained::FromPretrainedParameters, Error, Result}; @@ -103,11 +103,45 @@ pub struct GenerationParams { pub stopping_criteria: Option>>, /// The token streamer which will receive the next tokens as they are being generated. The /// default value is `None`. - pub token_streamer: Option>, + pub token_streamer: Option>>, /// A seed that will be used in the sampling of the next token. The default value is `None`. pub seed: Option, } +impl GenerationParams { + pub fn new() -> Self { + Self::default() + } + + pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self { + self.generation_config = Some(generation_config); + self + } + + pub fn with_tokenizer(mut self, tokenizer: Arc) -> Self { + self.tokenizer = Some(tokenizer); + self + } + + pub fn with_stopping_criteria( + mut self, + stopping_criteria: Vec>, + ) -> Self { + self.stopping_criteria = Some(stopping_criteria); + self + } + + pub fn with_token_streamer(mut self, token_streamer: Arc>) -> Self { + self.token_streamer = Some(token_streamer); + self + } + + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = Some(seed); + self + } +} + impl Default for GenerationParams { fn default() -> Self { Self { diff --git a/candle-holder-models/src/utils/var_builder.rs b/candle-holder-models/src/utils/var_builder.rs index 25d5e5f..b462eff 100644 --- a/candle-holder-models/src/utils/var_builder.rs +++ b/candle-holder-models/src/utils/var_builder.rs @@ -68,9 +68,7 @@ impl CompatibilityTensorRetrievalBackend { } // Try removing the model name prefix - let without_prefix = name - .strip_prefix(&format!("{}.", self.model_name)) - .unwrap_or(name); + let without_prefix = name.strip_prefix(&self.model_name).unwrap_or(name); // Function to replace weight/bias with beta/gamma let replace_weight_bias = |s: &str| s.replace("weight", "gamma").replace("bias", "beta");