Skip to content

Commit

Permalink
Add FeatureExtractionPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 2, 2024
1 parent 77f2fd1 commit 8709023
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Feature Extraction Pipeline

## Running the example

```bash
cargo run --example feature_extraction_pipeline
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use anyhow::Result;
use candle_holder_examples::get_device_from_args;
use candle_holder_pipelines::{FeatureExtractionOptions, FeatureExtractionPipeline, Pooling};

fn main() -> Result<()> {
let device = get_device_from_args()?;
println!("Device: {:?}", device);

let pipeline = FeatureExtractionPipeline::new(
"sentence-transformers/all-MiniLM-L6-v2",
&device,
None,
None,
)?;

let results = pipeline.run("This is an example sentence", None)?;
println!("`pipeline.run` results: {}", results);

let results = pipeline.run_batch(
vec!["This is an example sentence", "Each sentence is converted"],
None,
)?;
println!("`pipeline.run_batch` results: {}", results);

Ok(())
}
165 changes: 165 additions & 0 deletions candle-holder-pipelines/src/feature_extraction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
use candle_core::{DType, Device, IndexOp, Tensor, D};
use candle_holder::{FromPretrainedParameters, Result};
use candle_holder_models::{AutoModel, ForwardParams, PreTrainedModel};
use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Padding, PaddingOptions, Tokenizer};

/// The pooling strategy that will be used to pool the outputs of the model.
pub enum Pooling {
/// Mean pooling.
Mean,
/// Use the CLS token.
Cls,
}

/// Options for the [`FeatureExtractionPipeline`].
pub struct FeatureExtractionOptions {
/// The pooling strategy that will be used to pool the outputs of the model.
pub pooling: Option<Pooling>,
/// Whether to normalize the outputs.
pub normalize: bool,
}

impl Default for FeatureExtractionOptions {
fn default() -> Self {
Self {
pooling: Some(Pooling::Mean),
normalize: true,
}
}
}

/// A pipeline for generating sentence embeddings from input texts.
pub struct FeatureExtractionPipeline {
model: Box<dyn PreTrainedModel>,
tokenizer: Box<dyn Tokenizer>,
device: Device,
}

impl FeatureExtractionPipeline {
/// Creates a new `FeatureExtractionPipeline`.
///
/// # Arguments
///
/// * `identifier` - The repository id of the model to load.
/// * `device` - The device to run the model on.
/// * `params` - Optional parameters to specify the revision, user agent, and auth token.
///
/// # Returns
///
/// The `ZeroShotClassificationPipeline` instance.
pub fn new<S: AsRef<str> + Copy>(
identifier: S,
device: &Device,
dtype: Option<DType>,
params: Option<FromPretrainedParameters>,
) -> Result<Self> {
let model = AutoModel::from_pretrained(identifier, device, dtype, params.clone())?;
let tokenizer = AutoTokenizer::from_pretrained(identifier, None, params)?;
Ok(Self {
model,
tokenizer,
device: device.clone(),
})
}

fn preprocess(&self, sequences: Vec<String>) -> Result<BatchEncoding> {
let mut encodings = self.tokenizer.encode(
sequences,
true,
Some(PaddingOptions::new(Padding::Longest, None)),
)?;
encodings.to_device(&self.device)?;
Ok(encodings)
}

fn postprocess(
&self,
outputs: &Tensor,
encodings: BatchEncoding,
options: FeatureExtractionOptions,
) -> Result<Tensor> {
let mut outputs = outputs.clone();

if let Some(pooling) = options.pooling {
outputs = match pooling {
Pooling::Mean => mean_pooling(&outputs, encodings.get_attention_mask())?,
Pooling::Cls => outputs.i((.., 0))?,
};
}

if options.normalize {
outputs = outputs.broadcast_div(&outputs.powf(2.)?.sum(1)?.sqrt()?.unsqueeze(1)?)?;
}

Ok(outputs)
}

/// Generates a sentence embedding for the input text.
///
/// # Arguments
///
/// * `input` - The input text.
/// * `options` - Optional feature extraction options.
///
/// # Returns
///
/// The sentence embedding tensor.
pub fn run<I: Into<String>>(
&self,
input: I,
options: Option<FeatureExtractionOptions>,
) -> Result<Tensor> {
let options = options.unwrap_or_default();
let inputs = vec![input.into()];
let encodings = self.preprocess(inputs)?;
let outputs = self.model.forward(ForwardParams::from(&encodings))?;
let last_hidden_states = outputs.get_last_hidden_state().unwrap();
Ok(self
.postprocess(last_hidden_states, encodings, options)?
.i(0)?)
}

/// Generates sentence embeddings for a batch of input texts.
///
/// # Arguments
///
/// * `inputs` - The input texts.
/// * `options` - Optional feature extraction options.
///
/// # Returns
///
/// A tensor containing the sentence embeddings.
pub fn run_batch<I: Into<String>>(
&self,
inputs: Vec<I>,
options: Option<FeatureExtractionOptions>,
) -> Result<Tensor> {
let options = options.unwrap_or_default();
let inputs: Vec<String> = inputs.into_iter().map(|x| x.into()).collect();
let encodings = self.preprocess(inputs)?;
let outputs = self.model.forward(ForwardParams::from(&encodings))?;
let last_hidden_states = outputs.get_last_hidden_state().unwrap();
self.postprocess(last_hidden_states, encodings, options)
}
}

/// Computes the mean pooling of the outputs.
///
/// # Arguments
///
/// * `outputs` - The model outputs.
/// * `attention_mask` - The attention mask.
///
/// # Returns
///
/// The mean pooled tensor.
fn mean_pooling(outputs: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let input_mask_expanded = attention_mask
.unsqueeze(D::Minus1)?
.repeat((1, 1, outputs.dims3()?.2))?
.to_dtype(outputs.dtype())?;
Ok(outputs
.mul(&input_mask_expanded)?
.sum(1)?
.broadcast_div(&input_mask_expanded.sum(1)?.clamp(f64::MIN, f64::MAX)?)?)
}
2 changes: 2 additions & 0 deletions candle-holder-pipelines/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
pub mod feature_extraction;
pub mod fill_mask;
pub mod text_classification;
pub mod text_generation;
pub mod token_classification;
pub mod zero_shot_classification;

pub use feature_extraction::{FeatureExtractionOptions, FeatureExtractionPipeline, Pooling};
pub use fill_mask::{FillMaskOptions, FillMaskPipeline};
pub use text_classification::TextClassificationPipeline;
pub use text_generation::{TextGenerationPipeline, TextGenerationPipelineOutput};
Expand Down

0 comments on commit 8709023

Please sign in to comment.