Skip to content

Commit

Permalink
Finalize TextGenerationPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 1, 2024
1 parent ecbdcee commit 8722bd0
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn main() -> Result<()> {
],
Some(1),
)?;

println!("`pipeline.run_batch` results: {:?}", results);

Ok(())
Expand Down
20 changes: 20 additions & 0 deletions candle-holder-examples/examples/text_generation_pipeline/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,25 @@ fn main() -> Result<()> {

println!("`pipeline.run` results: {:?}", generations);

let generations = pipeline.run_batch(
vec![
vec![Message::user("How much is 2 + 2?")],
vec![Message::user("How much is 2 x 3?")],
],
Some(GenerationParams {
generation_config: Some(GenerationConfig {
do_sample: true,
max_new_tokens: Some(256),
top_p: Some(0.9),
top_k: None,
temperature: 0.6,
..GenerationConfig::default()
}),
..Default::default()
}),
)?;

println!("`pipeline.run_batch` results: {:?}", generations);

Ok(())
}
5 changes: 3 additions & 2 deletions candle-holder-models/src/generation/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(
mut token_streamer: Option<Box<dyn TokenStreamer<'a> + 'a>>,
seed: Option<u64>,
) -> Result<Vec<GenerateOutput>> {
let mut input_ids = input_ids.repeat((generation_config.get_num_return_sequences(), 1))?;
let num_return_sequences = generation_config.get_num_return_sequences();
let mut input_ids = input_ids.repeat((num_return_sequences, 1))?;
let mut output = input_ids.to_vec2::<u32>()?;
let input_ids_dims = input_ids.dims2()?;

Expand Down Expand Up @@ -135,7 +136,7 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(

stream_end(&mut token_streamer)?;

Ok(create_outputs(output, num_sequences))
Ok(create_outputs(output, num_return_sequences))
}

fn stream_tokens(
Expand Down
2 changes: 1 addition & 1 deletion candle-holder-pipelines/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod zero_shot_classification;

pub use fill_mask::{FillMaskOptions, FillMaskPipeline};
pub use text_classification::TextClassificationPipeline;
pub use text_generation::TextGenerationPipeline;
pub use text_generation::{TextGenerationPipeline, TextGenerationPipelineOutput};
pub use token_classification::{
AggregationStrategy, TokenClassificationOptions, TokenClassificationPipeline,
};
Expand Down
164 changes: 113 additions & 51 deletions candle-holder-pipelines/src/text_generation.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use candle_core::{DType, Device};
use candle_core::{DType, Device, Tensor};
use candle_holder::{FromPretrainedParameters, Result};
use candle_holder_models::{
generation::generate::GenerateOutput, AutoModelForCausalLM, GenerationParams, PreTrainedModel,
};
use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Message, PaddingSide, Tokenizer};

enum RunTextGenerationInput {
#[derive(Debug)]
pub enum RunTextGenerationInput {
Input(String),
Messages(Vec<Message>),
}
Expand All @@ -28,7 +29,8 @@ impl From<Vec<Message>> for RunTextGenerationInput {
}
}

enum RunBatchTextGenerationInput {
#[derive(Debug)]
pub enum RunBatchTextGenerationInput {
Inputs(Vec<String>),
Messages(Vec<Vec<Message>>),
}
Expand Down Expand Up @@ -64,6 +66,15 @@ pub struct TextGenerationPipelineOutput {
messages: Option<Vec<Message>>,
}

impl TextGenerationPipelineOutput {
pub fn get_text(&self) -> Option<String> {
self.text.clone()
}
pub fn get_messages(&self) -> Option<Vec<Message>> {
self.messages.clone()
}
}

/// A pipeline for generating text with a causal language model.
pub struct TextGenerationPipeline {
model: Box<dyn PreTrainedModel>,
Expand Down Expand Up @@ -100,77 +111,128 @@ impl TextGenerationPipeline {
})
}

fn preprocess<I: Into<RunBatchTextGenerationInput>>(
&self,
inputs: I,
) -> Result<(BatchEncoding, Vec<String>, bool)> {
let (mut encodings, text_inputs, are_messages) = match inputs.into() {
RunBatchTextGenerationInput::Inputs(inputs) => (
self.tokenizer.encode(inputs.clone(), false, None)?,
inputs,
false,
),
fn preprocess<I: Into<RunBatchTextGenerationInput>>(&self, inputs: I) -> Result<BatchEncoding> {
let mut encodings = match inputs.into() {
RunBatchTextGenerationInput::Inputs(inputs) => {
self.tokenizer.encode(inputs.clone(), true, None)?
}
RunBatchTextGenerationInput::Messages(messages) => {
let messages: Result<Vec<String>> = messages
.into_iter()
.map(|messages| self.tokenizer.apply_chat_template(messages))
.map(|messages| self.tokenizer.apply_chat_template(messages, true))
.collect();
let inputs = messages?;
(
self.tokenizer.encode(inputs.clone(), false, None)?,
inputs,
true,
)
self.tokenizer.encode(inputs.clone(), false, None)?
}
};
encodings.to_device(&self.device)?;
Ok((encodings, text_inputs, are_messages))
Ok(encodings)
}

fn postprocess(
fn postprocess<I: Into<RunBatchTextGenerationInput>>(
&self,
inputs: I,
input_ids: &Tensor,
outputs: Vec<GenerateOutput>,
text_inputs: Vec<String>,
are_messages: bool,
) -> Result<Vec<Vec<TextGenerationPipelineOutput>>> {
let mut results = vec![];
for (output, text_input) in outputs.iter().zip(text_inputs) {
let mut outputs = vec![];
let text_input_len = text_input.len();
for sequence in output.get_sequences() {
let text = self.tokenizer.decode(&sequence[..], true)?;
println!("text input: {}", text_input);
println!("text: {}", text);
let generated: String = text.chars().skip(text_input_len).collect();
println!("generated: {}", generated);
outputs.push(TextGenerationPipelineOutput {
text: Some(generated),
messages: None,
});
}
results.push(outputs);
let inputs = inputs.into();
let inputs_prompt_lengths: Vec<usize> = input_ids
.to_vec2::<u32>()?
.into_iter()
.map(|seq_input_ids| {
self.tokenizer
.decode(&seq_input_ids[..], true)
.unwrap()
.len()
})
.collect();

match inputs {
RunBatchTextGenerationInput::Inputs(inputs) => inputs
.into_iter()
.zip(inputs_prompt_lengths.into_iter())
.zip(outputs.into_iter())
.map(|((_, input_prompt_length), output)| {
output
.get_sequences()
.iter()
.map(|sequence| {
let text = self.tokenizer.decode(sequence, true)?;
let generated = text.chars().skip(input_prompt_length).collect();
Ok(TextGenerationPipelineOutput {
text: Some(generated),
messages: None,
})
})
.collect()
})
.collect(),
RunBatchTextGenerationInput::Messages(messages_batch) => messages_batch
.into_iter()
.zip(inputs_prompt_lengths.into_iter())
.zip(outputs.into_iter())
.map(|((messages, input_prompt_length), output)| {
output
.get_sequences()
.iter()
.map(|sequence| {
let text = self.tokenizer.decode(sequence, true)?;
let generated: String =
text.chars().skip(input_prompt_length).collect();
let mut messages = messages.clone();
messages.push(Message::assistant(generated));
Ok(TextGenerationPipelineOutput {
text: None,
messages: Some(messages),
})
})
.collect()
})
.collect(),
}
Ok(results)
}

pub fn run<I: Into<RunTextGenerationInput>>(
/// Generates text from the given input.
///
/// # Arguments
///
/// * `input` - The input to generate text from.
/// * `params` - Optional parameters to specify the generation configuration.
///
/// # Returns
///
/// The generated text.
pub fn run<I: Into<RunTextGenerationInput> + Clone>(
&self,
input: I,
params: Option<GenerationParams>,
) -> Result<Vec<TextGenerationPipelineOutput>> {
let (encodings, text_inputs, are_messages) = self.preprocess(input.into())?;
let encodings = self.preprocess(input.clone().into())?;
let outputs = self
.model
.generate(encodings.get_input_ids(), params.unwrap_or_default())?;
Ok(self.postprocess(outputs, text_inputs, are_messages)?[0].clone())
Ok(self.postprocess(input.into(), encodings.get_input_ids(), outputs)?[0].clone())
}

// pub fn run_batch<I: Into<RunBatchTextGenerationInput>>(
// &mut self,
// inputs: I,
// params: Option<GenerationParams>,
// ) -> Result<TextGenerationPipelineOutput> {
// let (encodings, are_messages) = self.preprocess(inputs)?;
// Ok(())
// }
/// Generates text from the given inputs.
///
/// # Arguments
///
/// * `inputs` - The inputs to generate text from.
/// * `params` - Optional parameters to specify the generation configuration.
///
/// # Returns
///
/// The generated texts.
pub fn run_batch<I: Into<RunBatchTextGenerationInput> + Clone>(
&self,
inputs: I,
params: Option<GenerationParams>,
) -> Result<Vec<Vec<TextGenerationPipelineOutput>>> {
let encodings = self.preprocess(inputs.clone())?;
let outputs = self
.model
.generate(encodings.get_input_ids(), params.unwrap_or_default())?;
self.postprocess(inputs, encodings.get_input_ids(), outputs)
}
}
4 changes: 3 additions & 1 deletion candle-holder-tokenizers/src/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct ChatTemplateInputs {
messages: Vec<Message>,
bos_token: Option<String>,
eos_token: Option<String>,
add_generation_prompt: bool,
}

/// https://github.com/huggingface/text-generation-inference/blob/d9fbbaafb046bb423e31edaf9ccf8eecc2d5c33d/router/src/infer/chat_template.rs#L4
Expand Down Expand Up @@ -76,12 +77,13 @@ impl ChatTemplate {
})
}

pub fn apply(&self, messages: Vec<Message>) -> Result<String> {
pub fn apply(&self, messages: Vec<Message>, add_generation_prompt: bool) -> Result<String> {
self.template
.render(&ChatTemplateInputs {
messages,
bos_token: self.bos_token.clone(),
eos_token: self.eos_token.clone(),
add_generation_prompt,
})
.map_err(Error::wrap)
}
Expand Down
19 changes: 14 additions & 5 deletions candle-holder-tokenizers/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,20 @@ pub trait Tokenizer: std::fmt::Debug {
/// # Arguments
///
/// * `messages` - A list of messages to apply the chat template.
/// * `add_generation_prompt` - A flag indicating if the generation prompt should be added.
///
/// # Returns
///
/// The input string for the model in the expected format.
fn apply_chat_template(&self, messages: Vec<Message>) -> Result<String> {
fn apply_chat_template(
&self,
messages: Vec<Message>,
add_generation_prompt: bool,
) -> Result<String> {
let chat_template = self.get_chat_template().ok_or_else(|| {
Error::MissingChatTemplate("Chat template not found in the tokenizer".to_string())
})?;
chat_template.apply(messages)
chat_template.apply(messages, add_generation_prompt)
}

/// Applies the chat template to a list of messages and encodes the result.
Expand All @@ -304,12 +309,16 @@ pub trait Tokenizer: std::fmt::Debug {
/// # Returns
///
/// A `BatchEncoding` containing the encoded sequences.
fn apply_chat_template_and_encode(&self, messages: Vec<Message>) -> Result<BatchEncoding> {
fn apply_chat_template_and_encode(
&self,
messages: Vec<Message>,
add_generation_prompt: bool,
) -> Result<BatchEncoding> {
let chat_template = self.get_chat_template().ok_or_else(|| {
Error::MissingChatTemplate("Chat template not found in the tokenizer".to_string())
})?;
let chat = chat_template.apply(messages)?;
self.encode(vec![chat], true, None)
let chat = chat_template.apply(messages, add_generation_prompt)?;
self.encode(vec![chat], false, None)
}
}

Expand Down

0 comments on commit 8722bd0

Please sign in to comment.