Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LM head and masked fill_mask for bert-burn. #34

Merged
merged 2 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions bert-burn/examples/masked.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::fill_mask::fill_mask;
use bert_burn::loader::{download_hf_model, load_model_config};
use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord};
use burn::data::dataloader::batcher::Batcher;
use burn::module::Module;
use burn::tensor::backend::Backend;
use std::env;
use std::sync::Arc;

#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;

pub fn launch<B: Backend>(device: B::Device) {
let args: Vec<String> = env::args().collect();
let default_model = "roberta-base".to_string();
let model_variant = if args.len() > 1 {
// Use the argument provided by the user
// Possible values: "bert-base-uncased", "roberta-large" etc.
&args[1]
} else {
// Use the default value if no argument is provided
&default_model
};

println!("Model variant: {}", model_variant);

let text_samples = vec![
"Paris is the <mask> of France.".to_string(),
"The goal of life is <mask>.".to_string(),
];

let (config_file, model_file) = download_hf_model(model_variant);
let model_config = load_model_config(config_file);

let model_record: BertMaskedLMRecord<B> =
BertMaskedLM::from_safetensors(model_file, &device, model_config.clone());

let model = model_config
.init_with_lm_head(&device)
.load_record(model_record);

let tokenizer = Arc::new(BertTokenizer::new(
model_variant.to_string(),
model_config.pad_token_id,
));

// Batch the input samples to max sequence length with padding
let batcher = Arc::new(BertInputBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
model_config.max_seq_len.unwrap(),
));

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
let input = batcher.batch(text_samples.clone());
let [batch_size, _seq_len] = input.tokens.dims();
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape());

let output = fill_mask(&model, &model_config, tokenizer.as_ref(), input);

for i in 0..batch_size {
let input = &text_samples[i];
let result = &output[i];
println!("Input: {}", input);
for fill_mask_result in result.iter() {
let mask_idx = fill_mask_result.mask_idx;
let top_k = &fill_mask_result.top_k;
for (k, (score, token)) in top_k.iter().enumerate() {
println!(
"Top {} Prediction for {}: {} (Score: {:.4})",
k + 1,
mask_idx,
token,
score
);
}
}
}
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<LibTorch<ElemType>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use crate::{launch, ElemType};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::{launch, ElemType};
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};

pub fn run() {
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
}
94 changes: 94 additions & 0 deletions bert-burn/src/fill_mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use crate::{
data::Tokenizer,
data::{BertInferenceBatch, BertTokenizer},
model::BertMaskedLM,
model::BertModelConfig,
};
use burn::tensor::{activation::softmax, backend::Backend, Data, Element, Tensor};

type TokenType = usize;
const MASK_TOKEN_ID: TokenType = 50264;

#[derive(Debug, Clone)]
pub struct FillMaskResult {
pub mask_idx: usize,
pub top_k: Vec<(f32, String)>,
}

pub fn fill_mask<B: Backend>(
model: &BertMaskedLM<B>,
model_config: &BertModelConfig,
tokenizer: &BertTokenizer,
input: BertInferenceBatch<B>,
) -> Vec<Vec<FillMaskResult>> {
let [batch_size, seq_len] = input.tokens.dims();
let output = model.forward(input.clone());

let mut results = vec![];

// Embedding size
let d_model = model_config.vocab_size.clone();
for i in 0..batch_size {
let mut batch_results = vec![];
let input_tokens = input
.tokens
.clone()
.slice([i..i + 1, 0..seq_len])
.squeeze(0)
.to_data();
// Find the mask tokens in the input, as a list of indices
let masks = find_masks(&input_tokens, MASK_TOKEN_ID);
for mask in masks {
let logits = output
.clone()
.slice([i..i + 1, mask..(mask + 1), 0..d_model])
.squeeze::<2>(0)
.squeeze(0);
// Find the top k tokens with the highest probabilities
let top_k = top_k(5, logits);
batch_results.push(FillMaskResult {
mask_idx: mask,
top_k: top_k
.iter()
.map(|(k, prob)| (*prob, tokenizer.decode(&[*k])))
.collect(),
});
}
results.push(batch_results);
}

results
}

