Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LM head and masked fill_mask for bert-burn. #34

Merged
merged 2 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions bert-burn/examples/masked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::fill_mask::fill_mask;
use bert_burn::loader::{download_hf_model, load_model_config};
use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord, BertModel, BertModelRecord};
use burn::data::dataloader::batcher::Batcher;
use burn::module::Module;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use std::env;
use std::sync::Arc;

#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;

pub fn launch<B: Backend>(device: B::Device) {
let args: Vec<String> = env::args().collect();
let default_model = "roberta-base".to_string();
let model_variant = if args.len() > 1 {
// Use the argument provided by the user
// Possible values: "bert-base-uncased", "roberta-large" etc.
&args[1]
} else {
// Use the default value if no argument is provided
&default_model
};

println!("Model variant: {}", model_variant);

let text_samples = vec![
"Paris is the <mask> of France.".to_string(),
"The goal of life is <mask>.".to_string(),
];

let (config_file, model_file) = download_hf_model(model_variant);
let model_config = load_model_config(config_file);

let model_record: BertMaskedLMRecord<B> =
BertMaskedLM::from_safetensors(model_file, &device, model_config.clone());

let model = model_config
.init_with_lm_head(&device)
.load_record(model_record);

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id,
));

// Batch the input samples to max sequence length with padding
let batcher = Arc::new(BertInputBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
model_config.max_seq_len.unwrap(),
));

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
let input = batcher.batch(text_samples.clone());
let [batch_size, _seq_len] = input.tokens.dims();
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape());

let output = fill_mask(&model, &model_config, &tokenizer, input);

for i in 0..batch_size {
let input = &text_samples[i];
let result = &output[i];
println!("Input: {}", input);
for (j, fill_mask_result) in result.iter().enumerate() {
let mask_idx = fill_mask_result.mask_idx;
let top_k = &fill_mask_result.top_k;
for (k, (score, token)) in top_k.iter().enumerate() {
println!("Top {} Prediction: {} (Score: {:.4})", k + 1, token, score);
}
}
}
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<LibTorch<ElemType>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};

pub fn run() {
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
}
88 changes: 88 additions & 0 deletions bert-burn/src/fill_mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use std::sync::Arc;

use crate::{
data::Tokenizer,
data::{BertInferenceBatch, BertTokenizer},
model::BertMaskedLM,
model::BertModelConfig,
};
use burn::tensor::{activation::softmax, backend::Backend, Data, Element};

type TokenType = usize;
const MASK_TOKEN_ID: TokenType = 50264;

fn find_masks<T: Element>(tokens: &Data<T, 1>, mask_token_id: TokenType) -> Vec<usize> {
let mut masks = Vec::new();
for (i, token) in tokens.value.iter().enumerate() {
if token.to_usize() == Some(mask_token_id) {
masks.push(i);
}
}
masks
}

fn data_to_vec_f32<T: Element>(data: &Data<T, 1>) -> Vec<f32> {
data.value.iter().map(|x| x.to_f32().unwrap()).collect()
}

fn top_k(k: usize, probabilities: Vec<f32>) -> Vec<(usize, f32)> {
let mut probabilities = probabilities.iter().enumerate().collect::<Vec<_>>();

probabilities.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());

seurimas marked this conversation as resolved.
Show resolved Hide resolved
probabilities.truncate(k);

probabilities.into_iter().map(|(i, &p)| (i, p)).collect()
}

#[derive(Debug, Clone)]
pub struct FillMaskResult {
pub mask_idx: usize,
pub top_k: Vec<(f32, String)>,
}

