Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bert] Feature: Custom Model Outputs #31

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ Cargo.lock
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
.DS_Store

# direnv files
.envrc

# Editor files
.vscode
15 changes: 8 additions & 7 deletions bert-burn/examples/infer-embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ pub fn launch<B: Backend>(device: B::Device) {

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id.clone(),
model_config.pad_token_id,
));

// Batch the input samples to max sequence length with padding
let batcher = Arc::new(BertInputBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
model_config.max_seq_len.unwrap().clone(),
model_config.max_seq_len.unwrap(),
));

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
Expand All @@ -75,11 +75,12 @@ pub fn launch<B: Backend>(device: B::Device) {
let cls_token_idx = 0;

// Embedding size
let d_model = model_config.hidden_size.clone();
let sentence_embedding =
output
.clone()
.slice([0..batch_size, cls_token_idx..cls_token_idx + 1, 0..d_model]);
let d_model = model_config.hidden_size;
let sentence_embedding = output.hidden_states.clone().slice([
0..batch_size,
cls_token_idx..cls_token_idx + 1,
0..d_model,
]);

let sentence_embedding: Tensor<B, 2> = sentence_embedding.squeeze(1);
println!(
Expand Down
4 changes: 2 additions & 2 deletions bert-burn/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ pub struct BertEmbeddingsConfig {

#[derive(Module, Debug)]
pub struct BertEmbeddings<B: Backend> {
pub pad_token_idx: usize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason why this is now public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is specifically so that I can use it here, in the batcher for my text classification pipeline (and later for the token classification pipeline): https://github.com/bkonkle/burn-transformers/blob/0.1.0/src/pipelines/text_classification/batcher.rs#L72

word_embeddings: Embedding<B>,
position_embeddings: Embedding<B>,
token_type_embeddings: Embedding<B>,
layer_norm: LayerNorm<B>,
dropout: Dropout,
max_position_embeddings: usize,
pad_token_idx: usize,
}

impl BertEmbeddingsConfig {
Expand Down Expand Up @@ -89,7 +89,7 @@ impl<B: Backend> BertEmbeddings<B> {
)
.reshape([1, seq_length]);
position_ids_tensor =
position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx.clone() as i32);
position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx as i32);
}

let position_embeddings = self.position_embeddings.forward(position_ids_tensor);
Expand Down
1 change: 1 addition & 0 deletions bert-burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pub mod data;
mod embedding;
pub mod loader;
pub mod model;
pub mod pooler;
54 changes: 36 additions & 18 deletions bert-burn/src/loader.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// This file contains logic to load the BERT Model from the safetensor format available on Hugging Face Hub.
// Some utility functions are referenced from: https://github.com/tvergho/sentence-transformers-burn/tree/main

use crate::model::BertModelConfig;

use crate::embedding::BertEmbeddingsRecord;
use crate::model::BertModelConfig;
use crate::pooler::PoolerRecord;
use burn::config::Config;
use burn::module::{ConstantRecord, Param};
use burn::nn::attention::MultiHeadAttentionRecord;
Expand Down Expand Up @@ -142,6 +142,7 @@ fn load_attention_layer_safetensor<B: Backend>(
attention_record
}

/// Load the BERT encoder from the safetensor format available on Hugging Face Hub
pub fn load_encoder_from_safetensors<B: Backend>(
encoder_tensors: HashMap<String, CandleTensor>,
device: &B::Device,
Expand All @@ -151,15 +152,13 @@ pub fn load_encoder_from_safetensors<B: Backend>(
let mut layers: HashMap<usize, HashMap<String, CandleTensor>> = HashMap::new();

for (key, value) in encoder_tensors.iter() {
let layer_number = key.split(".").collect::<Vec<&str>>()[2]
let layer_number = key.split('.').collect::<Vec<&str>>()[2]
.parse::<usize>()
.unwrap();
if !layers.contains_key(&layer_number) {
layers.insert(layer_number, HashMap::new());
}

layers
.get_mut(&layer_number)
.unwrap()
.entry(layer_number)
.or_default()
.insert(key.to_string(), value.clone());
}

Expand Down Expand Up @@ -241,6 +240,7 @@ fn load_embedding_safetensor<B: Backend>(
embedding
}

/// Load the BERT embeddings from the safetensor format available on Hugging Face Hub
pub fn load_embeddings_from_safetensors<B: Backend>(
embedding_tensors: HashMap<String, CandleTensor>,
device: &B::Device,
Expand Down Expand Up @@ -282,10 +282,24 @@ pub fn load_embeddings_from_safetensors<B: Backend>(
max_position_embeddings: ConstantRecord::new(),
pad_token_idx: ConstantRecord::new(),
};

embeddings_record
}

/// Load the BERT pooler from the safetensor format available on Hugging Face Hub
pub fn load_pooler_from_safetensors<B: Backend>(
pooler_tensors: HashMap<String, CandleTensor>,
device: &B::Device,
) -> PoolerRecord<B> {
let output = load_linear_safetensor(
&pooler_tensors["pooler.dense.bias"],
&pooler_tensors["pooler.dense.weight"],
device,
);

PoolerRecord { output }
}

/// Load the BERT model config from the JSON format available on Hugging Face Hub
pub fn load_model_config(path: PathBuf) -> BertModelConfig {
let mut model_config = BertModelConfig::load(path).expect("Config file present");
model_config.max_seq_len = Some(512);
Expand All @@ -299,15 +313,19 @@ pub async fn download_hf_model(model_name: &str) -> (PathBuf, PathBuf) {
let api = hf_hub::api::tokio::Api::new().unwrap();
let repo = api.model(model_name.to_string());

let model_filepath = repo.get("model.safetensors").await.expect(&format!(
"Failed to download: {} weights with name: model.safetensors from HuggingFace Hub",
model_name
));

let config_filepath = repo.get("config.json").await.expect(&format!(
"Failed to download: {} config with name: config.json from HuggingFace Hub",
model_name
));
let model_filepath = repo.get("model.safetensors").await.unwrap_or_else(|_| {
panic!(
"Failed to download: {} weights with name: model.safetensors from HuggingFace Hub",
model_name
)
});

let config_filepath = repo.get("config.json").await.unwrap_or_else(|_| {
panic!(
"Failed to download: {} config with name: config.json from HuggingFace Hub",
model_name
)
});

(config_filepath, model_filepath)
}
75 changes: 62 additions & 13 deletions bert-burn/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::data::BertInferenceBatch;
use crate::embedding::{BertEmbeddings, BertEmbeddingsConfig};
use crate::loader::{load_embeddings_from_safetensors, load_encoder_from_safetensors};
use crate::loader::{
load_embeddings_from_safetensors, load_encoder_from_safetensors, load_pooler_from_safetensors,
};
use crate::pooler::{Pooler, PoolerConfig};
use burn::config::Config;
use burn::module::Module;
use burn::nn::transformer::{
Expand Down Expand Up @@ -40,19 +43,50 @@ pub struct BertModelConfig {
pub pad_token_id: usize,
/// Maximum sequence length for the tokenizer
pub max_seq_len: Option<usize>,
/// Whether to add a pooling layer to the model
pub with_pooling_layer: Option<bool>,
}

// Define the Bert model structure
#[derive(Module, Debug)]
pub struct BertModel<B: Backend> {
pub embeddings: BertEmbeddings<B>,
pub encoder: TransformerEncoder<B>,
pub pooler: Option<Pooler<B>>,
}

#[derive(Debug, Clone)]
pub struct BertModelOutput<B: Backend> {
pub hidden_states: Tensor<B, 3>,
pub pooled_output: Option<Tensor<B, 3>>,
}

impl BertModelConfig {
/// Initializes a Bert model with default weights
pub fn init<B: Backend>(&self, device: &B::Device) -> BertModel<B> {
let embeddings = BertEmbeddingsConfig {
let embeddings = self.get_embeddings_config().init(device);
let encoder = self.get_encoder_config().init(device);

let pooler = if self.with_pooling_layer.unwrap_or(false) {
Some(
PoolerConfig {
hidden_size: self.hidden_size,
}
.init(device),
)
} else {
None
};

BertModel {
embeddings,
encoder,
pooler,
}
}

fn get_embeddings_config(&self) -> BertEmbeddingsConfig {
BertEmbeddingsConfig {
vocab_size: self.vocab_size,
max_position_embeddings: self.max_position_embeddings,
type_vocab_size: self.type_vocab_size,
Expand All @@ -61,9 +95,10 @@ impl BertModelConfig {
layer_norm_eps: self.layer_norm_eps,
pad_token_idx: self.pad_token_id,
}
.init(device);
}

let encoder = TransformerEncoderConfig {
fn get_encoder_config(&self) -> TransformerEncoderConfig {
TransformerEncoderConfig {
n_heads: self.num_attention_heads,
n_layers: self.num_hidden_layers,
d_model: self.hidden_size,
Expand All @@ -76,26 +111,29 @@ impl BertModelConfig {
fan_out_only: false,
},
}
.init(device);

BertModel {
embeddings,
encoder,
}
}
}

impl<B: Backend> BertModel<B> {
/// Defines forward pass
pub fn forward(&self, input: BertInferenceBatch<B>) -> Tensor<B, 3> {
pub fn forward(&self, input: BertInferenceBatch<B>) -> BertModelOutput<B> {
let embedding = self.embeddings.forward(input.clone());
let device = &self.embeddings.devices()[0];

let mask_pad = input.mask_pad.to_device(device);

let encoder_input = TransformerEncoderInput::new(embedding).mask_pad(mask_pad);
let output = self.encoder.forward(encoder_input);
output
let hidden_states = self.encoder.forward(encoder_input);

let pooled_output = self
.pooler
.as_ref()
.map(|pooler| pooler.forward(hidden_states.clone()));

BertModelOutput {
hidden_states,
pooled_output,
}
}

pub fn from_safetensors(
Expand All @@ -117,6 +155,7 @@ impl<B: Backend> BertModel<B> {
// We need to extract both.
let mut encoder_layers: HashMap<String, CandleTensor> = HashMap::new();
let mut embeddings_layers: HashMap<String, CandleTensor> = HashMap::new();
let mut pooler_layers: HashMap<String, CandleTensor> = HashMap::new();

for (key, value) in weights.iter() {
// If model name prefix present in keys, remove it to load keys consistently
Expand All @@ -129,14 +168,24 @@ impl<B: Backend> BertModel<B> {
encoder_layers.insert(key_without_prefix, value.clone());
} else if key_without_prefix.starts_with("embeddings.") {
embeddings_layers.insert(key_without_prefix, value.clone());
} else if key_without_prefix.starts_with("pooler.") {
pooler_layers.insert(key_without_prefix, value.clone());
}
}

let embeddings_record = load_embeddings_from_safetensors::<B>(embeddings_layers, device);
let encoder_record = load_encoder_from_safetensors::<B>(encoder_layers, device);

let pooler_record = if config.with_pooling_layer.unwrap_or(false) {
Some(load_pooler_from_safetensors::<B>(pooler_layers, device))
} else {
None
};

let model_record = BertModelRecord {
embeddings: embeddings_record,
encoder: encoder_record,
pooler: pooler_record,
};
model_record
}
Expand Down
41 changes: 41 additions & 0 deletions bert-burn/src/pooler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use burn::{
config::Config,
module::Module,
nn::{Linear, LinearConfig},
tensor::{backend::Backend, Tensor},
};
use derive_new::new;

/// Pooler
#[derive(Module, Debug, new)]
pub struct Pooler<B: Backend> {
/// Linear output
output: Linear<B>,
}

impl<B: Backend> Pooler<B> {
/// Forward pass
pub fn forward(&self, encoder_output: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch_size, _, _] = encoder_output.dims();

self.output
.forward(encoder_output.slice([0..batch_size, 0..1]))
.tanh()
}
}

/// Pooler Configuration
#[derive(Config)]
pub struct PoolerConfig {
/// Hidden size
pub hidden_size: usize,
}

impl PoolerConfig {
/// Initialize a new Pooler module.
pub fn init<B: Backend>(&self, device: &B::Device) -> Pooler<B> {
let output = LinearConfig::new(self.hidden_size, self.hidden_size).init(device);

Pooler::new(output)
}
}