diff --git a/.gitignore b/.gitignore index cb9bedb..ca59a89 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,9 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb .DS_Store + +# direnv files +.envrc + +# Editor files +.vscode diff --git a/bert-burn/examples/infer-embedding.rs b/bert-burn/examples/infer-embedding.rs index 524f8ae..978f369 100644 --- a/bert-burn/examples/infer-embedding.rs +++ b/bert-burn/examples/infer-embedding.rs @@ -54,14 +54,14 @@ pub fn launch(device: B::Device) { let tokenizer = Arc::new(BertTokenizer::new( model_variant.to_string(), - model_config.pad_token_id.clone(), + model_config.pad_token_id, )); // Batch the input samples to max sequence length with padding let batcher = Arc::new(BertInputBatcher::::new( tokenizer.clone(), device.clone(), - model_config.max_seq_len.unwrap().clone(), + model_config.max_seq_len.unwrap(), )); // Batch input samples using the batcher Shape: [Batch size, Seq_len] @@ -75,11 +75,12 @@ pub fn launch(device: B::Device) { let cls_token_idx = 0; // Embedding size - let d_model = model_config.hidden_size.clone(); - let sentence_embedding = - output - .clone() - .slice([0..batch_size, cls_token_idx..cls_token_idx + 1, 0..d_model]); + let d_model = model_config.hidden_size; + let sentence_embedding = output.hidden_states.clone().slice([ + 0..batch_size, + cls_token_idx..cls_token_idx + 1, + 0..d_model, + ]); let sentence_embedding: Tensor = sentence_embedding.squeeze(1); println!( diff --git a/bert-burn/src/embedding.rs b/bert-burn/src/embedding.rs index 2d95776..19051fe 100644 --- a/bert-burn/src/embedding.rs +++ b/bert-burn/src/embedding.rs @@ -18,13 +18,13 @@ pub struct BertEmbeddingsConfig { #[derive(Module, Debug)] pub struct BertEmbeddings { + pub pad_token_idx: usize, word_embeddings: Embedding, position_embeddings: Embedding, token_type_embeddings: Embedding, layer_norm: LayerNorm, dropout: Dropout, max_position_embeddings: usize, - pad_token_idx: usize, } impl BertEmbeddingsConfig { @@ -89,7 +89,7 @@ impl BertEmbeddings { ) .reshape([1, seq_length]); position_ids_tensor = - position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx.clone() as i32); + position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx as i32); } let position_embeddings = self.position_embeddings.forward(position_ids_tensor); diff --git a/bert-burn/src/lib.rs b/bert-burn/src/lib.rs index eede88f..1dead18 100644 --- a/bert-burn/src/lib.rs +++ b/bert-burn/src/lib.rs @@ -5,3 +5,4 @@ pub mod data; mod embedding; pub mod loader; pub mod model; +pub mod pooler; diff --git a/bert-burn/src/loader.rs b/bert-burn/src/loader.rs index 60b7c52..03e717e 100644 --- a/bert-burn/src/loader.rs +++ b/bert-burn/src/loader.rs @@ -1,9 +1,9 @@ // This file contains logic to load the BERT Model from the safetensor format available on Hugging Face Hub. // Some utility functions are referenced from: https://github.com/tvergho/sentence-transformers-burn/tree/main -use crate::model::BertModelConfig; - use crate::embedding::BertEmbeddingsRecord; +use crate::model::BertModelConfig; +use crate::pooler::PoolerRecord; use burn::config::Config; use burn::module::{ConstantRecord, Param}; use burn::nn::attention::MultiHeadAttentionRecord; @@ -142,6 +142,7 @@ fn load_attention_layer_safetensor( attention_record } +/// Load the BERT encoder from the safetensor format available on Hugging Face Hub pub fn load_encoder_from_safetensors( encoder_tensors: HashMap, device: &B::Device, @@ -151,15 +152,13 @@ pub fn load_encoder_from_safetensors( let mut layers: HashMap> = HashMap::new(); for (key, value) in encoder_tensors.iter() { - let layer_number = key.split(".").collect::>()[2] + let layer_number = key.split('.').collect::>()[2] .parse::() .unwrap(); - if !layers.contains_key(&layer_number) { - layers.insert(layer_number, HashMap::new()); - } + layers - .get_mut(&layer_number) - .unwrap() + .entry(layer_number) + .or_default() .insert(key.to_string(), value.clone()); } @@ -241,6 +240,7 @@ fn load_embedding_safetensor( embedding } +/// Load the BERT embeddings from the safetensor format available on Hugging Face Hub pub fn load_embeddings_from_safetensors( embedding_tensors: HashMap, device: &B::Device, @@ -282,10 +282,24 @@ pub fn load_embeddings_from_safetensors( max_position_embeddings: ConstantRecord::new(), pad_token_idx: ConstantRecord::new(), }; - embeddings_record } +/// Load the BERT pooler from the safetensor format available on Hugging Face Hub +pub fn load_pooler_from_safetensors( + pooler_tensors: HashMap, + device: &B::Device, +) -> PoolerRecord { + let output = load_linear_safetensor( + &pooler_tensors["pooler.dense.bias"], + &pooler_tensors["pooler.dense.weight"], + device, + ); + + PoolerRecord { output } +} + +/// Load the BERT model config from the JSON format available on Hugging Face Hub pub fn load_model_config(path: PathBuf) -> BertModelConfig { let mut model_config = BertModelConfig::load(path).expect("Config file present"); model_config.max_seq_len = Some(512); @@ -299,15 +313,19 @@ pub async fn download_hf_model(model_name: &str) -> (PathBuf, PathBuf) { let api = hf_hub::api::tokio::Api::new().unwrap(); let repo = api.model(model_name.to_string()); - let model_filepath = repo.get("model.safetensors").await.expect(&format!( - "Failed to download: {} weights with name: model.safetensors from HuggingFace Hub", - model_name - )); - - let config_filepath = repo.get("config.json").await.expect(&format!( - "Failed to download: {} config with name: config.json from HuggingFace Hub", - model_name - )); + let model_filepath = repo.get("model.safetensors").await.unwrap_or_else(|_| { + panic!( + "Failed to download: {} weights with name: model.safetensors from HuggingFace Hub", + model_name + ) + }); + + let config_filepath = repo.get("config.json").await.unwrap_or_else(|_| { + panic!( + "Failed to download: {} config with name: config.json from HuggingFace Hub", + model_name + ) + }); (config_filepath, model_filepath) } diff --git a/bert-burn/src/model.rs b/bert-burn/src/model.rs index b13aaa9..4810d7c 100644 --- a/bert-burn/src/model.rs +++ b/bert-burn/src/model.rs @@ -1,6 +1,9 @@ use crate::data::BertInferenceBatch; use crate::embedding::{BertEmbeddings, BertEmbeddingsConfig}; -use crate::loader::{load_embeddings_from_safetensors, load_encoder_from_safetensors}; +use crate::loader::{ + load_embeddings_from_safetensors, load_encoder_from_safetensors, load_pooler_from_safetensors, +}; +use crate::pooler::{Pooler, PoolerConfig}; use burn::config::Config; use burn::module::Module; use burn::nn::transformer::{ @@ -40,6 +43,8 @@ pub struct BertModelConfig { pub pad_token_id: usize, /// Maximum sequence length for the tokenizer pub max_seq_len: Option, + /// Whether to add a pooling layer to the model + pub with_pooling_layer: Option, } // Define the Bert model structure @@ -47,12 +52,41 @@ pub struct BertModelConfig { pub struct BertModel { pub embeddings: BertEmbeddings, pub encoder: TransformerEncoder, + pub pooler: Option>, +} + +#[derive(Debug, Clone)] +pub struct BertModelOutput { + pub hidden_states: Tensor, + pub pooled_output: Option>, } impl BertModelConfig { /// Initializes a Bert model with default weights pub fn init(&self, device: &B::Device) -> BertModel { - let embeddings = BertEmbeddingsConfig { + let embeddings = self.get_embeddings_config().init(device); + let encoder = self.get_encoder_config().init(device); + + let pooler = if self.with_pooling_layer.unwrap_or(false) { + Some( + PoolerConfig { + hidden_size: self.hidden_size, + } + .init(device), + ) + } else { + None + }; + + BertModel { + embeddings, + encoder, + pooler, + } + } + + fn get_embeddings_config(&self) -> BertEmbeddingsConfig { + BertEmbeddingsConfig { vocab_size: self.vocab_size, max_position_embeddings: self.max_position_embeddings, type_vocab_size: self.type_vocab_size, @@ -61,9 +95,10 @@ impl BertModelConfig { layer_norm_eps: self.layer_norm_eps, pad_token_idx: self.pad_token_id, } - .init(device); + } - let encoder = TransformerEncoderConfig { + fn get_encoder_config(&self) -> TransformerEncoderConfig { + TransformerEncoderConfig { n_heads: self.num_attention_heads, n_layers: self.num_hidden_layers, d_model: self.hidden_size, @@ -76,26 +111,29 @@ impl BertModelConfig { fan_out_only: false, }, } - .init(device); - - BertModel { - embeddings, - encoder, - } } } impl BertModel { /// Defines forward pass - pub fn forward(&self, input: BertInferenceBatch) -> Tensor { + pub fn forward(&self, input: BertInferenceBatch) -> BertModelOutput { let embedding = self.embeddings.forward(input.clone()); let device = &self.embeddings.devices()[0]; let mask_pad = input.mask_pad.to_device(device); let encoder_input = TransformerEncoderInput::new(embedding).mask_pad(mask_pad); - let output = self.encoder.forward(encoder_input); - output + let hidden_states = self.encoder.forward(encoder_input); + + let pooled_output = self + .pooler + .as_ref() + .map(|pooler| pooler.forward(hidden_states.clone())); + + BertModelOutput { + hidden_states, + pooled_output, + } } pub fn from_safetensors( @@ -117,6 +155,7 @@ impl BertModel { // We need to extract both. let mut encoder_layers: HashMap = HashMap::new(); let mut embeddings_layers: HashMap = HashMap::new(); + let mut pooler_layers: HashMap = HashMap::new(); for (key, value) in weights.iter() { // If model name prefix present in keys, remove it to load keys consistently @@ -129,14 +168,24 @@ impl BertModel { encoder_layers.insert(key_without_prefix, value.clone()); } else if key_without_prefix.starts_with("embeddings.") { embeddings_layers.insert(key_without_prefix, value.clone()); + } else if key_without_prefix.starts_with("pooler.") { + pooler_layers.insert(key_without_prefix, value.clone()); } } let embeddings_record = load_embeddings_from_safetensors::(embeddings_layers, device); let encoder_record = load_encoder_from_safetensors::(encoder_layers, device); + + let pooler_record = if config.with_pooling_layer.unwrap_or(false) { + Some(load_pooler_from_safetensors::(pooler_layers, device)) + } else { + None + }; + let model_record = BertModelRecord { embeddings: embeddings_record, encoder: encoder_record, + pooler: pooler_record, }; model_record } diff --git a/bert-burn/src/pooler.rs b/bert-burn/src/pooler.rs new file mode 100644 index 0000000..83a64bc --- /dev/null +++ b/bert-burn/src/pooler.rs @@ -0,0 +1,41 @@ +use burn::{ + config::Config, + module::Module, + nn::{Linear, LinearConfig}, + tensor::{backend::Backend, Tensor}, +}; +use derive_new::new; + +/// Pooler +#[derive(Module, Debug, new)] +pub struct Pooler { + /// Linear output + output: Linear, +} + +impl Pooler { + /// Forward pass + pub fn forward(&self, encoder_output: Tensor) -> Tensor { + let [batch_size, _, _] = encoder_output.dims(); + + self.output + .forward(encoder_output.slice([0..batch_size, 0..1])) + .tanh() + } +} + +/// Pooler Configuration +#[derive(Config)] +pub struct PoolerConfig { + /// Hidden size + pub hidden_size: usize, +} + +impl PoolerConfig { + /// Initialize a new Pooler module. + pub fn init(&self, device: &B::Device) -> Pooler { + let output = LinearConfig::new(self.hidden_size, self.hidden_size).init(device); + + Pooler::new(output) + } +}