Skip to content

Commit

Permalink
[Bert] Upgrade to Burn 0.13.0 (#27)
Browse files Browse the repository at this point in the history
* [Bert] Upgrade to Burn 0.13.0

* Convert Tensor::arrange arguments to i64 values

* Use Param::from_tensor in the loaders

* Use a consistent import style everywhere
  • Loading branch information
bkonkle committed Apr 25, 2024
1 parent 0f9c6c6 commit 29fadae
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 93 deletions.
8 changes: 4 additions & 4 deletions bert-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["Aasheesh Singh [email protected]"]
license = "MIT OR Apache-2.0"
name = "bert-burn"
version="0.1.0"
version = "0.1.0"
edition = "2021"

[features]
Expand All @@ -19,14 +19,14 @@ safetensors = ["candle-core/default"]

[dependencies]
# Burn
burn = {version = "0.12.1", default-features = false}
candle-core = { version = "0.3.2", optional = true}
burn = { version = "0.13", default-features = false }
candle-core = { version = "0.3.2", optional = true }
# Tokenizer
tokenizers = { version = "0.15.0", default-features = false, features = [
"onig",
"http",
] }
burn-import = "0.12.1"
burn-import = "0.13"
derive-new = "0.6.0"
hf-hub = { version = "0.3.2", features = ["tokio"] }

Expand Down
7 changes: 5 additions & 2 deletions bert-burn/examples/infer-embedding.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::loader::{download_hf_model, load_model_config};
use bert_burn::model::BertModel;
use bert_burn::model::{BertModel, BertModelRecord};
use burn::data::dataloader::batcher::Batcher;
use burn::module::Module;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use std::env;
Expand Down Expand Up @@ -46,9 +47,11 @@ pub fn launch<B: Backend>(device: B::Device) {
let (config_file, model_file) = download_hf_model(model_variant);
let model_config = load_model_config(config_file);

let model: BertModel<B> =
let model_record: BertModelRecord<B> =
BertModel::from_safetensors(model_file, &device, model_config.clone());

let model = model_config.init(&device).load_record(model_record);

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id.clone(),
Expand Down
9 changes: 4 additions & 5 deletions bert-burn/src/data/batcher.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use super::tokenizer::Tokenizer;
use burn::{
data::dataloader::batcher::Batcher,
nn::attention::generate_padding_mask,
tensor::{backend::Backend, Bool, Int, Tensor},
};
use burn::data::dataloader::batcher::Batcher;
use burn::nn::attention::generate_padding_mask;
use burn::tensor::backend::Backend;
use burn::tensor::{Bool, Int, Tensor};
use std::sync::Arc;

#[derive(new)]
Expand Down
31 changes: 3 additions & 28 deletions bert-burn/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,32 +51,6 @@ impl BertEmbeddingsConfig {
pad_token_idx: self.pad_token_idx,
}
}

/// Initializes BertEmbeddings with provided weights
pub fn init_with<B: Backend>(&self, record: BertEmbeddingsRecord<B>) -> BertEmbeddings<B> {
let word_embeddings = EmbeddingConfig::new(self.vocab_size, self.hidden_size)
.init_with(record.word_embeddings);
let position_embeddings =
EmbeddingConfig::new(self.max_position_embeddings, self.hidden_size)
.init_with(record.position_embeddings);
let token_type_embeddings = EmbeddingConfig::new(self.type_vocab_size, self.hidden_size)
.init_with(record.token_type_embeddings);
let layer_norm = LayerNormConfig::new(self.hidden_size)
.with_epsilon(self.layer_norm_eps)
.init_with(record.layer_norm);

let dropout = DropoutConfig::new(self.hidden_dropout_prob).init();

BertEmbeddings {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
max_position_embeddings: self.max_position_embeddings,
pad_token_idx: self.pad_token_idx,
}
}
}

