Skip to content

Commit

Permalink
Merge pull request #34 from seurimas/main
Browse files Browse the repository at this point in the history
Add LM head and masked fill_mask for bert-burn.
  • Loading branch information
nathanielsimard committed May 27, 2024
2 parents e2f060f + 8880259 commit be48556
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 7 deletions.
150 changes: 150 additions & 0 deletions bert-burn/examples/masked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::fill_mask::fill_mask;
use bert_burn::loader::{download_hf_model, load_model_config};
use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord};
use burn::data::dataloader::batcher::Batcher;
use burn::module::Module;
use burn::tensor::backend::Backend;
use std::env;
use std::sync::Arc;

#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;

pub fn launch<B: Backend>(device: B::Device) {
let args: Vec<String> = env::args().collect();
let default_model = "roberta-base".to_string();
let model_variant = if args.len() > 1 {
// Use the argument provided by the user
// Possible values: "bert-base-uncased", "roberta-large" etc.
&args[1]
} else {
// Use the default value if no argument is provided
&default_model
};

println!("Model variant: {}", model_variant);

let text_samples = vec![
"Paris is the <mask> of France.".to_string(),
"The goal of life is <mask>.".to_string(),
];

let (config_file, model_file) = download_hf_model(model_variant);
let model_config = load_model_config(config_file);

let model_record: BertMaskedLMRecord<B> =
BertMaskedLM::from_safetensors(model_file, &device, model_config.clone());

let model = model_config
.init_with_lm_head(&device)
.load_record(model_record);

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id,
));

// Batch the input samples to max sequence length with padding
let batcher = Arc::new(BertInputBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
model_config.max_seq_len.unwrap(),
));

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
let input = batcher.batch(text_samples.clone());
let [batch_size, _seq_len] = input.tokens.dims();
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape());

let output = fill_mask(&model, &model_config, tokenizer.as_ref(), input);

for i in 0..batch_size {
let input = &text_samples[i];
let result = &output[i];
println!("Input: {}", input);
for fill_mask_result in result.iter() {
let mask_idx = fill_mask_result.mask_idx;
let top_k = &fill_mask_result.top_k;
for (k, (score, token)) in top_k.iter().enumerate() {
println!(
"Top {} Prediction for {}: {} (Score: {:.4})",
k + 1,
mask_idx,
token,
score
);
}
}
}
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<LibTorch<ElemType>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};

pub fn run() {
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
}
94 changes: 94 additions & 0 deletions bert-burn/src/fill_mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use crate::{
data::Tokenizer,
data::{BertInferenceBatch, BertTokenizer},
model::BertMaskedLM,
model::BertModelConfig,
};
use burn::tensor::{activation::softmax, backend::Backend, Data, Element, Tensor};

type TokenType = usize;
const MASK_TOKEN_ID: TokenType = 50264;

#[derive(Debug, Clone)]
pub struct FillMaskResult {
pub mask_idx: usize,
pub top_k: Vec<(f32, String)>,
}

pub fn fill_mask<B: Backend>(
model: &BertMaskedLM<B>,
model_config: &BertModelConfig,
tokenizer: &BertTokenizer,
input: BertInferenceBatch<B>,
) -> Vec<Vec<FillMaskResult>> {
let [batch_size, seq_len] = input.tokens.dims();
let output = model.forward(input.clone());

let mut results = vec![];

// Embedding size
let d_model = model_config.vocab_size.clone();
for i in 0..batch_size {
let mut batch_results = vec![];
let input_tokens = input
.tokens
.clone()
.slice([i..i + 1, 0..seq_len])
.squeeze(0)
.to_data();
// Find the mask tokens in the input, as a list of indices
let masks = find_masks(&input_tokens, MASK_TOKEN_ID);
for mask in masks {
let logits = output
.clone()
.slice([i..i + 1, mask..(mask + 1), 0..d_model])
.squeeze::<2>(0)
.squeeze(0);
// Find the top k tokens with the highest probabilities
let top_k = top_k(5, logits);
batch_results.push(FillMaskResult {
mask_idx: mask,
top_k: top_k
.iter()
.map(|(k, prob)| (*prob, tokenizer.decode(&[*k])))
.collect(),
});
}
results.push(batch_results);
}

results
}