pub fn fill_mask<B: Backend>(
model: &BertMaskedLM<B>,
model_config: &BertModelConfig,
tokenizer: &Arc<BertTokenizer>,
seurimas marked this conversation as resolved.
Show resolved Hide resolved
input: BertInferenceBatch<B>,
) -> Vec<Vec<FillMaskResult>> {
let [batch_size, seq_len] = input.tokens.dims();
let output = model.forward(input.clone());

let mut results = vec![];

// Embedding size
let d_model = model_config.vocab_size.clone();
for i in 0..batch_size {
let mut batch_results = vec![];
let input_tokens = input
.tokens
.clone()
.slice([i..i + 1, 0..seq_len])
.squeeze(0)
.to_data();
// Find the mask tokens in the input, as a list of indices
let masks = find_masks(&input_tokens, MASK_TOKEN_ID);
for mask in masks {
let logits = output
.clone()
.slice([i..i + 1, mask..(mask + 1), 0..d_model])
.squeeze::<2>(0)
.squeeze(0);
// Find the top k tokens with the highest probabilities
let probs = data_to_vec_f32(&softmax(logits, 0).to_data());
let top_k = top_k(5, probs);
seurimas marked this conversation as resolved.
Show resolved Hide resolved
batch_results.push(FillMaskResult {
mask_idx: mask,
top_k: top_k
.iter()
.map(|(k, prob)| (*prob, tokenizer.decode(&[*k])))
.collect(),
});
}
results.push(batch_results);
}

results
}
1 change: 1 addition & 0 deletions bert-burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extern crate derive_new;

pub mod data;
mod embedding;
pub mod fill_mask;
pub mod loader;
pub mod model;
pub mod pooler;
26 changes: 21 additions & 5 deletions bert-burn/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use candle_core::Tensor as CandleTensor;
use std::collections::HashMap;
use std::path::PathBuf;

fn load_1d_tensor_from_candle<B: Backend>(
pub(crate) fn load_1d_tensor_from_candle<B: Backend>(
tensor: &CandleTensor,
device: &B::Device,
) -> Tensor<B, 1> {
Expand All @@ -29,7 +29,7 @@ fn load_1d_tensor_from_candle<B: Backend>(
weight
}

fn load_2d_tensor_from_candle<B: Backend>(
pub(crate) fn load_2d_tensor_from_candle<B: Backend>(
tensor: &CandleTensor,
device: &B::Device,
) -> Tensor<B, 2> {
Expand All @@ -46,7 +46,7 @@ fn load_2d_tensor_from_candle<B: Backend>(
weight
}

fn load_layer_norm_safetensor<B: Backend>(
pub(crate) fn load_layer_norm_safetensor<B: Backend>(
bias: &CandleTensor,
weight: &CandleTensor,
device: &B::Device,
Expand All @@ -62,7 +62,7 @@ fn load_layer_norm_safetensor<B: Backend>(
layer_norm_record
}

fn load_linear_safetensor<B: Backend>(
pub(crate) fn load_linear_safetensor<B: Backend>(
bias: &CandleTensor,
weight: &CandleTensor,
device: &B::Device,
Expand All @@ -79,7 +79,7 @@ fn load_linear_safetensor<B: Backend>(
linear_record
}

fn load_intermediate_layer_safetensor<B: Backend>(
pub(crate) fn load_intermediate_layer_safetensor<B: Backend>(
linear_inner_weight: &CandleTensor,
linear_inner_bias: &CandleTensor,
linear_outer_weight: &CandleTensor,
Expand Down Expand Up @@ -227,6 +227,22 @@ pub fn load_encoder_from_safetensors<B: Backend>(
encoder_record
}

pub fn load_decoder_from_safetensors<B: Backend>(
bias: &CandleTensor,
word_embedding_weights: &CandleTensor,
device: &B::Device,
) -> LinearRecord<B> {
let bias = load_1d_tensor_from_candle::<B>(bias, device);
let weight = load_2d_tensor_from_candle::<B>(word_embedding_weights, device);
let weight = weight.transpose();

let linear_record = LinearRecord {
weight: Param::from_tensor(weight),
bias: Some(Param::from_tensor(bias)),
};
linear_record
}

fn load_embedding_safetensor<B: Backend>(
weight: &CandleTensor,
device: &B::Device,
Expand Down
Loading