-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
77f2fd1
commit 8709023
Showing
4 changed files
with
200 additions
and
0 deletions.
There are no files selected for viewing
7 changes: 7 additions & 0 deletions
7
candle-holder-examples/examples/feature_extraction_pipeline/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Feature Extraction Pipeline | ||
|
||
## Running the example | ||
|
||
```bash | ||
cargo run --example feature_extraction_pipeline | ||
``` |
26 changes: 26 additions & 0 deletions
26
candle-holder-examples/examples/feature_extraction_pipeline/main.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
use anyhow::Result; | ||
use candle_holder_examples::get_device_from_args; | ||
use candle_holder_pipelines::{FeatureExtractionOptions, FeatureExtractionPipeline, Pooling}; | ||
|
||
fn main() -> Result<()> { | ||
let device = get_device_from_args()?; | ||
println!("Device: {:?}", device); | ||
|
||
let pipeline = FeatureExtractionPipeline::new( | ||
"sentence-transformers/all-MiniLM-L6-v2", | ||
&device, | ||
None, | ||
None, | ||
)?; | ||
|
||
let results = pipeline.run("This is an example sentence", None)?; | ||
println!("`pipeline.run` results: {}", results); | ||
|
||
let results = pipeline.run_batch( | ||
vec!["This is an example sentence", "Each sentence is converted"], | ||
None, | ||
)?; | ||
println!("`pipeline.run_batch` results: {}", results); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
use candle_core::{DType, Device, IndexOp, Tensor, D}; | ||
use candle_holder::{FromPretrainedParameters, Result}; | ||
use candle_holder_models::{AutoModel, ForwardParams, PreTrainedModel}; | ||
use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Padding, PaddingOptions, Tokenizer}; | ||
|
||
/// The pooling strategy that will be used to pool the outputs of the model. | ||
pub enum Pooling { | ||
/// Mean pooling. | ||
Mean, | ||
/// Use the CLS token. | ||
Cls, | ||
} | ||
|
||
/// Options for the [`FeatureExtractionPipeline`]. | ||
pub struct FeatureExtractionOptions { | ||
/// The pooling strategy that will be used to pool the outputs of the model. | ||
pub pooling: Option<Pooling>, | ||
/// Whether to normalize the outputs. | ||
pub normalize: bool, | ||
} | ||
|
||
impl Default for FeatureExtractionOptions { | ||
fn default() -> Self { | ||
Self { | ||
pooling: Some(Pooling::Mean), | ||
normalize: true, | ||
} | ||
} | ||
} | ||
|
||
/// A pipeline for generating sentence embeddings from input texts. | ||
pub struct FeatureExtractionPipeline { | ||
model: Box<dyn PreTrainedModel>, | ||
tokenizer: Box<dyn Tokenizer>, | ||
device: Device, | ||
} | ||
|
||
impl FeatureExtractionPipeline { | ||
/// Creates a new `FeatureExtractionPipeline`. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `identifier` - The repository id of the model to load. | ||
/// * `device` - The device to run the model on. | ||
/// * `params` - Optional parameters to specify the revision, user agent, and auth token. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The `ZeroShotClassificationPipeline` instance. | ||
pub fn new<S: AsRef<str> + Copy>( | ||
identifier: S, | ||
device: &Device, | ||
dtype: Option<DType>, | ||
params: Option<FromPretrainedParameters>, | ||
) -> Result<Self> { | ||
let model = AutoModel::from_pretrained(identifier, device, dtype, params.clone())?; | ||
let tokenizer = AutoTokenizer::from_pretrained(identifier, None, params)?; | ||
Ok(Self { | ||
model, | ||
tokenizer, | ||
device: device.clone(), | ||
}) | ||
} | ||
|
||
fn preprocess(&self, sequences: Vec<String>) -> Result<BatchEncoding> { | ||
let mut encodings = self.tokenizer.encode( | ||
sequences, | ||
true, | ||
Some(PaddingOptions::new(Padding::Longest, None)), | ||
)?; | ||
encodings.to_device(&self.device)?; | ||
Ok(encodings) | ||
} | ||
|
||
fn postprocess( | ||
&self, | ||
outputs: &Tensor, | ||
encodings: BatchEncoding, | ||
options: FeatureExtractionOptions, | ||
) -> Result<Tensor> { | ||
let mut outputs = outputs.clone(); | ||
|
||
if let Some(pooling) = options.pooling { | ||
outputs = match pooling { | ||
Pooling::Mean => mean_pooling(&outputs, encodings.get_attention_mask())?, | ||
Pooling::Cls => outputs.i((.., 0))?, | ||
}; | ||
} | ||
|
||
if options.normalize { | ||
outputs = outputs.broadcast_div(&outputs.powf(2.)?.sum(1)?.sqrt()?.unsqueeze(1)?)?; | ||
} | ||
|
||
Ok(outputs) | ||
} | ||
|
||
/// Generates a sentence embedding for the input text. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `input` - The input text. | ||
/// * `options` - Optional feature extraction options. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The sentence embedding tensor. | ||
pub fn run<I: Into<String>>( | ||
&self, | ||
input: I, | ||
options: Option<FeatureExtractionOptions>, | ||
) -> Result<Tensor> { | ||
let options = options.unwrap_or_default(); | ||
let inputs = vec![input.into()]; | ||
let encodings = self.preprocess(inputs)?; | ||
let outputs = self.model.forward(ForwardParams::from(&encodings))?; | ||
let last_hidden_states = outputs.get_last_hidden_state().unwrap(); | ||
Ok(self | ||
.postprocess(last_hidden_states, encodings, options)? | ||
.i(0)?) | ||
} | ||
|
||
/// Generates sentence embeddings for a batch of input texts. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `inputs` - The input texts. | ||
/// * `options` - Optional feature extraction options. | ||
/// | ||
/// # Returns | ||
/// | ||
/// A tensor containing the sentence embeddings. | ||
pub fn run_batch<I: Into<String>>( | ||
&self, | ||
inputs: Vec<I>, | ||
options: Option<FeatureExtractionOptions>, | ||
) -> Result<Tensor> { | ||
let options = options.unwrap_or_default(); | ||
let inputs: Vec<String> = inputs.into_iter().map(|x| x.into()).collect(); | ||
let encodings = self.preprocess(inputs)?; | ||
let outputs = self.model.forward(ForwardParams::from(&encodings))?; | ||
let last_hidden_states = outputs.get_last_hidden_state().unwrap(); | ||
self.postprocess(last_hidden_states, encodings, options) | ||
} | ||
} | ||
|
||
/// Computes the mean pooling of the outputs. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `outputs` - The model outputs. | ||
/// * `attention_mask` - The attention mask. | ||
/// | ||
/// # Returns | ||
/// | ||
/// The mean pooled tensor. | ||
fn mean_pooling(outputs: &Tensor, attention_mask: &Tensor) -> Result<Tensor> { | ||
let input_mask_expanded = attention_mask | ||
.unsqueeze(D::Minus1)? | ||
.repeat((1, 1, outputs.dims3()?.2))? | ||
.to_dtype(outputs.dtype())?; | ||
Ok(outputs | ||
.mul(&input_mask_expanded)? | ||
.sum(1)? | ||
.broadcast_div(&input_mask_expanded.sum(1)?.clamp(f64::MIN, f64::MAX)?)?) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters