-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add streamed text generation example
- Loading branch information
1 parent
b55e50a
commit fe07d24
Showing
6 changed files
with
151 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
candle-holder-examples/examples/streamed_text_generation/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Streamed Text Generation | ||
|
||
## Running the example | ||
|
||
```bash | ||
cargo run --example streamed_text_generation --features cuda,flash-attn -- --device cuda:0 --prompt "What's the three body problem?" --apply-chat-template | ||
``` |
85 changes: 85 additions & 0 deletions
85
candle-holder-examples/examples/streamed_text_generation/main.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
use anyhow::{Error, Result}; | ||
use candle_holder_examples::Cli; | ||
use candle_holder_models::{ | ||
AutoModelForCausalLM, GenerationConfig, GenerationParams, TextStreamer, TokenStreamer, | ||
}; | ||
use candle_holder_tokenizers::{AutoTokenizer, Message}; | ||
use clap::Parser; | ||
|
||
#[derive(Debug, Parser)] | ||
pub struct GenerationCli { | ||
#[command(flatten)] | ||
pub base: Cli, | ||
|
||
#[arg(long, default_value = "meta-llama/Meta-Llama-3.1-8B-Instruct")] | ||
pub model: String, | ||
|
||
#[arg(long, default_value = "0.6")] | ||
pub temperature: f64, | ||
|
||
#[arg(long, default_value = "0.9")] | ||
pub top_p: f32, | ||
|
||
#[arg(long, default_value = "50")] | ||
pub top_k: usize, | ||
|
||
#[arg(long, default_value = "1024")] | ||
pub max_new_tokens: usize, | ||
|
||
#[arg(long)] | ||
pub system_prompt: Option<String>, | ||
|
||
#[arg(long, required = true)] | ||
pub prompt: String, | ||
|
||
#[arg(long, default_value = "false")] | ||
pub apply_chat_template: bool, | ||
} | ||
|
||
fn main() -> Result<()> { | ||
let args = GenerationCli::parse(); | ||
|
||
let device = args.base.get_device()?; | ||
println!("Device: {:?}", device); | ||
|
||
let tokenizer = AutoTokenizer::from_pretrained(args.model.clone(), None, None)?; | ||
let model = AutoModelForCausalLM::from_pretrained(args.model, &device, None, None)?; | ||
|
||
let mut encodings = if args.apply_chat_template { | ||
tokenizer | ||
.apply_chat_template_and_encode(vec![Message::user(args.prompt)], true) | ||
.map_err(Error::msg)? | ||
} else { | ||
tokenizer | ||
.encode(vec![args.prompt], true, None) | ||
.map_err(Error::msg)? | ||
}; | ||
encodings.to_device(&device)?; | ||
|
||
let start = std::time::Instant::now(); | ||
|
||
let token_streamer: Box<dyn TokenStreamer> = | ||
Box::new(TextStreamer::new(&tokenizer, true, true)); | ||
|
||
let input_ids = encodings.get_input_ids(); | ||
model.generate( | ||
input_ids, | ||
GenerationParams { | ||
generation_config: Some(GenerationConfig { | ||
do_sample: true, | ||
top_p: Some(args.top_p), | ||
top_k: Some(args.top_k), | ||
temperature: args.temperature, | ||
max_new_tokens: Some(args.max_new_tokens), | ||
..GenerationConfig::default() | ||
}), | ||
tokenizer: Some(&tokenizer), | ||
token_streamer: Some(token_streamer), | ||
..Default::default() | ||
}, | ||
)?; | ||
|
||
println!("\nTook: {:?}", start.elapsed()); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,56 @@ | ||
use anyhow::Result; | ||
use std::env; | ||
|
||
use anyhow::{anyhow, Result}; | ||
use candle_core::Device; | ||
use clap::Parser; | ||
use std::str::FromStr; | ||
|
||
pub enum DeviceOption { | ||
Cpu, | ||
Metal, | ||
Cuda(usize), | ||
#[derive(Debug, Parser)] | ||
#[command(author, version, about, long_about = None)] | ||
pub struct Cli { | ||
#[arg(long, value_parser = parse_device, default_value = "cpu")] | ||
pub device: DeviceOption, | ||
} | ||
|
||
pub fn get_device(device: Option<DeviceOption>) -> Result<Device> { | ||
let device = match device { | ||
Some(DeviceOption::Cuda(device_id)) if cfg!(feature = "cuda") => { | ||
Device::new_cuda(device_id)? | ||
impl Cli { | ||
pub fn get_device(&self) -> Result<Device> { | ||
match self.device { | ||
DeviceOption::Cuda(device_id) if cfg!(feature = "cuda") => { | ||
Ok(Device::new_cuda(device_id)?) | ||
} | ||
DeviceOption::Metal if cfg!(feature = "metal") => Ok(Device::new_metal(0)?), | ||
DeviceOption::Cpu => Ok(Device::Cpu), | ||
_ => Err(anyhow!("Requested device is not available")), | ||
} | ||
Some(DeviceOption::Metal) if cfg!(feature = "metal") => Device::new_metal(0)?, | ||
_ => Device::Cpu, | ||
}; | ||
} | ||
} | ||
|
||
Ok(device) | ||
#[derive(Clone, Debug)] | ||
pub enum DeviceOption { | ||
Cpu, | ||
Metal, | ||
Cuda(usize), | ||
} | ||
|
||
pub fn parse_device_option() -> Option<DeviceOption> { | ||
let args: Vec<String> = env::args().collect(); | ||
impl FromStr for DeviceOption { | ||
type Err = anyhow::Error; | ||
|
||
// Expecting something like: --device cpu, --device metal, or --device cuda:<id> | ||
if args.len() > 2 && args[1] == "--device" { | ||
match args[2].as_str() { | ||
"metal" => Some(DeviceOption::Metal), | ||
cuda if cuda.starts_with("cuda:") => { | ||
let id_part = &cuda["cuda:".len()..]; | ||
if let Ok(device_id) = id_part.parse::<usize>() { | ||
Some(DeviceOption::Cuda(device_id)) | ||
} else { | ||
eprintln!("Error: Invalid CUDA device id: {}", id_part); | ||
None | ||
} | ||
fn from_str(s: &str) -> Result<Self, Self::Err> { | ||
match s { | ||
"cpu" => Ok(DeviceOption::Cpu), | ||
"metal" => Ok(DeviceOption::Metal), | ||
s if s.starts_with("cuda:") => { | ||
let id = s.strip_prefix("cuda:").unwrap().parse::<usize>()?; | ||
Ok(DeviceOption::Cuda(id)) | ||
} | ||
_ => Some(DeviceOption::Cpu), | ||
_ => Err(anyhow!("Invalid device option: {}", s)), | ||
} | ||
} else { | ||
Some(DeviceOption::Cpu) | ||
} | ||
} | ||
|
||
fn parse_device(s: &str) -> Result<DeviceOption, anyhow::Error> { | ||
DeviceOption::from_str(s) | ||
} | ||
|
||
pub fn get_device_from_args() -> Result<Device> { | ||
get_device(parse_device_option()) | ||
let cli = Cli::parse(); | ||
cli.get_device() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters