Skip to content

Commit

Permalink
Add chat example
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Nov 13, 2024
1 parent 38df7b5 commit 01c3693
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 20 deletions.
7 changes: 7 additions & 0 deletions candle-holder-examples/examples/chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Streamed Text Generation

## Running the example

```bash
cargo run --example streamed_text_generation --features cuda,flash-attn -- --device cuda:0 --prompt "What's the three body problem?" --apply-chat-template
```
153 changes: 153 additions & 0 deletions candle-holder-examples/examples/chat/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;

#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::{
io::{self, Write},
sync::{mpsc, Arc, Mutex},
};

use anyhow::{Error, Result};
use candle_holder_examples::Cli;
use candle_holder_models::{
AutoModelForCausalLM, GenerationConfig, GenerationParams, TextIteratorStreamer, TokenStreamer,
};
use candle_holder_tokenizers::{AutoTokenizer, Message};
use clap::Parser;

#[derive(Debug, Parser)]
pub struct GenerationCli {
#[command(flatten)]
pub base: Cli,

#[arg(long, default_value = "meta-llama/Meta-Llama-3.1-8B-Instruct")]
pub model: String,

#[arg(short, long)]
pub do_sample: bool,

#[arg(long, default_value = "0.6")]
pub temperature: f64,

#[arg(long, default_value = "0.9")]
pub top_p: f32,

#[arg(long, default_value = "50")]
pub top_k: usize,

#[arg(long)]
pub system_prompt: Option<String>,
}

