From ecbdceec730339d46006e0cba63acf7c5790403f Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Sat, 31 Aug 2024 20:34:20 +0200 Subject: [PATCH] WIP `TextGenerationPipeline` --- .../text_classification_pipeline/main.rs | 1 - .../text_generation_pipeline/README.md | 7 + .../examples/text_generation_pipeline/main.rs | 32 ++++ .../src/text_generation.rs | 144 ++++++++++++++++-- candle-holder-tokenizers/src/chat_template.rs | 2 +- .../src/from_pretrained.rs | 4 +- 6 files changed, 173 insertions(+), 17 deletions(-) create mode 100644 candle-holder-examples/examples/text_generation_pipeline/README.md create mode 100644 candle-holder-examples/examples/text_generation_pipeline/main.rs diff --git a/candle-holder-examples/examples/text_classification_pipeline/main.rs b/candle-holder-examples/examples/text_classification_pipeline/main.rs index fc8fa51..c5900c9 100644 --- a/candle-holder-examples/examples/text_classification_pipeline/main.rs +++ b/candle-holder-examples/examples/text_classification_pipeline/main.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use candle_core::Device; use candle_holder_examples::get_device_from_args; use candle_holder_pipelines::TextClassificationPipeline; diff --git a/candle-holder-examples/examples/text_generation_pipeline/README.md b/candle-holder-examples/examples/text_generation_pipeline/README.md new file mode 100644 index 0000000..4bd5f84 --- /dev/null +++ b/candle-holder-examples/examples/text_generation_pipeline/README.md @@ -0,0 +1,7 @@ +# Text Generation Pipeline + +## Running the example + +```bash +cargo run --example text_generation_pipeline +``` diff --git a/candle-holder-examples/examples/text_generation_pipeline/main.rs b/candle-holder-examples/examples/text_generation_pipeline/main.rs new file mode 100644 index 0000000..ec5ab90 --- /dev/null +++ b/candle-holder-examples/examples/text_generation_pipeline/main.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use candle_holder_examples::get_device_from_args; +use candle_holder_models::{GenerationConfig, GenerationParams}; +use candle_holder_pipelines::TextGenerationPipeline; +use candle_holder_tokenizers::Message; + +fn main() -> Result<()> { + let device = get_device_from_args()?; + println!("Device: {:?}", device); + + let pipeline = + TextGenerationPipeline::new("meta-llama/Meta-Llama-3.1-8B-Instruct", &device, None, None)?; + + let generations = pipeline.run( + vec![Message::user("How much is 2 + 2?")], + Some(GenerationParams { + generation_config: Some(GenerationConfig { + do_sample: true, + max_new_tokens: Some(256), + top_p: Some(0.9), + top_k: None, + temperature: 0.6, + ..GenerationConfig::default() + }), + ..Default::default() + }), + )?; + + println!("`pipeline.run` results: {:?}", generations); + + Ok(()) +} diff --git a/candle-holder-pipelines/src/text_generation.rs b/candle-holder-pipelines/src/text_generation.rs index e3451dd..519872b 100644 --- a/candle-holder-pipelines/src/text_generation.rs +++ b/candle-holder-pipelines/src/text_generation.rs @@ -1,9 +1,68 @@ use candle_core::{DType, Device}; use candle_holder::{FromPretrainedParameters, Result}; -use candle_holder_models::{AutoModelForCausalLM, PreTrainedModel}; -use candle_holder_tokenizers::{ - AutoTokenizer, BatchEncoding, Padding, PaddingOptions, PaddingSide, Tokenizer, +use candle_holder_models::{ + generation::generate::GenerateOutput, AutoModelForCausalLM, GenerationParams, PreTrainedModel, }; +use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Message, PaddingSide, Tokenizer}; + +enum RunTextGenerationInput { + Input(String), + Messages(Vec), +} + +impl From<&str> for RunTextGenerationInput { + fn from(input: &str) -> Self { + RunTextGenerationInput::Input(input.to_owned()) + } +} + +impl From for RunTextGenerationInput { + fn from(input: String) -> Self { + RunTextGenerationInput::Input(input) + } +} + +impl From> for RunTextGenerationInput { + fn from(messages: Vec) -> Self { + RunTextGenerationInput::Messages(messages) + } +} + +enum RunBatchTextGenerationInput { + Inputs(Vec), + Messages(Vec>), +} + +impl From> for RunBatchTextGenerationInput { + fn from(inputs: Vec) -> Self { + RunBatchTextGenerationInput::Inputs(inputs) + } +} + +impl From>> for RunBatchTextGenerationInput { + fn from(messages: Vec>) -> Self { + RunBatchTextGenerationInput::Messages(messages) + } +} + +impl From for RunBatchTextGenerationInput { + fn from(input: RunTextGenerationInput) -> Self { + match input { + RunTextGenerationInput::Input(input) => { + RunBatchTextGenerationInput::Inputs(vec![input]) + } + RunTextGenerationInput::Messages(messages) => { + RunBatchTextGenerationInput::Messages(vec![messages]) + } + } + } +} + +#[derive(Debug, Clone)] +pub struct TextGenerationPipelineOutput { + text: Option, + messages: Option>, +} /// A pipeline for generating text with a causal language model. pub struct TextGenerationPipeline { @@ -41,20 +100,77 @@ impl TextGenerationPipeline { }) } - fn preprocess(&self, inputs: Vec) -> Result { - let mut encodings = self.tokenizer.encode( - inputs, - true, - Some(PaddingOptions::new(Padding::Longest, None)), - )?; + fn preprocess>( + &self, + inputs: I, + ) -> Result<(BatchEncoding, Vec, bool)> { + let (mut encodings, text_inputs, are_messages) = match inputs.into() { + RunBatchTextGenerationInput::Inputs(inputs) => ( + self.tokenizer.encode(inputs.clone(), false, None)?, + inputs, + false, + ), + RunBatchTextGenerationInput::Messages(messages) => { + let messages: Result> = messages + .into_iter() + .map(|messages| self.tokenizer.apply_chat_template(messages)) + .collect(); + let inputs = messages?; + ( + self.tokenizer.encode(inputs.clone(), false, None)?, + inputs, + true, + ) + } + }; encodings.to_device(&self.device)?; - Ok(encodings) + Ok((encodings, text_inputs, are_messages)) + } + + fn postprocess( + &self, + outputs: Vec, + text_inputs: Vec, + are_messages: bool, + ) -> Result>> { + let mut results = vec![]; + for (output, text_input) in outputs.iter().zip(text_inputs) { + let mut outputs = vec![]; + let text_input_len = text_input.len(); + for sequence in output.get_sequences() { + let text = self.tokenizer.decode(&sequence[..], true)?; + println!("text input: {}", text_input); + println!("text: {}", text); + let generated: String = text.chars().skip(text_input_len).collect(); + println!("generated: {}", generated); + outputs.push(TextGenerationPipelineOutput { + text: Some(generated), + messages: None, + }); + } + results.push(outputs); + } + Ok(results) } - pub fn run>(&self, input: I) -> Result<()> { - let _encodings = self.preprocess(vec![input.into()])?; - Ok(()) + pub fn run>( + &self, + input: I, + params: Option, + ) -> Result> { + let (encodings, text_inputs, are_messages) = self.preprocess(input.into())?; + let outputs = self + .model + .generate(encodings.get_input_ids(), params.unwrap_or_default())?; + Ok(self.postprocess(outputs, text_inputs, are_messages)?[0].clone()) } - pub fn run_batch>(&mut self, inputs: Vec) {} + // pub fn run_batch>( + // &mut self, + // inputs: I, + // params: Option, + // ) -> Result { + // let (encodings, are_messages) = self.preprocess(inputs)?; + // Ok(()) + // } } diff --git a/candle-holder-tokenizers/src/chat_template.rs b/candle-holder-tokenizers/src/chat_template.rs index b2ec68b..43d2301 100644 --- a/candle-holder-tokenizers/src/chat_template.rs +++ b/candle-holder-tokenizers/src/chat_template.rs @@ -4,7 +4,7 @@ use minijinja_contrib::pycompat; use serde::{Deserialize, Serialize}; /// Represents a message in a conversation between a user and an assistant. -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { /// The role of the message. Can be "system", "user", or "assistant". role: String, diff --git a/candle-holder-tokenizers/src/from_pretrained.rs b/candle-holder-tokenizers/src/from_pretrained.rs index f0d3644..f0a6bbb 100644 --- a/candle-holder-tokenizers/src/from_pretrained.rs +++ b/candle-holder-tokenizers/src/from_pretrained.rs @@ -174,7 +174,9 @@ impl TokenizerInfo { pub fn get_tokenizer_class(&self) -> &str { if let Some(config) = &self.config { if let Some(tokenizer_class) = &config.tokenizer_class { - return tokenizer_class; + if tokenizer_class != "PreTrainedTokenizerFast" { + return tokenizer_class; + } } }