Skip to content

Commit

Permalink
[Bert] Feature: Custom Model Outputs (#31)
Browse files Browse the repository at this point in the history
* [Bert] Feature: Custom Model Outputs

* Make the with_pooling_layer flag optional, since it isn't present in the original model
  • Loading branch information
bkonkle committed May 8, 2024
1 parent 00d5dfa commit 14ae737
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 40 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ Cargo.lock
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
.DS_Store

# direnv files
.envrc

# Editor files
.vscode
15 changes: 8 additions & 7 deletions bert-burn/examples/infer-embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ pub fn launch<B: Backend>(device: B::Device) {

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id.clone(),
model_config.pad_token_id,
));

// Batch the input samples to max sequence length with padding
let batcher = Arc::new(BertInputBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
model_config.max_seq_len.unwrap().clone(),
model_config.max_seq_len.unwrap(),
));

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
Expand All @@ -75,11 +75,12 @@ pub fn launch<B: Backend>(device: B::Device) {
let cls_token_idx = 0;

// Embedding size
let d_model = model_config.hidden_size.clone();
let sentence_embedding =
output
.clone()
.slice([0..batch_size, cls_token_idx..cls_token_idx + 1, 0..d_model]);
let d_model = model_config.hidden_size;
let sentence_embedding = output.hidden_states.clone().slice([
0..batch_size,
cls_token_idx..cls_token_idx + 1,
0..d_model,
]);

let sentence_embedding: Tensor<B, 2> = sentence_embedding.squeeze(1);
println!(
Expand Down
4 changes: 2 additions & 2 deletions bert-burn/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub struct BertEmbeddingsConfig {

#[derive(Module, Debug)]
pub struct BertEmbeddings<B: Backend> {
pub pad_token_idx: usize,
word_embeddings: Embedding<B>,
position_embeddings: Embedding<B>,
token_type_embeddings: Embedding<B>,
layer_norm: LayerNorm<B>,
dropout: Dropout,
max_position_embeddings: usize,
pad_token_idx: usize,
}

impl BertEmbeddingsConfig {
Expand Down Expand Up @@ -89,7 +89,7 @@ impl<B: Backend> BertEmbeddings<B> {
)
.reshape([1, seq_length]);
position_ids_tensor =
position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx.clone() as i32);
position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx as i32);
}

let position_embeddings = self.position_embeddings.forward(position_ids_tensor);
Expand Down
1 change: 1 addition & 0 deletions bert-burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pub mod data;
mod embedding;
pub mod loader;
pub mod model;
pub mod pooler;
54 changes: 36 additions & 18 deletions bert-burn/src/loader.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// This file contains logic to load the BERT Model from the safetensor format available on Hugging Face Hub.
// Some utility functions are referenced from: https://github.com/tvergho/sentence-transformers-burn/tree/main

use crate::model::BertModelConfig;

use crate::embedding::BertEmbeddingsRecord;
use crate::model::BertModelConfig;
use crate::pooler::PoolerRecord;
use burn::config::Config;
use burn::module::{ConstantRecord, Param};
use burn::nn::attention::MultiHeadAttentionRecord;
Expand Down Expand Up @@ -142,6 +142,7 @@ fn load_attention_layer_safetensor<B: Backend>(
attention_record
}

/// Load the BERT encoder from the safetensor format available on Hugging Face Hub
pub fn load_encoder_from_safetensors<B: Backend>(
encoder_tensors: HashMap<String, CandleTensor>,
device: &B::Device,
Expand All @@ -151,15 +152,13 @@ pub fn load_encoder_from_safetensors<B: Backend>(
let mut layers: HashMap<usize, HashMap<String, CandleTensor>> = HashMap::new();

for (key, value) in encoder_tensors.iter() {
let layer_number = key.split(".").collect::<Vec<&str>>()[2]
let layer_number = key.split('.').collect::<Vec<&str>>()[2]
.parse::<usize>()
.unwrap();
if !layers.contains_key(&layer_number) {
layers.insert(layer_number, HashMap::new());
}

layers
.get_mut(&layer_number)
.unwrap()
.entry(layer_number)
.or_default()
.insert(key.to_string(), value.clone());
}

Expand Down Expand Up @@ -241,6 +240,7 @@ fn load_embedding_safetensor<B: Backend>(
embedding
}

/// Load the BERT embeddings from the safetensor format available on Hugging Face Hub
pub fn load_embeddings_from_safetensors<B: Backend>(
embedding_tensors: HashMap<String, CandleTensor>,
device: &B::Device,
Expand Down Expand Up @@ -282,10 +282,24 @@ pub fn load_embeddings_from_safetensors<B: Backend>(
max_position_embeddings: ConstantRecord::new(),
pad_token_idx: ConstantRecord::new(),
};

embeddings_record
}

/// Load the BERT pooler from the safetensor format available on Hugging Face Hub
pub fn load_pooler_from_safetensors<B: Backend>(
pooler_tensors: HashMap<String, CandleTensor>,
device: &B::Device,
) -> PoolerRecord<B> {
let output = load_linear_safetensor(
&pooler_tensors["pooler.dense.bias"],
&pooler_tensors["pooler.dense.weight"],
device,
);

PoolerRecord { output }
}

/// Load the BERT model config from the JSON format available on Hugging Face Hub
pub fn load_model_config(path: PathBuf) -> BertModelConfig {
let mut model_config = BertModelConfig::load(path).expect("Config file present");
model_config.max_seq_len = Some(512);
Expand All @@ -299,15 +313,19 @@ pub async fn download_hf_model(model_name: &str) -> (PathBuf, PathBuf) {
let api = hf_hub::api::tokio::Api::new().unwrap();
let repo = api.model(model_name.to_string());

let model_filepath = repo.get("model.safetensors").await.expect(&format!(
"Failed to download: {} weights with name: model.safetensors from HuggingFace Hub",
model_name
));

let config_filepath = repo.get("config.json").await.expect(&format!(
"Failed to download: {} config with name: config.json from HuggingFace Hub",
model_name
));
let model_filepath = repo.get("model.safetensors").await.unwrap_or_else(|_| {
panic!(
"Failed to download: {} weights with name: model.safetensors from HuggingFace Hub",
model_name
)
});

let config_filepath = repo.get("config.json").await.unwrap_or_else(|_| {
panic!(
"Failed to download: {} config with name: config.json from HuggingFace Hub",
model_name
)
});

(config_filepath, model_filepath)
}
75 changes: 62 additions & 13 deletions bert-burn/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::data::BertInferenceBatch;
use crate::embedding::{BertEmbeddings, BertEmbeddingsConfig};
use crate::loader::{load_embeddings_from_safetensors, load_encoder_from_safetensors};
use crate::loader::{
load_embeddings_from_safetensors, load_encoder_from_safetensors, load_pooler_from_safetensors,
};
use crate::pooler::{Pooler, PoolerConfig};
use burn::config::Config;
use burn::module::Module;
use burn::nn::transformer::{
Expand Down Expand Up @@ -40,19 +43,50 @@ pub struct BertModelConfig {
pub pad_token_id: usize,
/// Maximum sequence length for the tokenizer
pub max_seq_len: Option<usize>,
/// Whether to add a pooling layer to the model
pub with_pooling_layer: Option<bool>,
}

// Define the Bert model structure
#[derive(Module, Debug)]
pub struct BertModel<B: Backend> {
pub embeddings: BertEmbeddings<B>,
pub encoder: TransformerEncoder<B>,
pub pooler: Option<Pooler<B>>,
}

#[derive(Debug, Clone)]
pub struct BertModelOutput<B: Backend> {
pub hidden_states: Tensor<B, 3>,
pub pooled_output: Option<Tensor<B, 3>>,
}

impl BertModelConfig {
/// Initializes a Bert model with default weights
pub fn init<B: Backend>(&self, device: &B::Device) -> BertModel<B> {
let embeddings = BertEmbeddingsConfig {
let embeddings = self.get_embeddings_config().init(device);
let encoder = self.get_encoder_config().init(device);

let pooler = if self.with_pooling_layer.unwrap_or(false) {
Some(
PoolerConfig {
hidden_size: self.hidden_size,
}
.init(device),
)
} else {
None
};

BertModel {
embeddings,
encoder,
pooler,
}
}

fn get_embeddings_config(&self) -> BertEmbeddingsConfig {
BertEmbeddingsConfig {
vocab_size: self.vocab_size,
max_position_embeddings: self.max_position_embeddings,
type_vocab_size: self.type_vocab_size,
Expand All @@ -61,9 +95,10 @@ impl BertModelConfig {
layer_norm_eps: self.layer_norm_eps,
pad_token_idx: self.pad_token_id,
}
.init(device);
}

let encoder = TransformerEncoderConfig {
fn get_encoder_config(&self) -> TransformerEncoderConfig {
TransformerEncoderConfig {
n_heads: self.num_attention_heads,
n_layers: self.num_hidden_layers,
d_model: self.hidden_size,
Expand All @@ -76,26 +111,29 @@ impl BertModelConfig {
fan_out_only: false,
},
}
.init(device);

BertModel {
embeddings,
encoder,
}
}
}

impl<B: Backend> BertModel<B> {
/// Defines forward pass
pub fn forward(&self, input: BertInferenceBatch<B>) -> Tensor<B, 3> {
pub fn forward(&self, input: BertInferenceBatch<B>) -> BertModelOutput<B> {
let embedding = self.embeddings.forward(input.clone());
let device = &self.embeddings.devices()[0];

let mask_pad = input.mask_pad.to_device(device);

let encoder_input = TransformerEncoderInput::new(embedding).mask_pad(mask_pad);
let output = self.encoder.forward(encoder_input);
output
let hidden_states = self.encoder.forward(encoder_input);

let pooled_output = self
.pooler
.as_ref()
.map(|pooler| pooler.forward(hidden_states.clone()));

BertModelOutput {
hidden_states,
pooled_output,
}
}

pub fn from_safetensors(
Expand All @@ -117,6 +155,7 @@ impl<B: Backend> BertModel<B> {
// We need to extract both.
let mut encoder_layers: HashMap<String, CandleTensor> = HashMap::new();
let mut embeddings_layers: HashMap<String, CandleTensor> = HashMap::new();
let mut pooler_layers: HashMap<String, CandleTensor> = HashMap::new();

for (key, value) in weights.iter() {
// If model name prefix present in keys, remove it to load keys consistently
Expand All @@ -129,14 +168,24 @@ impl<B: Backend> BertModel<B> {
encoder_layers.insert(key_without_prefix, value.clone());
} else if key_without_prefix.starts_with("embeddings.") {
embeddings_layers.insert(key_without_prefix, value.clone());
} else if key_without_prefix.starts_with("pooler.") {
pooler_layers.insert(key_without_prefix, value.clone());
}
}

let embeddings_record = load_embeddings_from_safetensors::<B>(embeddings_layers, device);
let encoder_record = load_encoder_from_safetensors::<B>(encoder_layers, device);

let pooler_record = if config.with_pooling_layer.unwrap_or(false) {
Some(load_pooler_from_safetensors::<B>(pooler_layers, device))
} else {
None
};

let model_record = BertModelRecord {
embeddings: embeddings_record,
encoder: encoder_record,
pooler: pooler_record,
};
model_record
}
Expand Down
41 changes: 41 additions & 0 deletions bert-burn/src/pooler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use burn::{
config::Config,
module::Module,
nn::{Linear, LinearConfig},
tensor::{backend::Backend, Tensor},
};
use derive_new::new;

/// Pooler
#[derive(Module, Debug, new)]
pub struct Pooler<B: Backend> {
/// Linear output
output: Linear<B>,
}

impl<B: Backend> Pooler<B> {
/// Forward pass
pub fn forward(&self, encoder_output: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch_size, _, _] = encoder_output.dims();

self.output
.forward(encoder_output.slice([0..batch_size, 0..1]))
.tanh()
}
}

/// Pooler Configuration
#[derive(Config)]
pub struct PoolerConfig {
/// Hidden size
pub hidden_size: usize,
}

impl PoolerConfig {
/// Initialize a new Pooler module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Pooler<B> {
let output = LinearConfig::new(self.hidden_size, self.hidden_size).init(device);

Pooler::new(output)
}
}

0 comments on commit 14ae737

Please sign in to comment.