diff --git a/candle-holder-models/src/from_pretrained.rs b/candle-holder-models/src/from_pretrained.rs index 49a2463..b5714c2 100644 --- a/candle-holder-models/src/from_pretrained.rs +++ b/candle-holder-models/src/from_pretrained.rs @@ -6,10 +6,7 @@ use std::{ use candle_core::{DType, Device}; use candle_nn::VarBuilder; -use hf_hub::{ - api::sync::{ApiBuilder, ApiRepo}, - Repo, RepoType, -}; +use hf_hub::api::sync::ApiRepo; use serde::{Deserialize, Serialize}; use candle_holder::{ @@ -185,7 +182,7 @@ fn load_model_weights(api: &ApiRepo) -> Result<(Vec, bool)> { /// # Arguments /// /// * `repo_id`: The Hugging Face Hub model repository id. -/// * `params`: +/// * `params` - Optional parameters to specify the revision, user agent, and auth token. /// /// # Returns /// @@ -194,6 +191,7 @@ pub fn from_pretrained>( repo_id: I, params: Option, ) -> Result { + // TODO: check if `repo_id` is a local path let api = get_repo_api(repo_id.as_ref(), params)?; // Get the model configuration from `config.json` diff --git a/candle-holder-models/src/generation/generate.rs b/candle-holder-models/src/generation/generate.rs index ce78f44..8846152 100644 --- a/candle-holder-models/src/generation/generate.rs +++ b/candle-holder-models/src/generation/generate.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{IndexOp, Tensor}; use candle_holder::{Error, Result}; use candle_holder_tokenizers::Tokenizer; @@ -31,13 +33,13 @@ impl GenerateOutput { /// # Returns /// /// A vector containing vectors of token ids for each input sequence. -pub fn generate<'a, M: PreTrainedModel + ?Sized>( +pub fn generate( model: &M, input_ids: &Tensor, generation_config: &GenerationConfig, - tokenizer: Option<&Box>, + tokenizer: Option>, stopping_criteria: Option>>, - mut token_streamer: Option + 'a>>, + mut token_streamer: Option>, seed: Option, ) -> Result> { let num_return_sequences = generation_config.get_num_return_sequences().max(1); diff --git a/candle-holder-models/src/generation/stopping_criteria.rs b/candle-holder-models/src/generation/stopping_criteria.rs index e086728..cffd5f5 100644 --- a/candle-holder-models/src/generation/stopping_criteria.rs +++ b/candle-holder-models/src/generation/stopping_criteria.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_holder::{Error, Result}; use candle_holder_tokenizers::Tokenizer; @@ -40,14 +42,14 @@ impl StoppingCriteria for EosTokenStoppingCriteria { } /// Stopping criteria that stops generation when a stop string is generated. -pub struct StopStringStoppingCriteria<'a> { +pub struct StopStringStoppingCriteria { /// The stop strings to check for. stop_strings: Vec, /// The tokenizer to use to decode the input token IDs. - tokenizer: Option<&'a Box>, + tokenizer: Option>, } -impl<'a> StopStringStoppingCriteria<'a> { +impl StopStringStoppingCriteria { /// Creates a new `StopStringStoppingCriteria` with the provided stop strings. /// /// # Arguments @@ -58,7 +60,7 @@ impl<'a> StopStringStoppingCriteria<'a> { /// # Returns /// /// A new `StopStringStoppingCriteria`. - pub fn new(stop_strings: Vec, tokenizer: Option<&'a Box>) -> Self { + pub fn new(stop_strings: Vec, tokenizer: Option>) -> Self { Self { stop_strings, tokenizer, @@ -66,9 +68,9 @@ impl<'a> StopStringStoppingCriteria<'a> { } } -impl StoppingCriteria for StopStringStoppingCriteria<'_> { +impl StoppingCriteria for StopStringStoppingCriteria { fn should_stop(&self, input_ids: &[u32]) -> Result { - if let Some(tokenizer) = self.tokenizer { + if let Some(tokenizer) = self.tokenizer.as_ref() { let input_str = tokenizer.decode(input_ids, true)?; Ok(self .stop_strings @@ -103,7 +105,7 @@ impl<'a> StoppingCriteriaApplier<'a> { pub fn from_configuration( configuration: &GenerationConfig, stopping_criteria: Option>>, - tokenizer: Option<&'a Box>, + tokenizer: Option>, ) -> Result { let mut stopping_criteria = stopping_criteria.unwrap_or_else(|| vec![]); diff --git a/candle-holder-models/src/generation/token_streamer.rs b/candle-holder-models/src/generation/token_streamer.rs index 2027e06..438df1c 100644 --- a/candle-holder-models/src/generation/token_streamer.rs +++ b/candle-holder-models/src/generation/token_streamer.rs @@ -1,11 +1,11 @@ use candle_holder::{Error, Result}; use candle_holder_tokenizers::Tokenizer; -use std::{collections::VecDeque, io::{self, Write}}; +use std::{io::{self, Write}, sync::{mpsc, Arc}}; /// A trait for streamers that receives tokens generated by `generate` method using an /// auto-regressive model. Streamers can be used to handle tokens as they are generated by the /// model. -pub trait TokenStreamer<'a> { +pub trait TokenStreamer: Send { /// Receives a sequence of tokens generated by the model. fn put(&mut self, tokens: &[Vec]) -> Result<()>; /// Called when the generation is finished @@ -13,9 +13,9 @@ pub trait TokenStreamer<'a> { } /// A streamer that prints the generated tokens to the console. -pub struct TextStreamer<'a> { +pub struct TextStreamer { /// The tokenizer used to decode the tokens into text. - tokenizer: &'a Box, + tokenizer: Arc, /// Whether to skip the prompt when decoding the tokens into text. skip_prompt: bool, /// Whether to skip special tokens when decoding the tokens @@ -28,7 +28,7 @@ pub struct TextStreamer<'a> { print_len: usize, } -impl<'a> TextStreamer<'a> { +impl TextStreamer { /// Creates a new `TextStreamer` with the given tokenizer and whether to skip special /// tokens when decoding the tokens. /// @@ -42,7 +42,7 @@ impl<'a> TextStreamer<'a> { /// /// A new `TextStreamer`. pub fn new( - tokenizer: &'a Box, + tokenizer: Arc, skip_prompt: bool, skip_special_tokens: bool, ) -> Self { @@ -121,7 +121,7 @@ impl<'a> TextStreamer<'a> { } } -impl<'a> TokenStreamer<'a> for TextStreamer<'a> { +impl TokenStreamer for TextStreamer { fn put(&mut self, tokens: &[Vec]) -> Result<()> { if let Some(printable_text) = self.decode(tokens)? { self.print(printable_text, false)?; @@ -137,56 +137,48 @@ impl<'a> TokenStreamer<'a> for TextStreamer<'a> { } -pub struct TextIteratorStreamer<'a> { - text_streamer: TextStreamer<'a>, - text_queue: VecDeque, - stream_end: bool +pub struct TextIteratorStreamer { + text_streamer: TextStreamer, + sender: mpsc::Sender>, } -impl<'a> TextIteratorStreamer<'a> { +impl TextIteratorStreamer { pub fn new( - tokenizer: &'a Box, + tokenizer: Arc, skip_prompt: bool, skip_special_tokens: bool, - ) -> Self { - Self { + ) -> (Self, mpsc::Receiver>) { + let (sender, receiver) = mpsc::channel(); + let streamer = Self { text_streamer: TextStreamer::new(tokenizer, skip_prompt, skip_special_tokens), - text_queue: VecDeque::new(), - stream_end: false, - } + sender + }; + (streamer, receiver) + } + + fn send(&self, t: Option) -> Result<()>{ + self.sender.send(t).map_err(|e| Error::msg(e.to_string())) } } -impl<'a> TokenStreamer<'a> for TextIteratorStreamer<'a> { +impl TokenStreamer for TextIteratorStreamer { fn put(&mut self, tokens: &[Vec]) -> Result<()> { if let Some(text) = self.text_streamer.decode(tokens)? { - self.text_queue.push_back(text) + self.send(Some(text))?; } Ok(()) } fn end(&mut self) -> Result<()> { - let final_text = self.text_streamer.flush_tokens()?; - if !final_text.is_empty() { - self.text_queue.push_back(final_text); + let text = self.text_streamer.flush_tokens()?; + if !text.is_empty() { + self.send(Some(text))?; } - self.stream_end = true; + self.send(None)?; Ok(()) } } -impl<'a> Iterator for TextIteratorStreamer<'a> { - type Item = String; - - fn next(&mut self) -> Option { - if self.stream_end { - return None; - } - self.text_queue.pop_front() - } - -} - fn is_chinese_char(c: char) -> bool { matches!(c, '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' diff --git a/candle-holder-models/src/model.rs b/candle-holder-models/src/model.rs index 3c180b3..bad6db6 100644 --- a/candle-holder-models/src/model.rs +++ b/candle-holder-models/src/model.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, Tensor}; use candle_holder::{utils::from_pretrained::FromPretrainedParameters, Error, Result}; use candle_holder_tokenizers::Tokenizer; @@ -88,25 +90,25 @@ impl<'a> From<&'a candle_holder_tokenizers::BatchEncoding> for ForwardParams<'a> } /// Parameters for the `generate` method of a `PreTrainedModel`. -pub struct GenerationParams<'a> { +pub struct GenerationParams { /// The generation configuration to use. If not provided, then the model's default generation /// configuration in `generation_config.json` will be used. If that is not available, then /// default generation configuration will be used. The default value is `None`. pub generation_config: Option, /// The tokenizer to use for decoding the generated tokens. It's not extrictly required, but /// some stopping criteria may depend on it. The default value is `None`. - pub tokenizer: Option<&'a Box>, + pub tokenizer: Option>, /// The list of stopping criteria to use that will determine when to stop generating tokens. /// The default value is `None`. pub stopping_criteria: Option>>, /// The token streamer which will receive the next tokens as they are being generated. The /// default value is `None`. - pub token_streamer: Option + 'a>>, + pub token_streamer: Option>, /// A seed that will be used in the sampling of the next token. The default value is `None`. pub seed: Option, } -impl Default for GenerationParams<'_> { +impl Default for GenerationParams { fn default() -> Self { Self { generation_config: None, @@ -155,7 +157,7 @@ impl ModelOutput { } /// Trait for a pre-trained model. -pub trait PreTrainedModel: Send + Sync { +pub trait PreTrainedModel: std::fmt::Debug + Send + Sync { /// Loads a model from a `VarBuilder` containing the model's parameters and a model /// configuration. /// @@ -231,10 +233,10 @@ pub trait PreTrainedModel: Send + Sync { /// # Returns /// /// A vector containing vectors of token ids for each input sequence. - fn generate<'a>( + fn generate( &self, input_ids: &Tensor, - params: GenerationParams<'a>, + params: GenerationParams, ) -> Result> { let (mut generation_config, used_model_generation_config) = match params.generation_config { Some(config) => (config, false), diff --git a/candle-holder-models/src/models/bert/modeling.rs b/candle-holder-models/src/models/bert/modeling.rs index b9d5800..06c4f94 100644 --- a/candle-holder-models/src/models/bert/modeling.rs +++ b/candle-holder-models/src/models/bert/modeling.rs @@ -15,6 +15,7 @@ use crate::{ pub const BERT_DTYPE: DType = DType::F32; +#[derive(Debug)] pub struct HiddenActLayer { act: HiddenAct, } @@ -32,6 +33,7 @@ impl HiddenActLayer { } } +#[derive(Debug)] pub struct BertEmbeddings { pub word_embeddings: Arc, pub position_embeddings: Option, @@ -88,6 +90,7 @@ impl BertEmbeddings { } } +#[derive(Debug)] pub struct BertSelfAttention { query: Linear, key: Linear, @@ -152,6 +155,7 @@ impl BertSelfAttention { } } +#[derive(Debug)] pub struct BertSelfOutput { dense: Linear, layer_norm: LayerNorm, @@ -185,6 +189,7 @@ impl BertSelfOutput { } } +#[derive(Debug)] pub struct BertAttention { self_attention: BertSelfAttention, self_output: BertSelfOutput, @@ -211,6 +216,7 @@ impl BertAttention { } } +#[derive(Debug)] pub struct BertIntermediate { dense: Linear, intermediate_act: HiddenActLayer, @@ -234,6 +240,7 @@ impl Module for BertIntermediate { } } +#[derive(Debug)] pub struct BertOutput { dense: Linear, layer_norm: LayerNorm, @@ -267,6 +274,7 @@ impl BertOutput { } } +#[derive(Debug)] pub struct BertLayer { attention: BertAttention, intermediate: BertIntermediate, @@ -299,6 +307,7 @@ impl BertLayer { } } +#[derive(Debug)] pub struct BertEncoder { layers: Vec, } @@ -324,6 +333,7 @@ impl BertEncoder { } } +#[derive(Debug)] pub struct BertPooler { dense: Linear, } @@ -343,6 +353,7 @@ impl Module for BertPooler { } } +#[derive(Debug)] pub struct BertPredictionHeadTransform { dense: Linear, transform_act_fn: HiddenActLayer, @@ -374,6 +385,7 @@ impl Module for BertPredictionHeadTransform { } } +#[derive(Debug)] pub struct BertLMPredictionHead { transform: BertPredictionHeadTransform, // The decoder weights are tied with the embeddings weights so the model only learns one @@ -408,6 +420,7 @@ impl Module for BertLMPredictionHead { } } +#[derive(Debug)] pub struct BertOnlyMLMHead { predictions: BertLMPredictionHead, } @@ -429,6 +442,7 @@ impl Module for BertOnlyMLMHead { } } +#[derive(Debug)] pub struct Bert { embeddings: BertEmbeddings, encoder: BertEncoder, @@ -481,6 +495,7 @@ impl Bert { } } +#[derive(Debug)] pub struct BertModel { model: Bert, config: BertConfig, @@ -512,6 +527,7 @@ impl PreTrainedModel for BertModel { } } +#[derive(Debug)] pub struct BertForSequenceClassification { model: Bert, dropout: Dropout, @@ -561,6 +577,7 @@ impl PreTrainedModel for BertForSequenceClassification { } } +#[derive(Debug)] pub struct BertForTokenClassification { model: Bert, dropout: Dropout, @@ -611,6 +628,7 @@ impl PreTrainedModel for BertForTokenClassification { } } +#[derive(Debug)] pub struct BertForMaskedLM { model: Bert, cls: BertOnlyMLMHead, diff --git a/candle-holder-models/src/models/llama/modeling.rs b/candle-holder-models/src/models/llama/modeling.rs index 151e423..ac9fcff 100644 --- a/candle-holder-models/src/models/llama/modeling.rs +++ b/candle-holder-models/src/models/llama/modeling.rs @@ -22,6 +22,7 @@ use super::config::{HiddenAct, LlamaConfig}; pub const LLAMA_DTYPE: DType = DType::F16; +#[derive(Debug)] pub struct LlamaRotaryEmbedding { inv_freq: Tensor, scaling_factor: Option, @@ -67,6 +68,7 @@ impl LlamaRotaryEmbedding { } } +#[derive(Debug)] pub struct LlamaAttention { q_proj: Linear, k_proj: Linear, @@ -194,6 +196,7 @@ impl LlamaAttention { } } +#[derive(Debug)] pub struct HiddenActLayer { act: HiddenAct, } @@ -210,6 +213,7 @@ impl HiddenActLayer { } } +#[derive(Debug)] pub struct LlamaMLP { gate_proj: Linear, up_proj: Linear, @@ -253,6 +257,7 @@ impl LlamaMLP { } } +#[derive(Debug)] pub struct LlamaDecoderLayer { self_attn: LlamaAttention, mlp: LlamaMLP, @@ -314,6 +319,7 @@ impl LlamaDecoderLayer { } } +#[derive(Debug)] pub struct Llama { embed_tokens: Embedding, layers: Vec, @@ -385,6 +391,7 @@ impl Llama { } } +#[derive(Debug)] pub struct LlamaModel { model: Llama, config: LlamaConfig, @@ -407,6 +414,7 @@ impl PreTrainedModel for LlamaModel { } } +#[derive(Debug)] pub struct LlamaForCausalLM { model: Llama, lm_head: Linear, diff --git a/candle-holder-pipelines/src/feature_extraction.rs b/candle-holder-pipelines/src/feature_extraction.rs index 8e19977..06a6ca2 100644 --- a/candle-holder-pipelines/src/feature_extraction.rs +++ b/candle-holder-pipelines/src/feature_extraction.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, IndexOp, Tensor, D}; use candle_holder::{FromPretrainedParameters, Result}; use candle_holder_models::{AutoModel, ForwardParams, PreTrainedModel}; @@ -35,7 +37,7 @@ impl Default for FeatureExtractionOptions { /// A pipeline for generating sentence embeddings from input texts. pub struct FeatureExtractionPipeline { model: Box, - tokenizer: Box, + tokenizer: Arc, device: Device, } diff --git a/candle-holder-pipelines/src/fill_mask.rs b/candle-holder-pipelines/src/fill_mask.rs index fb58032..4f35ff7 100644 --- a/candle-holder-pipelines/src/fill_mask.rs +++ b/candle-holder-pipelines/src/fill_mask.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, Tensor, D}; use candle_holder::{Error, FromPretrainedParameters, Result}; use candle_holder_models::{AutoModelForMaskedLM, ForwardParams, PreTrainedModel}; @@ -19,7 +21,7 @@ impl Default for FillMaskOptions { /// A pipeline for filling masked tokens in a sentence. pub struct FillMaskPipeline { model: Box, - tokenizer: Box, + tokenizer: Arc, device: Device, } diff --git a/candle-holder-pipelines/src/text_classification.rs b/candle-holder-pipelines/src/text_classification.rs index cd48ba0..a53a2ff 100644 --- a/candle-holder-pipelines/src/text_classification.rs +++ b/candle-holder-pipelines/src/text_classification.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, Tensor, D}; use candle_holder::{Error, FromPretrainedParameters, Result}; use candle_holder_models::{ @@ -9,7 +11,7 @@ use candle_nn::ops::{sigmoid, softmax}; /// A pipeline for doing text classification. pub struct TextClassificationPipeline { model: Box, - tokenizer: Box, + tokenizer: Arc, device: Device, } diff --git a/candle-holder-pipelines/src/text_generation.rs b/candle-holder-pipelines/src/text_generation.rs index b624f32..a0a87b2 100644 --- a/candle-holder-pipelines/src/text_generation.rs +++ b/candle-holder-pipelines/src/text_generation.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, Tensor}; use candle_holder::{FromPretrainedParameters, Result}; use candle_holder_models::{ @@ -78,7 +80,7 @@ impl TextGenerationPipelineOutput { /// A pipeline for generating text with a causal language model. pub struct TextGenerationPipeline { model: Box, - tokenizer: Box, + tokenizer: Arc, device: Device, } diff --git a/candle-holder-pipelines/src/token_classification.rs b/candle-holder-pipelines/src/token_classification.rs index 8265e8e..1d9f778 100644 --- a/candle-holder-pipelines/src/token_classification.rs +++ b/candle-holder-pipelines/src/token_classification.rs @@ -3,7 +3,7 @@ use candle_holder::{Error, FromPretrainedParameters, Result}; use candle_holder_models::{AutoModelForTokenClassification, ForwardParams, PreTrainedModel}; use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Padding, PaddingOptions, Tokenizer}; use serde::Deserialize; -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use tokenizers::Encoding; #[derive(Debug, Clone)] @@ -112,7 +112,7 @@ fn substring(s: &str, start: usize, end: usize) -> String { /// A pipeline for token classification. pub struct TokenClassificationPipeline { model: Box, - tokenizer: Box, + tokenizer: Arc, device: Device, id2label: HashMap, } diff --git a/candle-holder-pipelines/src/zero_shot_classification.rs b/candle-holder-pipelines/src/zero_shot_classification.rs index df04d41..3b915ae 100644 --- a/candle-holder-pipelines/src/zero_shot_classification.rs +++ b/candle-holder-pipelines/src/zero_shot_classification.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, IndexOp, Tensor, D}; use candle_holder::{Error, FromPretrainedParameters, Result}; use candle_holder_models::{AutoModelForSequenceClassification, ForwardParams, PreTrainedModel}; @@ -24,7 +26,7 @@ impl Default for ZeroShotClassificationOptions { /// A pipeline for doing zero-shot classification. pub struct ZeroShotClassificationPipeline { model: Box, - tokenizer: Box, + tokenizer: Arc, device: Device, num_labels: usize, entailment_id: i8, diff --git a/candle-holder-tokenizers/src/from_pretrained.rs b/candle-holder-tokenizers/src/from_pretrained.rs index 972fcd9..ef5a48a 100644 --- a/candle-holder-tokenizers/src/from_pretrained.rs +++ b/candle-holder-tokenizers/src/from_pretrained.rs @@ -5,7 +5,6 @@ use candle_holder::{ utils::from_pretrained::{load_model_config, FromPretrainedParameters, MODEL_CONFIG_FILE}, Result, }; -use hf_hub::{api::sync::Api, Repo, RepoType}; use lazy_static::lazy_static; use serde::{Deserialize, Deserializer, Serialize}; use tokenizers::{ diff --git a/candle-holder-tokenizers/src/tokenizer.rs b/candle-holder-tokenizers/src/tokenizer.rs index d03bf0a..399b3d8 100644 --- a/candle-holder-tokenizers/src/tokenizer.rs +++ b/candle-holder-tokenizers/src/tokenizer.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use candle_core::{DType, Device, Tensor}; use candle_holder::{Error, FromPretrainedParameters, Result}; use tokenizers::{ @@ -462,16 +464,16 @@ macro_rules! impl_auto_tokenizer_from_pretrained_method { repo_id: S, padding_side: Option, params: Option - ) -> Result> { + ) -> Result> { let tokenizer_info = from_pretrained(repo_id, params)?; let tokenizer_class = tokenizer_info.get_tokenizer_class(); - let tokenizer: Result> = match tokenizer_class { + let tokenizer: Result> = match tokenizer_class { $( $tokenizer_class => { $tokenizer_builder_struct::new(tokenizer_info, padding_side) .build() - .map(|tokenizer| Box::new(tokenizer) as Box) + .map(|tokenizer| Arc::new(tokenizer) as Arc) } )* _ => Err(Error::TokenizerNotImplemented(tokenizer_class.to_string())), @@ -492,11 +494,11 @@ macro_rules! impl_tokenizer_from_pretrained_method { repo_id: S, padding_side: Option, params: Option, - ) -> Result> { + ) -> Result> { let tokenizer_info = from_pretrained(repo_id, params)?; let tokenizer = $tokenizer_builder_struct::new(tokenizer_info, padding_side).build()?; - Ok(Box::new(tokenizer)) + Ok(Arc::new(tokenizer)) } } };