fn find_masks<T: Element>(tokens: &Data<T, 1>, mask_token_id: TokenType) -> Vec<usize> {
let mut masks = Vec::new();
for (i, token) in tokens.value.iter().enumerate() {
if token.to_usize() == Some(mask_token_id) {
masks.push(i);
}
}
masks
}

fn data_to_vec_f32<T: Element>(data: &Data<T, 1>) -> Vec<f32> {
data.value.iter().map(|x| x.to_f32().unwrap()).collect()
}

fn data_to_vec_usize<T: Element>(data: &Data<T, 1>) -> Vec<usize> {
data.value.iter().map(|x| x.to_usize().unwrap()).collect()
}

fn top_k<B: Backend>(k: usize, logits: Tensor<B, 1>) -> Vec<(usize, f32)> {
let (pre_soft_probs, indices) = logits.sort_with_indices(0);
let (probabilities, indices) = (
data_to_vec_f32(&softmax(pre_soft_probs, 0).to_data()),
data_to_vec_usize(&indices.to_data()),
);
probabilities
.iter()
.enumerate()
.rev()
.take(k)
.map(|(i, &p)| (indices[i], p))
.collect()
}
1 change: 1 addition & 0 deletions bert-burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extern crate derive_new;

pub mod data;
mod embedding;
pub mod fill_mask;
pub mod loader;
pub mod model;
pub mod pooler;
26 changes: 21 additions & 5 deletions bert-burn/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use candle_core::Tensor as CandleTensor;
use std::collections::HashMap;
use std::path::PathBuf;

fn load_1d_tensor_from_candle<B: Backend>(
pub(crate) fn load_1d_tensor_from_candle<B: Backend>(
tensor: &CandleTensor,
device: &B::Device,
) -> Tensor<B, 1> {
Expand All @@ -29,7 +29,7 @@ fn load_1d_tensor_from_candle<B: Backend>(
weight
}

fn load_2d_tensor_from_candle<B: Backend>(
pub(crate) fn load_2d_tensor_from_candle<B: Backend>(
tensor: &CandleTensor,
device: &B::Device,
) -> Tensor<B, 2> {
Expand All @@ -46,7 +46,7 @@ fn load_2d_tensor_from_candle<B: Backend>(
weight
}

fn load_layer_norm_safetensor<B: Backend>(
pub(crate) fn load_layer_norm_safetensor<B: Backend>(
bias: &CandleTensor,
weight: &CandleTensor,
device: &B::Device,
Expand All @@ -62,7 +62,7 @@ fn load_layer_norm_safetensor<B: Backend>(
layer_norm_record
}

fn load_linear_safetensor<B: Backend>(
pub(crate) fn load_linear_safetensor<B: Backend>(
bias: &CandleTensor,
weight: &CandleTensor,
device: &B::Device,
Expand All @@ -79,7 +79,7 @@ fn load_linear_safetensor<B: Backend>(
linear_record
}

fn load_intermediate_layer_safetensor<B: Backend>(
pub(crate) fn load_intermediate_layer_safetensor<B: Backend>(
linear_inner_weight: &CandleTensor,
linear_inner_bias: &CandleTensor,
linear_outer_weight: &CandleTensor,
Expand Down Expand Up @@ -227,6 +227,22 @@ pub fn load_encoder_from_safetensors<B: Backend>(
encoder_record
}

pub fn load_decoder_from_safetensors<B: Backend>(
bias: &CandleTensor,
word_embedding_weights: &CandleTensor,
device: &B::Device,
) -> LinearRecord<B> {
let bias = load_1d_tensor_from_candle::<B>(bias, device);
let weight = load_2d_tensor_from_candle::<B>(word_embedding_weights, device);
let weight = weight.transpose();

let linear_record = LinearRecord {
weight: Param::from_tensor(weight),
bias: Some(Param::from_tensor(bias)),
};
linear_record
}

fn load_embedding_safetensor<B: Backend>(
weight: &CandleTensor,
device: &B::Device,
Expand Down
Loading