fn main() -> Result<()> {
let args = GenerationCli::parse();

let device = args.base.get_device()?;
println!("Model: {}", args.model);
println!("Device: {:?}", device);

// Load the model and the tokenizer
let tokenizer = AutoTokenizer::from_pretrained(args.model.clone(), None, None)?;
let model = AutoModelForCausalLM::from_pretrained(args.model, &device, None, None)?;

// Create a token streamer to stream the generated tokens by the model
let (token_streamer, output_receiver) =
TextIteratorStreamer::new(tokenizer.clone(), true, true);
let token_streamer: Arc<Mutex<dyn TokenStreamer>> = Arc::new(Mutex::new(token_streamer));

// Run the model generation loop in a background thread
let (sender, receiver) = mpsc::channel::<Option<String>>();
let generation_handle = std::thread::spawn(move || {
let mut messages = Vec::new();

if let Some(system_prompt) = args.system_prompt {
messages.push(Message::system(system_prompt));
}

while let Ok(message) = receiver.recv() {
if message.is_none() {
println!("Stopping generation loop...");
break;
}

let prompt = message.unwrap();
messages.push(Message::user(prompt));

let mut encodings = tokenizer
.apply_chat_template_and_encode(messages.clone(), true)
.map_err(Error::msg)
.unwrap();
encodings.to_device(&device).unwrap();

let input_ids = encodings.get_input_ids();

let params = GenerationParams::new()
.with_generation_config(GenerationConfig {
do_sample: args.do_sample,
top_p: Some(args.top_p),
top_k: Some(args.top_k),
temperature: args.temperature,
max_new_tokens: Some(2048),
..GenerationConfig::default()
})
.with_tokenizer(tokenizer.clone())
.with_token_streamer(token_streamer.clone());

let output = model.generate(input_ids, params).unwrap();

let inputs_prompt_length: usize = input_ids
.to_vec2::<u32>()
.unwrap()
.first()
.map(|seq_input_ids| tokenizer.decode(&seq_input_ids[..], true).unwrap().len())
.unwrap_or(0);
let sequence = &output[0].get_sequences()[0];
let system_message: String = tokenizer
.decode(&sequence, true)
.unwrap()
.chars()
.skip(inputs_prompt_length)
.collect();
messages.push(Message::system(system_message));
}
});

// User input loop
loop {
// read user input
let mut input = String::new();
print!("> ");
std::io::stdout().flush().unwrap();
std::io::stdin().read_line(&mut input).unwrap();
let input = input.trim().to_string();

if input.is_empty() {
continue;
}

if input.to_lowercase() == "/quit" || input.to_lowercase() == "/exit" {
sender.send(None)?;
break;
}

// Send the user message to the background thread
sender.send(Some(input))?;

// Print the new tokens generated by the model in the background thread
while let Ok(message) = output_receiver.recv() {
if let Some(text) = message {
print!("{}", text);
io::stdout().flush().unwrap();
} else {
println!("");
break;
}
}
}

generation_handle.join().unwrap();

Ok(())
}
14 changes: 7 additions & 7 deletions candle-holder-models/src/generation/generate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use candle_core::{IndexOp, Tensor};
use candle_holder::{Error, Result};
Expand Down Expand Up @@ -39,7 +39,7 @@ pub fn generate<M: PreTrainedModel + ?Sized>(
generation_config: &GenerationConfig,
tokenizer: Option<Arc<dyn Tokenizer>>,
stopping_criteria: Option<Vec<Box<dyn StoppingCriteria>>>,
mut token_streamer: Option<Box<dyn TokenStreamer>>,
mut token_streamer: Option<Arc<Mutex<dyn TokenStreamer>>>,
seed: Option<u64>,
) -> Result<Vec<GenerateOutput>> {
let num_return_sequences = generation_config.get_num_return_sequences().max(1);
Expand Down Expand Up @@ -143,18 +143,18 @@ pub fn generate<M: PreTrainedModel + ?Sized>(
}

fn stream_tokens(
token_streamer: &mut Option<Box<dyn TokenStreamer + '_>>,
token_streamer: &mut Option<Arc<Mutex<dyn TokenStreamer>>>,
tokens: &[Vec<u32>],
) -> Result<()> {
if let Some(streamer) = token_streamer.as_mut() {
streamer.put(tokens)?;
if let Some(streamer) = token_streamer {
streamer.lock().unwrap().put(tokens)?;
}
Ok(())
}

fn stream_end(token_streamer: &mut Option<Box<dyn TokenStreamer + '_>>) -> Result<()> {
fn stream_end(token_streamer: &mut Option<Arc<Mutex<dyn TokenStreamer>>>) -> Result<()> {
if let Some(streamer) = token_streamer.as_mut() {
streamer.end()?;
streamer.lock().unwrap().end()?;
}
Ok(())
}
Expand Down
17 changes: 9 additions & 8 deletions candle-holder-models/src/generation/token_streamer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use candle_holder::{Error, Result};
use candle_holder_tokenizers::Tokenizer;
use std::{io::{self, Write}, sync::{mpsc, Arc}};
use std::{
io::{self, Write},
sync::{mpsc, Arc},
};

/// A trait for streamers that receives tokens generated by `generate` method using an
/// auto-regressive model. Streamers can be used to handle tokens as they are generated by the
Expand Down Expand Up @@ -76,7 +79,6 @@ impl TextStreamer {
Ok(self.process_text(&text))
}


fn process_text(&mut self, full_text: &str) -> Option<String> {
let new_text = &full_text[self.print_len..];

Expand Down Expand Up @@ -132,11 +134,11 @@ impl TokenStreamer for TextStreamer {

fn end(&mut self) -> Result<()> {
let printable_text = self.flush_tokens()?;
self.next_tokens_are_prompt = false;
self.print(printable_text, true)
}
}


pub struct TextIteratorStreamer {
text_streamer: TextStreamer,
sender: mpsc::Sender<Option<String>>,
Expand All @@ -151,12 +153,12 @@ impl TextIteratorStreamer {
let (sender, receiver) = mpsc::channel();
let streamer = Self {
text_streamer: TextStreamer::new(tokenizer, skip_prompt, skip_special_tokens),
sender
sender,
};
(streamer, receiver)
}

fn send(&self, t: Option<String>) -> Result<()>{
fn send(&self, t: Option<String>) -> Result<()> {
self.sender.send(t).map_err(|e| Error::msg(e.to_string()))
}
}
Expand All @@ -171,10 +173,9 @@ impl TokenStreamer for TextIteratorStreamer {

fn end(&mut self) -> Result<()> {
let text = self.text_streamer.flush_tokens()?;
if !text.is_empty() {
self.send(Some(text))?;
}
self.send(Some(text))?;
self.send(None)?;
self.text_streamer.next_tokens_are_prompt = true;
Ok(())
}
}
Expand Down
38 changes: 36 additions & 2 deletions candle-holder-models/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use candle_core::{DType, Device, Tensor};
use candle_holder::{utils::from_pretrained::FromPretrainedParameters, Error, Result};
Expand Down Expand Up @@ -103,11 +103,45 @@ pub struct GenerationParams {
pub stopping_criteria: Option<Vec<Box<dyn StoppingCriteria>>>,
/// The token streamer which will receive the next tokens as they are being generated. The
/// default value is `None`.
pub token_streamer: Option<Box<dyn TokenStreamer>>,
pub token_streamer: Option<Arc<Mutex<dyn TokenStreamer>>>,
/// A seed that will be used in the sampling of the next token. The default value is `None`.
pub seed: Option<u64>,
}

impl GenerationParams {
pub fn new() -> Self {
Self::default()
}

pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self {
self.generation_config = Some(generation_config);
self
}

pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
self.tokenizer = Some(tokenizer);
self
}

pub fn with_stopping_criteria(
mut self,
stopping_criteria: Vec<Box<dyn StoppingCriteria>>,
) -> Self {
self.stopping_criteria = Some(stopping_criteria);
self
}

pub fn with_token_streamer(mut self, token_streamer: Arc<Mutex<dyn TokenStreamer>>) -> Self {
self.token_streamer = Some(token_streamer);
self
}

pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}

impl Default for GenerationParams {
fn default() -> Self {
Self {
Expand Down
4 changes: 1 addition & 3 deletions candle-holder-models/src/utils/var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ impl CompatibilityTensorRetrievalBackend {
}

// Try removing the model name prefix
let without_prefix = name
.strip_prefix(&format!("{}.", self.model_name))
.unwrap_or(name);
let without_prefix = name.strip_prefix(&self.model_name).unwrap_or(name);

// Function to replace weight/bias with beta/gamma
let replace_weight_bias = |s: &str| s.replace("weight", "gamma").replace("bias", "beta");
Expand Down

0 comments on commit 01c3693

Please sign in to comment.