Skip to content

Commit

Permalink
Add num_return_sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Aug 31, 2024
1 parent 972d412 commit 67fd2a7
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
10 changes: 9 additions & 1 deletion candle-holder-models/src/generation/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct GenerationConfig {
pub min_length: Option<usize>,
/// A string or a list of strings that should terminate the generation if the model outputs
/// them.
#[serde(deserialize_with = "deserialize_single_or_vec")]
#[serde(default, deserialize_with = "deserialize_single_or_vec")]
pub stop_strings: Option<Vec<String>>,
/// Whether or not to use sampling. If `false`, then greedy decoding will be used. The default
/// is `false`.
Expand All @@ -53,6 +53,9 @@ pub struct GenerationConfig {
/// `0.0` to `infinity`. A value of `1.0` means no penalty. The default is `1.0`.
#[serde(default)]
pub repetition_penalty: Option<f64>,
/// The number of sequences to generate. The default is `1`.
#[serde(default)]
pub num_return_sequences: usize,
/// The ID of the PAD token.
#[serde(default)]
pub pad_token_id: Option<u32>,
Expand Down Expand Up @@ -105,6 +108,10 @@ impl GenerationConfig {
self.repetition_penalty
}

pub fn get_num_return_sequences(&self) -> usize {
self.num_return_sequences
}

pub fn get_pad_token_id(&self) -> Option<u32> {
self.pad_token_id
}
Expand All @@ -131,6 +138,7 @@ impl Default for GenerationConfig {
top_k: Some(50),
top_p: Some(1.0),
repetition_penalty: Some(1.0),
num_return_sequences: 1,
pad_token_id: None,
bos_token_id: None,
eos_token_id: None,
Expand Down
36 changes: 28 additions & 8 deletions candle-holder-models/src/generation/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ use super::{
};
use crate::{utils::cache::DynamicCache, ForwardParams, PreTrainedModel};

#[derive(Debug)]
pub struct GenerateOutput {
sequences: Vec<Vec<u32>>,
}

impl GenerateOutput {
pub fn get_sequences(&self) -> &Vec<Vec<u32>> {
&self.sequences
}
}

/// Generates a completion of the input sequences using the provided `model`.
///
/// # Arguments
Expand All @@ -28,9 +39,9 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(
stopping_criteria: Option<Vec<Box<dyn StoppingCriteria>>>,
mut token_streamer: Option<Box<dyn TokenStreamer<'a> + 'a>>,
seed: Option<u64>,
) -> Result<Vec<Vec<u32>>> {
) -> Result<Vec<GenerateOutput>> {
let mut input_ids = input_ids.repeat((generation_config.get_num_return_sequences(), 1))?;
let mut output = input_ids.to_vec2::<u32>()?;
let mut input_ids = input_ids.clone();
let input_ids_dims = input_ids.dims2()?;

if input_ids_dims.1.ge(&generation_config.get_max_length())
Expand All @@ -55,7 +66,7 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(
};

// TODO: refactor this into `LogitProcessor` trait
let mut sampling_config = LogitSampler::from_generation_config(generation_config, seed);
let mut logit_sampler = LogitSampler::from_generation_config(generation_config, seed);

// Initialize the stopping criteria applier that will be used to determine when to stop
let stopping_criteria_applier = StoppingCriteriaApplier::from_configuration(
Expand All @@ -67,8 +78,6 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(
let num_sequences = input_ids_dims.0;
let mut active_sequences = num_sequences;

// TODO: if `generation_config.num_return_sequences>1` then we need to expand the
// `input_ids` tensor to have `num_return_sequences` times the number of sequences
stream_tokens(&mut token_streamer, &input_ids.to_vec2::<u32>()?)?;

// Generation loop
Expand Down Expand Up @@ -101,7 +110,7 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(
}

// Sample next token
let next_token_id = sampling_config.sample(&seq_logits)?;
let next_token_id = logit_sampler.sample(&seq_logits)?;

// Update the sequences with the next token
output[i].push(next_token_id);
Expand All @@ -120,12 +129,13 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(

// Build the next `input_ids` vectors with the last token of each sequence
let sequences_last_tokens: Vec<u32> = sequences_last_tokens.into_iter().flatten().collect();
input_ids = Tensor::new(&sequences_last_tokens[..], input_ids.device())?.unsqueeze(0)?;
input_ids = Tensor::new(&sequences_last_tokens[..], input_ids.device())?
.reshape((num_sequences, 1))?;
}

stream_end(&mut token_streamer)?;

Ok(output)
Ok(create_outputs(output, num_sequences))
}

fn stream_tokens(
Expand All @@ -144,3 +154,13 @@ fn stream_end(token_streamer: &mut Option<Box<dyn TokenStreamer + '_>>) -> Resul
}
Ok(())
}

fn create_outputs(outputs: Vec<Vec<u32>>, num_return_sequences: usize) -> Vec<GenerateOutput> {
let mut generate_outputs = Vec::new();
for generated_sequences in outputs.chunks(num_return_sequences) {
let sequences = generated_sequences.to_vec();
let output = GenerateOutput { sequences };
generate_outputs.push(output);
}
generate_outputs
}
2 changes: 2 additions & 0 deletions candle-holder-models/src/generation/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::generation::config::GenerationConfig;

/// An enum containing the available sampling strategies for selecting the next token id of a
/// sequence using the outputs logits of an auto-regressive model.
#[derive(Debug)]
pub enum SamplingStrategy {
/// Greedy sampling selects the token with the highest probability.
Greedy,
Expand All @@ -23,6 +24,7 @@ pub enum SamplingStrategy {
}

/// A struct for sampling the next token id from the logits of an auto-regressive model.
#[derive(Debug)]
pub struct LogitSampler {
/// The sampling strategy to use.
strategy: SamplingStrategy,
Expand Down
31 changes: 26 additions & 5 deletions candle-holder-models/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ use candle_nn::VarBuilder;
use crate::{
config::PretrainedConfig,
from_pretrained::from_pretrained,
generation::{config::GenerationConfig, generate::generate, StoppingCriteria, TokenStreamer},
generation::{
config::GenerationConfig,
generate::{generate, GenerateOutput},
StoppingCriteria, TokenStreamer,
},
models::{
bert::{
BertForMaskedLM, BertForSequenceClassification, BertForTokenClassification, BertModel,
Expand Down Expand Up @@ -195,10 +199,27 @@ pub trait PreTrainedModel {
&self,
input_ids: &Tensor,
params: GenerationParams<'a>,
) -> Result<Vec<Vec<u32>>> {
let generation_config = params
.generation_config
.unwrap_or_else(|| self.get_generation_config().clone());
) -> Result<Vec<GenerateOutput>> {
let (mut generation_config, used_model_generation_config) = match params.generation_config {
Some(config) => (config, false),
None => (self.get_generation_config().clone(), true),
};

if !used_model_generation_config {
if generation_config.get_bos_token_id().is_none() {
generation_config.bos_token_id = self.get_generation_config().get_bos_token_id();
}

if generation_config.get_eos_token_id().is_none() {
generation_config.eos_token_id =
self.get_generation_config().get_eos_token_id().cloned();
}

if generation_config.get_pad_token_id().is_none() {
generation_config.pad_token_id = self.get_generation_config().get_pad_token_id();
}
}

generate(
self,
input_ids,
Expand Down

0 comments on commit 67fd2a7

Please sign in to comment.