From 87090237c4718ecf7433c570480fb1cbc6db44d0 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 3 Sep 2024 00:14:58 +0200 Subject: [PATCH] Add `FeatureExtractionPipeline` --- .../feature_extraction_pipeline/README.md | 7 + .../feature_extraction_pipeline/main.rs | 26 +++ .../src/feature_extraction.rs | 165 ++++++++++++++++++ candle-holder-pipelines/src/lib.rs | 2 + 4 files changed, 200 insertions(+) create mode 100644 candle-holder-examples/examples/feature_extraction_pipeline/README.md create mode 100644 candle-holder-examples/examples/feature_extraction_pipeline/main.rs create mode 100644 candle-holder-pipelines/src/feature_extraction.rs diff --git a/candle-holder-examples/examples/feature_extraction_pipeline/README.md b/candle-holder-examples/examples/feature_extraction_pipeline/README.md new file mode 100644 index 0000000..5c9b6fb --- /dev/null +++ b/candle-holder-examples/examples/feature_extraction_pipeline/README.md @@ -0,0 +1,7 @@ +# Feature Extraction Pipeline + +## Running the example + +```bash +cargo run --example feature_extraction_pipeline +``` diff --git a/candle-holder-examples/examples/feature_extraction_pipeline/main.rs b/candle-holder-examples/examples/feature_extraction_pipeline/main.rs new file mode 100644 index 0000000..4f76870 --- /dev/null +++ b/candle-holder-examples/examples/feature_extraction_pipeline/main.rs @@ -0,0 +1,26 @@ +use anyhow::Result; +use candle_holder_examples::get_device_from_args; +use candle_holder_pipelines::{FeatureExtractionOptions, FeatureExtractionPipeline, Pooling}; + +fn main() -> Result<()> { + let device = get_device_from_args()?; + println!("Device: {:?}", device); + + let pipeline = FeatureExtractionPipeline::new( + "sentence-transformers/all-MiniLM-L6-v2", + &device, + None, + None, + )?; + + let results = pipeline.run("This is an example sentence", None)?; + println!("`pipeline.run` results: {}", results); + + let results = pipeline.run_batch( + vec!["This is an example sentence", "Each sentence is converted"], + None, + )?; + println!("`pipeline.run_batch` results: {}", results); + + Ok(()) +} diff --git a/candle-holder-pipelines/src/feature_extraction.rs b/candle-holder-pipelines/src/feature_extraction.rs new file mode 100644 index 0000000..e441f52 --- /dev/null +++ b/candle-holder-pipelines/src/feature_extraction.rs @@ -0,0 +1,165 @@ +use candle_core::{DType, Device, IndexOp, Tensor, D}; +use candle_holder::{FromPretrainedParameters, Result}; +use candle_holder_models::{AutoModel, ForwardParams, PreTrainedModel}; +use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Padding, PaddingOptions, Tokenizer}; + +/// The pooling strategy that will be used to pool the outputs of the model. +pub enum Pooling { + /// Mean pooling. + Mean, + /// Use the CLS token. + Cls, +} + +/// Options for the [`FeatureExtractionPipeline`]. +pub struct FeatureExtractionOptions { + /// The pooling strategy that will be used to pool the outputs of the model. + pub pooling: Option, + /// Whether to normalize the outputs. + pub normalize: bool, +} + +impl Default for FeatureExtractionOptions { + fn default() -> Self { + Self { + pooling: Some(Pooling::Mean), + normalize: true, + } + } +} + +/// A pipeline for generating sentence embeddings from input texts. +pub struct FeatureExtractionPipeline { + model: Box, + tokenizer: Box, + device: Device, +} + +impl FeatureExtractionPipeline { + /// Creates a new `FeatureExtractionPipeline`. + /// + /// # Arguments + /// + /// * `identifier` - The repository id of the model to load. + /// * `device` - The device to run the model on. + /// * `params` - Optional parameters to specify the revision, user agent, and auth token. + /// + /// # Returns + /// + /// The `ZeroShotClassificationPipeline` instance. + pub fn new + Copy>( + identifier: S, + device: &Device, + dtype: Option, + params: Option, + ) -> Result { + let model = AutoModel::from_pretrained(identifier, device, dtype, params.clone())?; + let tokenizer = AutoTokenizer::from_pretrained(identifier, None, params)?; + Ok(Self { + model, + tokenizer, + device: device.clone(), + }) + } + + fn preprocess(&self, sequences: Vec) -> Result { + let mut encodings = self.tokenizer.encode( + sequences, + true, + Some(PaddingOptions::new(Padding::Longest, None)), + )?; + encodings.to_device(&self.device)?; + Ok(encodings) + } + + fn postprocess( + &self, + outputs: &Tensor, + encodings: BatchEncoding, + options: FeatureExtractionOptions, + ) -> Result { + let mut outputs = outputs.clone(); + + if let Some(pooling) = options.pooling { + outputs = match pooling { + Pooling::Mean => mean_pooling(&outputs, encodings.get_attention_mask())?, + Pooling::Cls => outputs.i((.., 0))?, + }; + } + + if options.normalize { + outputs = outputs.broadcast_div(&outputs.powf(2.)?.sum(1)?.sqrt()?.unsqueeze(1)?)?; + } + + Ok(outputs) + } + + /// Generates a sentence embedding for the input text. + /// + /// # Arguments + /// + /// * `input` - The input text. + /// * `options` - Optional feature extraction options. + /// + /// # Returns + /// + /// The sentence embedding tensor. + pub fn run>( + &self, + input: I, + options: Option, + ) -> Result { + let options = options.unwrap_or_default(); + let inputs = vec![input.into()]; + let encodings = self.preprocess(inputs)?; + let outputs = self.model.forward(ForwardParams::from(&encodings))?; + let last_hidden_states = outputs.get_last_hidden_state().unwrap(); + Ok(self + .postprocess(last_hidden_states, encodings, options)? + .i(0)?) + } + + /// Generates sentence embeddings for a batch of input texts. + /// + /// # Arguments + /// + /// * `inputs` - The input texts. + /// * `options` - Optional feature extraction options. + /// + /// # Returns + /// + /// A tensor containing the sentence embeddings. + pub fn run_batch>( + &self, + inputs: Vec, + options: Option, + ) -> Result { + let options = options.unwrap_or_default(); + let inputs: Vec = inputs.into_iter().map(|x| x.into()).collect(); + let encodings = self.preprocess(inputs)?; + let outputs = self.model.forward(ForwardParams::from(&encodings))?; + let last_hidden_states = outputs.get_last_hidden_state().unwrap(); + self.postprocess(last_hidden_states, encodings, options) + } +} + +/// Computes the mean pooling of the outputs. +/// +/// # Arguments +/// +/// * `outputs` - The model outputs. +/// * `attention_mask` - The attention mask. +/// +/// # Returns +/// +/// The mean pooled tensor. +fn mean_pooling(outputs: &Tensor, attention_mask: &Tensor) -> Result { + let input_mask_expanded = attention_mask + .unsqueeze(D::Minus1)? + .repeat((1, 1, outputs.dims3()?.2))? + .to_dtype(outputs.dtype())?; + Ok(outputs + .mul(&input_mask_expanded)? + .sum(1)? + .broadcast_div(&input_mask_expanded.sum(1)?.clamp(f64::MIN, f64::MAX)?)?) +} diff --git a/candle-holder-pipelines/src/lib.rs b/candle-holder-pipelines/src/lib.rs index 1238364..35cd141 100644 --- a/candle-holder-pipelines/src/lib.rs +++ b/candle-holder-pipelines/src/lib.rs @@ -1,9 +1,11 @@ +pub mod feature_extraction; pub mod fill_mask; pub mod text_classification; pub mod text_generation; pub mod token_classification; pub mod zero_shot_classification; +pub use feature_extraction::{FeatureExtractionOptions, FeatureExtractionPipeline, Pooling}; pub use fill_mask::{FillMaskOptions, FillMaskPipeline}; pub use text_classification::TextClassificationPipeline; pub use text_generation::{TextGenerationPipeline, TextGenerationPipelineOutput};