Skip to content

Commit

Permalink
Add streamed text generation example
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 1, 2024
1 parent b55e50a commit fe07d24
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 34 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ accelerate-src = { version = "0.3.2" }
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
minijinja = "2.2.0"
minijinja-contrib = { version = "2.2.0", features = ["pycompat"] }
clap = { version = "4.5.16", features = ["derive"] }
1 change: 1 addition & 0 deletions candle-holder-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ candle-nn = { workspace = true }
accelerate-src = { workspace = true, optional = true }
intel-mkl-src = { workspace = true, optional = true }
anyhow = "1.0.86"
clap = { workspace = true }

[features]
metal = ["candle-core/metal", "candle-holder-models/metal"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Streamed Text Generation

## Running the example

```bash
cargo run --example streamed_text_generation --features cuda,flash-attn -- --device cuda:0 --prompt "What's the three body problem?" --apply-chat-template
```
85 changes: 85 additions & 0 deletions candle-holder-examples/examples/streamed_text_generation/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use anyhow::{Error, Result};
use candle_holder_examples::Cli;
use candle_holder_models::{
AutoModelForCausalLM, GenerationConfig, GenerationParams, TextStreamer, TokenStreamer,
};
use candle_holder_tokenizers::{AutoTokenizer, Message};
use clap::Parser;

#[derive(Debug, Parser)]
pub struct GenerationCli {
#[command(flatten)]
pub base: Cli,

#[arg(long, default_value = "meta-llama/Meta-Llama-3.1-8B-Instruct")]
pub model: String,

#[arg(long, default_value = "0.6")]
pub temperature: f64,

#[arg(long, default_value = "0.9")]
pub top_p: f32,

#[arg(long, default_value = "50")]
pub top_k: usize,

#[arg(long, default_value = "1024")]
pub max_new_tokens: usize,

#[arg(long)]
pub system_prompt: Option<String>,

#[arg(long, required = true)]
pub prompt: String,

#[arg(long, default_value = "false")]
pub apply_chat_template: bool,
}

fn main() -> Result<()> {
let args = GenerationCli::parse();

let device = args.base.get_device()?;
println!("Device: {:?}", device);

let tokenizer = AutoTokenizer::from_pretrained(args.model.clone(), None, None)?;
let model = AutoModelForCausalLM::from_pretrained(args.model, &device, None, None)?;

let mut encodings = if args.apply_chat_template {
tokenizer
.apply_chat_template_and_encode(vec![Message::user(args.prompt)], true)
.map_err(Error::msg)?
} else {
tokenizer
.encode(vec![args.prompt], true, None)
.map_err(Error::msg)?
};
encodings.to_device(&device)?;

let start = std::time::Instant::now();

let token_streamer: Box<dyn TokenStreamer> =
Box::new(TextStreamer::new(&tokenizer, true, true));

let input_ids = encodings.get_input_ids();
model.generate(
input_ids,
GenerationParams {
generation_config: Some(GenerationConfig {
do_sample: true,
top_p: Some(args.top_p),
top_k: Some(args.top_k),
temperature: args.temperature,
max_new_tokens: Some(args.max_new_tokens),
..GenerationConfig::default()
}),
tokenizer: Some(&tokenizer),
token_streamer: Some(token_streamer),
..Default::default()
},
)?;

println!("\nTook: {:?}", start.elapsed());

Ok(())
}
73 changes: 40 additions & 33 deletions candle-holder-examples/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,56 @@
use anyhow::Result;
use std::env;

use anyhow::{anyhow, Result};
use candle_core::Device;
use clap::Parser;
use std::str::FromStr;

pub enum DeviceOption {
Cpu,
Metal,
Cuda(usize),
#[derive(Debug, Parser)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
#[arg(long, value_parser = parse_device, default_value = "cpu")]
pub device: DeviceOption,
}

pub fn get_device(device: Option<DeviceOption>) -> Result<Device> {
let device = match device {
Some(DeviceOption::Cuda(device_id)) if cfg!(feature = "cuda") => {
Device::new_cuda(device_id)?
impl Cli {
pub fn get_device(&self) -> Result<Device> {
match self.device {
DeviceOption::Cuda(device_id) if cfg!(feature = "cuda") => {
Ok(Device::new_cuda(device_id)?)
}
DeviceOption::Metal if cfg!(feature = "metal") => Ok(Device::new_metal(0)?),
DeviceOption::Cpu => Ok(Device::Cpu),
_ => Err(anyhow!("Requested device is not available")),
}
Some(DeviceOption::Metal) if cfg!(feature = "metal") => Device::new_metal(0)?,
_ => Device::Cpu,
};
}
}

Ok(device)
#[derive(Clone, Debug)]
pub enum DeviceOption {
Cpu,
Metal,
Cuda(usize),
}

pub fn parse_device_option() -> Option<DeviceOption> {
let args: Vec<String> = env::args().collect();
impl FromStr for DeviceOption {
type Err = anyhow::Error;

// Expecting something like: --device cpu, --device metal, or --device cuda:<id>
if args.len() > 2 && args[1] == "--device" {
match args[2].as_str() {
"metal" => Some(DeviceOption::Metal),
cuda if cuda.starts_with("cuda:") => {
let id_part = &cuda["cuda:".len()..];
if let Ok(device_id) = id_part.parse::<usize>() {
Some(DeviceOption::Cuda(device_id))
} else {
eprintln!("Error: Invalid CUDA device id: {}", id_part);
None
}
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"cpu" => Ok(DeviceOption::Cpu),
"metal" => Ok(DeviceOption::Metal),
s if s.starts_with("cuda:") => {
let id = s.strip_prefix("cuda:").unwrap().parse::<usize>()?;
Ok(DeviceOption::Cuda(id))
}
_ => Some(DeviceOption::Cpu),
_ => Err(anyhow!("Invalid device option: {}", s)),
}
} else {
Some(DeviceOption::Cpu)
}
}

fn parse_device(s: &str) -> Result<DeviceOption, anyhow::Error> {
DeviceOption::from_str(s)
}

pub fn get_device_from_args() -> Result<Device> {
get_device(parse_device_option())
let cli = Cli::parse();
cli.get_device()
}
18 changes: 17 additions & 1 deletion candle-holder-models/src/generation/token_streamer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@ pub trait TokenStreamer<'a> {
pub struct TextStreamer<'a> {
/// The tokenizer used to decode the tokens into text.
tokenizer: &'a Box<dyn Tokenizer>,
/// Whether to skip the prompt when decoding the tokens into text.
skip_prompt: bool,
/// Whether to skip special tokens when decoding the tokens
skip_special_tokens: bool,
/// Whether the next tokens are part of the prompt.
next_tokens_are_prompt: bool,
/// A cache to store the tokens until a printable text is found.
token_cache: Vec<u32>,
/// The length of text that can be printed from the token cache.
Expand All @@ -31,15 +35,22 @@ impl<'a> TextStreamer<'a> {
/// # Arguments
///
/// * `tokenizer` - The tokenizer used to decode the tokens into text.
/// * `skip_prompt` - Whether to skip the prompt when decoding the tokens into text.
/// * `skip_special_tokens` - Whether to skip special tokens when decoding the tokens.
///
/// # Returns
///
/// A new `TextStreamer`.
pub fn new(tokenizer: &'a Box<dyn Tokenizer>, skip_special_tokens: bool) -> Self {
pub fn new(
tokenizer: &'a Box<dyn Tokenizer>,
skip_prompt: bool,
skip_special_tokens: bool,
) -> Self {
TextStreamer {
tokenizer,
skip_prompt,
skip_special_tokens,
next_tokens_are_prompt: true,
token_cache: vec![],
print_len: 0,
}
Expand Down Expand Up @@ -85,6 +96,11 @@ impl<'a> TokenStreamer<'a> for TextStreamer<'a> {
));
}

if self.skip_prompt && self.next_tokens_are_prompt {
self.next_tokens_are_prompt = false;
return Ok(());
}

self.token_cache.extend_from_slice(&tokens[0]);
let text = self
.tokenizer
Expand Down

0 comments on commit fe07d24

Please sign in to comment.