Skip to content

Commit

Permalink
Use Arc instead of Box for tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Nov 11, 2024
1 parent ba703bb commit eb1b8e7
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 71 deletions.
8 changes: 3 additions & 5 deletions candle-holder-models/src/from_pretrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ use std::{

use candle_core::{DType, Device};
use candle_nn::VarBuilder;
use hf_hub::{
api::sync::{ApiBuilder, ApiRepo},
Repo, RepoType,
};
use hf_hub::api::sync::ApiRepo;
use serde::{Deserialize, Serialize};

use candle_holder::{
Expand Down Expand Up @@ -185,7 +182,7 @@ fn load_model_weights(api: &ApiRepo) -> Result<(Vec<PathBuf>, bool)> {
/// # Arguments
///
/// * `repo_id`: The Hugging Face Hub model repository id.
/// * `params`:
/// * `params` - Optional parameters to specify the revision, user agent, and auth token.
///
/// # Returns
///
Expand All @@ -194,6 +191,7 @@ pub fn from_pretrained<I: AsRef<str>>(
repo_id: I,
params: Option<FromPretrainedParameters>,
) -> Result<ModelInfo> {
// TODO: check if `repo_id` is a local path
let api = get_repo_api(repo_id.as_ref(), params)?;

// Get the model configuration from `config.json`
Expand Down
8 changes: 5 additions & 3 deletions candle-holder-models/src/generation/generate.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use candle_core::{IndexOp, Tensor};
use candle_holder::{Error, Result};
use candle_holder_tokenizers::Tokenizer;
Expand Down Expand Up @@ -31,13 +33,13 @@ impl GenerateOutput {
/// # Returns
///
/// A vector containing vectors of token ids for each input sequence.
pub fn generate<'a, M: PreTrainedModel + ?Sized>(
pub fn generate<M: PreTrainedModel + ?Sized>(
model: &M,
input_ids: &Tensor,
generation_config: &GenerationConfig,
tokenizer: Option<&Box<dyn Tokenizer>>,
tokenizer: Option<Arc<dyn Tokenizer>>,
stopping_criteria: Option<Vec<Box<dyn StoppingCriteria>>>,
mut token_streamer: Option<Box<dyn TokenStreamer<'a> + 'a>>,
mut token_streamer: Option<Box<dyn TokenStreamer>>,
seed: Option<u64>,
) -> Result<Vec<GenerateOutput>> {
let num_return_sequences = generation_config.get_num_return_sequences().max(1);
Expand Down
16 changes: 9 additions & 7 deletions candle-holder-models/src/generation/stopping_criteria.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use candle_holder::{Error, Result};
use candle_holder_tokenizers::Tokenizer;

Expand Down Expand Up @@ -40,14 +42,14 @@ impl StoppingCriteria for EosTokenStoppingCriteria {
}

/// Stopping criteria that stops generation when a stop string is generated.
pub struct StopStringStoppingCriteria<'a> {
pub struct StopStringStoppingCriteria {
/// The stop strings to check for.
stop_strings: Vec<String>,
/// The tokenizer to use to decode the input token IDs.
tokenizer: Option<&'a Box<dyn Tokenizer>>,
tokenizer: Option<Arc<dyn Tokenizer>>,
}

impl<'a> StopStringStoppingCriteria<'a> {
impl StopStringStoppingCriteria {
/// Creates a new `StopStringStoppingCriteria` with the provided stop strings.
///
/// # Arguments
Expand All @@ -58,17 +60,17 @@ impl<'a> StopStringStoppingCriteria<'a> {
/// # Returns
///
/// A new `StopStringStoppingCriteria`.
pub fn new(stop_strings: Vec<String>, tokenizer: Option<&'a Box<dyn Tokenizer>>) -> Self {
pub fn new(stop_strings: Vec<String>, tokenizer: Option<Arc<dyn Tokenizer>>) -> Self {
Self {
stop_strings,
tokenizer,
}
}
}

impl StoppingCriteria for StopStringStoppingCriteria<'_> {
impl StoppingCriteria for StopStringStoppingCriteria {
fn should_stop(&self, input_ids: &[u32]) -> Result<bool> {
if let Some(tokenizer) = self.tokenizer {
if let Some(tokenizer) = self.tokenizer.as_ref() {
let input_str = tokenizer.decode(input_ids, true)?;
Ok(self
.stop_strings
Expand Down Expand Up @@ -103,7 +105,7 @@ impl<'a> StoppingCriteriaApplier<'a> {
pub fn from_configuration(
configuration: &GenerationConfig,
stopping_criteria: Option<Vec<Box<dyn StoppingCriteria>>>,
tokenizer: Option<&'a Box<dyn Tokenizer>>,
tokenizer: Option<Arc<dyn Tokenizer>>,
) -> Result<Self> {
let mut stopping_criteria = stopping_criteria.unwrap_or_else(|| vec![]);

Expand Down
64 changes: 28 additions & 36 deletions candle-holder-models/src/generation/token_streamer.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use candle_holder::{Error, Result};
use candle_holder_tokenizers::Tokenizer;
use std::{collections::VecDeque, io::{self, Write}};
use std::{io::{self, Write}, sync::{mpsc, Arc}};

/// A trait for streamers that receives tokens generated by `generate` method using an
/// auto-regressive model. Streamers can be used to handle tokens as they are generated by the
/// model.
pub trait TokenStreamer<'a> {
pub trait TokenStreamer: Send {
/// Receives a sequence of tokens generated by the model.
fn put(&mut self, tokens: &[Vec<u32>]) -> Result<()>;
/// Called when the generation is finished
fn end(&mut self) -> Result<()>;
}

/// A streamer that prints the generated tokens to the console.
pub struct TextStreamer<'a> {
pub struct TextStreamer {
/// The tokenizer used to decode the tokens into text.
tokenizer: &'a Box<dyn Tokenizer>,
tokenizer: Arc<dyn Tokenizer>,
/// Whether to skip the prompt when decoding the tokens into text.
skip_prompt: bool,
/// Whether to skip special tokens when decoding the tokens
Expand All @@ -28,7 +28,7 @@ pub struct TextStreamer<'a> {
print_len: usize,
}

impl<'a> TextStreamer<'a> {
impl TextStreamer {
/// Creates a new `TextStreamer` with the given tokenizer and whether to skip special
/// tokens when decoding the tokens.
///
Expand All @@ -42,7 +42,7 @@ impl<'a> TextStreamer<'a> {
///
/// A new `TextStreamer`.
pub fn new(
tokenizer: &'a Box<dyn Tokenizer>,
tokenizer: Arc<dyn Tokenizer>,
skip_prompt: bool,
skip_special_tokens: bool,
) -> Self {
Expand Down Expand Up @@ -121,7 +121,7 @@ impl<'a> TextStreamer<'a> {
}
}

impl<'a> TokenStreamer<'a> for TextStreamer<'a> {
impl TokenStreamer for TextStreamer {
fn put(&mut self, tokens: &[Vec<u32>]) -> Result<()> {
if let Some(printable_text) = self.decode(tokens)? {
self.print(printable_text, false)?;
Expand All @@ -137,56 +137,48 @@ impl<'a> TokenStreamer<'a> for TextStreamer<'a> {
}


pub struct TextIteratorStreamer<'a> {
text_streamer: TextStreamer<'a>,
text_queue: VecDeque<String>,
stream_end: bool
pub struct TextIteratorStreamer {
text_streamer: TextStreamer,
sender: mpsc::Sender<Option<String>>,
}

impl<'a> TextIteratorStreamer<'a> {
impl TextIteratorStreamer {
pub fn new(
tokenizer: &'a Box<dyn Tokenizer>,
tokenizer: Arc<dyn Tokenizer>,
skip_prompt: bool,
skip_special_tokens: bool,
) -> Self {
Self {
) -> (Self, mpsc::Receiver<Option<String>>) {
let (sender, receiver) = mpsc::channel();
let streamer = Self {
text_streamer: TextStreamer::new(tokenizer, skip_prompt, skip_special_tokens),
text_queue: VecDeque::new(),
stream_end: false,
}
sender
};
(streamer, receiver)
}

fn send(&self, t: Option<String>) -> Result<()>{
self.sender.send(t).map_err(|e| Error::msg(e.to_string()))
}
}

impl<'a> TokenStreamer<'a> for TextIteratorStreamer<'a> {
impl TokenStreamer for TextIteratorStreamer {
fn put(&mut self, tokens: &[Vec<u32>]) -> Result<()> {
if let Some(text) = self.text_streamer.decode(tokens)? {
self.text_queue.push_back(text)
self.send(Some(text))?;
}
Ok(())
}

fn end(&mut self) -> Result<()> {
let final_text = self.text_streamer.flush_tokens()?;
if !final_text.is_empty() {
self.text_queue.push_back(final_text);
let text = self.text_streamer.flush_tokens()?;
if !text.is_empty() {
self.send(Some(text))?;
}
self.stream_end = true;
self.send(None)?;
Ok(())
}
}

impl<'a> Iterator for TextIteratorStreamer<'a> {
type Item = String;

fn next(&mut self) -> Option<Self::Item> {
if self.stream_end {
return None;
}
self.text_queue.pop_front()
}

}

fn is_chinese_char(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}'
Expand Down
16 changes: 9 additions & 7 deletions candle-holder-models/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use candle_core::{DType, Device, Tensor};
use candle_holder::{utils::from_pretrained::FromPretrainedParameters, Error, Result};
use candle_holder_tokenizers::Tokenizer;
Expand Down Expand Up @@ -88,25 +90,25 @@ impl<'a> From<&'a candle_holder_tokenizers::BatchEncoding> for ForwardParams<'a>
}

/// Parameters for the `generate` method of a `PreTrainedModel`.
pub struct GenerationParams<'a> {
pub struct GenerationParams {
/// The generation configuration to use. If not provided, then the model's default generation
/// configuration in `generation_config.json` will be used. If that is not available, then
/// default generation configuration will be used. The default value is `None`.
pub generation_config: Option<GenerationConfig>,
/// The tokenizer to use for decoding the generated tokens. It's not extrictly required, but
/// some stopping criteria may depend on it. The default value is `None`.
pub tokenizer: Option<&'a Box<dyn Tokenizer>>,
pub tokenizer: Option<Arc<dyn Tokenizer>>,
/// The list of stopping criteria to use that will determine when to stop generating tokens.
/// The default value is `None`.
pub stopping_criteria: Option<Vec<Box<dyn StoppingCriteria>>>,
/// The token streamer which will receive the next tokens as they are being generated. The
/// default value is `None`.
pub token_streamer: Option<Box<dyn TokenStreamer<'a> + 'a>>,
pub token_streamer: Option<Box<dyn TokenStreamer>>,
/// A seed that will be used in the sampling of the next token. The default value is `None`.
pub seed: Option<u64>,
}

impl Default for GenerationParams<'_> {
impl Default for GenerationParams {
fn default() -> Self {
Self {
generation_config: None,
Expand Down Expand Up @@ -155,7 +157,7 @@ impl ModelOutput {
}

/// Trait for a pre-trained model.
pub trait PreTrainedModel: Send + Sync {
pub trait PreTrainedModel: std::fmt::Debug + Send + Sync {
/// Loads a model from a `VarBuilder` containing the model's parameters and a model
/// configuration.
///
Expand Down Expand Up @@ -231,10 +233,10 @@ pub trait PreTrainedModel: Send + Sync {
/// # Returns
///
/// A vector containing vectors of token ids for each input sequence.
fn generate<'a>(
fn generate(
&self,
input_ids: &Tensor,
params: GenerationParams<'a>,
params: GenerationParams,
) -> Result<Vec<GenerateOutput>> {
let (mut generation_config, used_model_generation_config) = match params.generation_config {
Some(config) => (config, false),
Expand Down
Loading

0 comments on commit eb1b8e7

Please sign in to comment.