impl<B: Backend> BertEmbeddings<B> {
Expand All @@ -102,14 +76,15 @@ impl<B: Backend> BertEmbeddings<B> {

let seq_length = input_shape.dims[1];
let mut position_ids_tensor: Tensor<B, 2, Int> =
Tensor::arange(0..seq_length, device).reshape([1, seq_length]);
Tensor::arange(0..seq_length as i64, device).reshape([1, seq_length]);

if self.max_position_embeddings != 512 {
// RoBERTa use a different scheme than BERT to create position indexes where padding tokens are given
// a fixed positional index. Check: create_position_ids_from_input_ids() in
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/modeling_roberta.py
let position_ids = Tensor::arange(
self.pad_token_idx + 1..seq_length + self.pad_token_idx + 1,
(self.pad_token_idx as i64) + 1
..(seq_length as i64) + (self.pad_token_idx as i64) + 1,
device,
)
.reshape([1, seq_length]);
Expand Down
20 changes: 9 additions & 11 deletions bert-burn/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@ use crate::model::BertModelConfig;

use crate::embedding::BertEmbeddingsRecord;
use burn::config::Config;
use burn::module::{ConstantRecord, Param};
use burn::nn::attention::MultiHeadAttentionRecord;
use burn::nn::transformer::{
PositionWiseFeedForwardRecord, TransformerEncoderLayerRecord, TransformerEncoderRecord,
};
use burn::{
module::ConstantRecord,
nn::LayerNormRecord,
nn::{EmbeddingRecord, LinearRecord},
tensor::{backend::Backend, Data, Shape, Tensor},
};
use burn::nn::{EmbeddingRecord, LayerNormRecord, LinearRecord};
use burn::tensor::backend::Backend;
use burn::tensor::{Data, Shape, Tensor};
use candle_core::Tensor as CandleTensor;
use std::collections::HashMap;
use std::path::PathBuf;
Expand Down Expand Up @@ -57,8 +55,8 @@ fn load_layer_norm_safetensor<B: Backend>(
let gamma = load_1d_tensor_from_candle::<B>(weight, device);

let layer_norm_record = LayerNormRecord {
beta: beta.into(),
gamma: gamma.into(),
beta: Param::from_tensor(beta),
gamma: Param::from_tensor(gamma),
epsilon: ConstantRecord::new(),
};
layer_norm_record
Expand All @@ -75,8 +73,8 @@ fn load_linear_safetensor<B: Backend>(
let weight = weight.transpose();

let linear_record = LinearRecord {
weight: weight.into(),
bias: Some(bias.into()),
weight: Param::from_tensor(weight),
bias: Some(Param::from_tensor(bias)),
};
linear_record
}
Expand Down Expand Up @@ -237,7 +235,7 @@ fn load_embedding_safetensor<B: Backend>(
let weight = load_2d_tensor_from_candle(weight, device);

let embedding = EmbeddingRecord {
weight: weight.into(),
weight: Param::from_tensor(weight),
};

embedding
Expand Down
49 changes: 6 additions & 43 deletions bert-burn/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use crate::data::BertInferenceBatch;
use crate::embedding::{BertEmbeddings, BertEmbeddingsConfig};
use crate::loader::{load_embeddings_from_safetensors, load_encoder_from_safetensors};
use burn::config::Config;
use burn::module::Module;
use burn::nn::transformer::{
TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput,
};
use burn::nn::Initializer::KaimingUniform;
use burn::{
config::Config,
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use candle_core::{safetensors, Device, Tensor as CandleTensor};
use std::collections::HashMap;
use std::path::PathBuf;
Expand Down Expand Up @@ -84,40 +83,6 @@ impl BertModelConfig {
encoder,
}
}

/// Initializes a Bert model with provided weights
pub fn init_with<B: Backend>(&self, record: BertModelRecord<B>) -> BertModel<B> {
let embeddings = BertEmbeddingsConfig {
vocab_size: self.vocab_size,
max_position_embeddings: self.max_position_embeddings,
type_vocab_size: self.type_vocab_size,
hidden_size: self.hidden_size,
hidden_dropout_prob: self.hidden_dropout_prob,
layer_norm_eps: self.layer_norm_eps,
pad_token_idx: self.pad_token_id,
}
.init_with(record.embeddings);

let encoder = TransformerEncoderConfig {
n_heads: self.num_attention_heads,
n_layers: self.num_hidden_layers,
d_model: self.hidden_size,
d_ff: self.intermediate_size,
dropout: self.hidden_dropout_prob,
norm_first: true,
quiet_softmax: false,
initializer: KaimingUniform {
gain: 1.0 / libm::sqrt(3.0),
fan_out_only: false,
},
}
.init_with(record.encoder);

BertModel {
encoder,
embeddings,
}
}
}

impl<B: Backend> BertModel<B> {
Expand All @@ -137,7 +102,7 @@ impl<B: Backend> BertModel<B> {
file_path: PathBuf,
device: &B::Device,
config: BertModelConfig,
) -> BertModel<B> {
) -> BertModelRecord<B> {
let model_name = config.model_type.as_str();
let weight_result = safetensors::load::<PathBuf>(file_path, &Device::Cpu);

Expand Down Expand Up @@ -173,8 +138,6 @@ impl<B: Backend> BertModel<B> {
embeddings: embeddings_record,
encoder: encoder_record,
};

let model = config.init_with::<B>(model_record);
model
model_record
}
}

0 comments on commit 29fadae

Please sign in to comment.