fn find_masks<T: Element>(tokens: &Data<T, 1>, mask_token_id: TokenType) -> Vec<usize> {
let mut masks = Vec::new();
for (i, token) in tokens.value.iter().enumerate() {
if token.to_usize() == Some(mask_token_id) {
masks.push(i);
}
}
masks
}

fn data_to_vec_f32<T: Element>(data: &Data<T, 1>) -> Vec<f32> {
data.value.iter().map(|x| x.to_f32().unwrap()).collect()
}

fn data_to_vec_usize<T: Element>(data: &Data<T, 1>) -> Vec<usize> {
data.value.iter().map(|x| x.to_usize().unwrap()).collect()
}

fn top_k<B: Backend>(k: usize, logits: Tensor<B, 1>) -> Vec<(usize, f32)> {
let (pre_soft_probs, indices) = logits.sort_with_indices(0);
let (probabilities, indices) = (
data_to_vec_f32(&softmax(pre_soft_probs, 0).to_data()),
data_to_vec_usize(&indices.to_data()),
);
probabilities
.iter()
.enumerate()
.rev()
.take(k)
.map(|(i, &p)| (indices[i], p))
.collect()
}
1 change: 1 addition & 0 deletions bert-burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extern crate derive_new;

pub mod data;
mod embedding;
pub mod fill_mask;
pub mod loader;
pub mod model;
pub mod pooler;
26 changes: 21 additions & 5 deletions bert-burn/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use candle_core::Tensor as CandleTensor;
use std::collections::HashMap;
use std::path::PathBuf;

fn load_1d_tensor_from_candle<B: Backend>(
pub(crate) fn load_1d_tensor_from_candle<B: Backend>(
tensor: &CandleTensor,
device: &B::Device,
) -> Tensor<B, 1> {
Expand All @@ -29,7 +29,7 @@ fn load_1d_tensor_from_candle<B: Backend>(
weight
}

fn load_2d_tensor_from_candle<B: Backend>(
pub(crate) fn load_2d_tensor_from_candle<B: Backend>(
tensor: &CandleTensor,
device: &B::Device,
) -> Tensor<B, 2> {
Expand All @@ -46,7 +46,7 @@ fn load_2d_tensor_from_candle<B: Backend>(
weight
}

fn load_layer_norm_safetensor<B: Backend>(
pub(crate) fn load_layer_norm_safetensor<B: Backend>(
bias: &CandleTensor,
weight: &CandleTensor,
device: &B::Device,
Expand All @@ -62,7 +62,7 @@ fn load_layer_norm_safetensor<B: Backend>(
layer_norm_record
}

fn load_linear_safetensor<B: Backend>(
pub(crate) fn load_linear_safetensor<B: Backend>(
bias: &CandleTensor,
weight: &CandleTensor,
device: &B::Device,
Expand All @@ -79,7 +79,7 @@ fn load_linear_safetensor<B: Backend>(
linear_record
}

fn load_intermediate_layer_safetensor<B: Backend>(
pub(crate) fn load_intermediate_layer_safetensor<B: Backend>(
linear_inner_weight: &CandleTensor,
linear_inner_bias: &CandleTensor,
linear_outer_weight: &CandleTensor,
Expand Down Expand Up @@ -227,6 +227,22 @@ pub fn load_encoder_from_safetensors<B: Backend>(
encoder_record
}

pub fn load_decoder_from_safetensors<B: Backend>(
bias: &CandleTensor,
word_embedding_weights: &CandleTensor,
device: &B::Device,
) -> LinearRecord<B> {
let bias = load_1d_tensor_from_candle::<B>(bias, device);
let weight = load_2d_tensor_from_candle::<B>(word_embedding_weights, device);
let weight = weight.transpose();

let linear_record = LinearRecord {
weight: Param::from_tensor(weight),
bias: Some(Param::from_tensor(bias)),
};
linear_record
}

fn load_embedding_safetensor<B: Backend>(
weight: &CandleTensor,
device: &B::Device,
Expand Down
Loading

0 comments on commit be48556

Please sign in to comment.