diff --git a/candle-holder-examples/examples/text_classification_pipeline/main.rs b/candle-holder-examples/examples/text_classification_pipeline/main.rs index c5900c9..43e3c1a 100644 --- a/candle-holder-examples/examples/text_classification_pipeline/main.rs +++ b/candle-holder-examples/examples/text_classification_pipeline/main.rs @@ -19,6 +19,7 @@ fn main() -> Result<()> { ], Some(1), )?; + println!("`pipeline.run_batch` results: {:?}", results); Ok(()) diff --git a/candle-holder-examples/examples/text_generation_pipeline/main.rs b/candle-holder-examples/examples/text_generation_pipeline/main.rs index ec5ab90..115ea17 100644 --- a/candle-holder-examples/examples/text_generation_pipeline/main.rs +++ b/candle-holder-examples/examples/text_generation_pipeline/main.rs @@ -28,5 +28,25 @@ fn main() -> Result<()> { println!("`pipeline.run` results: {:?}", generations); + let generations = pipeline.run_batch( + vec![ + vec![Message::user("How much is 2 + 2?")], + vec![Message::user("How much is 2 x 3?")], + ], + Some(GenerationParams { + generation_config: Some(GenerationConfig { + do_sample: true, + max_new_tokens: Some(256), + top_p: Some(0.9), + top_k: None, + temperature: 0.6, + ..GenerationConfig::default() + }), + ..Default::default() + }), + )?; + + println!("`pipeline.run_batch` results: {:?}", generations); + Ok(()) } diff --git a/candle-holder-models/src/generation/generate.rs b/candle-holder-models/src/generation/generate.rs index 17997bb..576c82b 100644 --- a/candle-holder-models/src/generation/generate.rs +++ b/candle-holder-models/src/generation/generate.rs @@ -40,7 +40,8 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>( mut token_streamer: Option + 'a>>, seed: Option, ) -> Result> { - let mut input_ids = input_ids.repeat((generation_config.get_num_return_sequences(), 1))?; + let num_return_sequences = generation_config.get_num_return_sequences(); + let mut input_ids = input_ids.repeat((num_return_sequences, 1))?; let mut output = input_ids.to_vec2::()?; let input_ids_dims = input_ids.dims2()?; @@ -135,7 +136,7 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>( stream_end(&mut token_streamer)?; - Ok(create_outputs(output, num_sequences)) + Ok(create_outputs(output, num_return_sequences)) } fn stream_tokens( diff --git a/candle-holder-pipelines/src/lib.rs b/candle-holder-pipelines/src/lib.rs index 7b65a57..1238364 100644 --- a/candle-holder-pipelines/src/lib.rs +++ b/candle-holder-pipelines/src/lib.rs @@ -6,7 +6,7 @@ pub mod zero_shot_classification; pub use fill_mask::{FillMaskOptions, FillMaskPipeline}; pub use text_classification::TextClassificationPipeline; -pub use text_generation::TextGenerationPipeline; +pub use text_generation::{TextGenerationPipeline, TextGenerationPipelineOutput}; pub use token_classification::{ AggregationStrategy, TokenClassificationOptions, TokenClassificationPipeline, }; diff --git a/candle-holder-pipelines/src/text_generation.rs b/candle-holder-pipelines/src/text_generation.rs index 519872b..057ba90 100644 --- a/candle-holder-pipelines/src/text_generation.rs +++ b/candle-holder-pipelines/src/text_generation.rs @@ -1,11 +1,12 @@ -use candle_core::{DType, Device}; +use candle_core::{DType, Device, Tensor}; use candle_holder::{FromPretrainedParameters, Result}; use candle_holder_models::{ generation::generate::GenerateOutput, AutoModelForCausalLM, GenerationParams, PreTrainedModel, }; use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Message, PaddingSide, Tokenizer}; -enum RunTextGenerationInput { +#[derive(Debug)] +pub enum RunTextGenerationInput { Input(String), Messages(Vec), } @@ -28,7 +29,8 @@ impl From> for RunTextGenerationInput { } } -enum RunBatchTextGenerationInput { +#[derive(Debug)] +pub enum RunBatchTextGenerationInput { Inputs(Vec), Messages(Vec>), } @@ -64,6 +66,15 @@ pub struct TextGenerationPipelineOutput { messages: Option>, } +impl TextGenerationPipelineOutput { + pub fn get_text(&self) -> Option { + self.text.clone() + } + pub fn get_messages(&self) -> Option> { + self.messages.clone() + } +} + /// A pipeline for generating text with a causal language model. pub struct TextGenerationPipeline { model: Box, @@ -100,77 +111,108 @@ impl TextGenerationPipeline { }) } - fn preprocess>( - &self, - inputs: I, - ) -> Result<(BatchEncoding, Vec, bool)> { - let (mut encodings, text_inputs, are_messages) = match inputs.into() { - RunBatchTextGenerationInput::Inputs(inputs) => ( - self.tokenizer.encode(inputs.clone(), false, None)?, - inputs, - false, - ), + fn preprocess>(&self, inputs: I) -> Result { + let mut encodings = match inputs.into() { + RunBatchTextGenerationInput::Inputs(inputs) => { + self.tokenizer.encode(inputs.clone(), true, None)? + } RunBatchTextGenerationInput::Messages(messages) => { let messages: Result> = messages .into_iter() - .map(|messages| self.tokenizer.apply_chat_template(messages)) + .map(|messages| self.tokenizer.apply_chat_template(messages, true)) .collect(); let inputs = messages?; - ( - self.tokenizer.encode(inputs.clone(), false, None)?, - inputs, - true, - ) + self.tokenizer.encode(inputs.clone(), false, None)? } }; encodings.to_device(&self.device)?; - Ok((encodings, text_inputs, are_messages)) + Ok(encodings) } - fn postprocess( + fn postprocess>( &self, + inputs: I, + input_ids: &Tensor, outputs: Vec, - text_inputs: Vec, - are_messages: bool, ) -> Result>> { - let mut results = vec![]; - for (output, text_input) in outputs.iter().zip(text_inputs) { - let mut outputs = vec![]; - let text_input_len = text_input.len(); - for sequence in output.get_sequences() { - let text = self.tokenizer.decode(&sequence[..], true)?; - println!("text input: {}", text_input); - println!("text: {}", text); - let generated: String = text.chars().skip(text_input_len).collect(); - println!("generated: {}", generated); - outputs.push(TextGenerationPipelineOutput { - text: Some(generated), - messages: None, - }); - } - results.push(outputs); + let inputs = inputs.into(); + let inputs_prompt_lengths: Vec = input_ids + .to_vec2::()? + .into_iter() + .map(|seq_input_ids| { + self.tokenizer + .decode(&seq_input_ids[..], true) + .unwrap() + .len() + }) + .collect(); + + match inputs { + RunBatchTextGenerationInput::Inputs(inputs) => inputs + .into_iter() + .zip(inputs_prompt_lengths.into_iter()) + .zip(outputs.into_iter()) + .map(|((_, input_prompt_length), output)| { + output + .get_sequences() + .iter() + .map(|sequence| { + let text = self.tokenizer.decode(sequence, true)?; + let generated = text.chars().skip(input_prompt_length).collect(); + Ok(TextGenerationPipelineOutput { + text: Some(generated), + messages: None, + }) + }) + .collect() + }) + .collect(), + RunBatchTextGenerationInput::Messages(messages_batch) => messages_batch + .into_iter() + .zip(inputs_prompt_lengths.into_iter()) + .zip(outputs.into_iter()) + .map(|((messages, input_prompt_length), output)| { + output + .get_sequences() + .iter() + .map(|sequence| { + let text = self.tokenizer.decode(sequence, true)?; + let generated: String = + text.chars().skip(input_prompt_length).collect(); + let mut messages = messages.clone(); + messages.push(Message::assistant(generated)); + Ok(TextGenerationPipelineOutput { + text: None, + messages: Some(messages), + }) + }) + .collect() + }) + .collect(), } - Ok(results) } - pub fn run>( + pub fn run + Clone>( &self, input: I, params: Option, ) -> Result> { - let (encodings, text_inputs, are_messages) = self.preprocess(input.into())?; + let encodings = self.preprocess(input.clone().into())?; let outputs = self .model .generate(encodings.get_input_ids(), params.unwrap_or_default())?; - Ok(self.postprocess(outputs, text_inputs, are_messages)?[0].clone()) + Ok(self.postprocess(input.into(), encodings.get_input_ids(), outputs)?[0].clone()) } - // pub fn run_batch>( - // &mut self, - // inputs: I, - // params: Option, - // ) -> Result { - // let (encodings, are_messages) = self.preprocess(inputs)?; - // Ok(()) - // } + pub fn run_batch + Clone>( + &self, + inputs: I, + params: Option, + ) -> Result>> { + let encodings = self.preprocess(inputs.clone())?; + let outputs = self + .model + .generate(encodings.get_input_ids(), params.unwrap_or_default())?; + self.postprocess(inputs, encodings.get_input_ids(), outputs) + } } diff --git a/candle-holder-tokenizers/src/chat_template.rs b/candle-holder-tokenizers/src/chat_template.rs index 43d2301..3832538 100644 --- a/candle-holder-tokenizers/src/chat_template.rs +++ b/candle-holder-tokenizers/src/chat_template.rs @@ -43,6 +43,7 @@ struct ChatTemplateInputs { messages: Vec, bos_token: Option, eos_token: Option, + add_generation_prompt: bool, } /// https://github.com/huggingface/text-generation-inference/blob/d9fbbaafb046bb423e31edaf9ccf8eecc2d5c33d/router/src/infer/chat_template.rs#L4 @@ -76,12 +77,13 @@ impl ChatTemplate { }) } - pub fn apply(&self, messages: Vec) -> Result { + pub fn apply(&self, messages: Vec, add_generation_prompt: bool) -> Result { self.template .render(&ChatTemplateInputs { messages, bos_token: self.bos_token.clone(), eos_token: self.eos_token.clone(), + add_generation_prompt, }) .map_err(Error::wrap) } diff --git a/candle-holder-tokenizers/src/tokenizer.rs b/candle-holder-tokenizers/src/tokenizer.rs index 30459d2..6df9fc8 100644 --- a/candle-holder-tokenizers/src/tokenizer.rs +++ b/candle-holder-tokenizers/src/tokenizer.rs @@ -284,15 +284,20 @@ pub trait Tokenizer: std::fmt::Debug { /// # Arguments /// /// * `messages` - A list of messages to apply the chat template. + /// * `add_generation_prompt` - A flag indicating if the generation prompt should be added. /// /// # Returns /// /// The input string for the model in the expected format. - fn apply_chat_template(&self, messages: Vec) -> Result { + fn apply_chat_template( + &self, + messages: Vec, + add_generation_prompt: bool, + ) -> Result { let chat_template = self.get_chat_template().ok_or_else(|| { Error::MissingChatTemplate("Chat template not found in the tokenizer".to_string()) })?; - chat_template.apply(messages) + chat_template.apply(messages, add_generation_prompt) } /// Applies the chat template to a list of messages and encodes the result. @@ -304,12 +309,16 @@ pub trait Tokenizer: std::fmt::Debug { /// # Returns /// /// A `BatchEncoding` containing the encoded sequences. - fn apply_chat_template_and_encode(&self, messages: Vec) -> Result { + fn apply_chat_template_and_encode( + &self, + messages: Vec, + add_generation_prompt: bool, + ) -> Result { let chat_template = self.get_chat_template().ok_or_else(|| { Error::MissingChatTemplate("Chat template not found in the tokenizer".to_string()) })?; - let chat = chat_template.apply(messages)?; - self.encode(vec![chat], true, None) + let chat = chat_template.apply(messages, add_generation_prompt)?; + self.encode(vec![chat], false, None) } }