diff --git a/bert-burn/examples/masked.rs b/bert-burn/examples/masked.rs new file mode 100644 index 0000000..5a4b948 --- /dev/null +++ b/bert-burn/examples/masked.rs @@ -0,0 +1,150 @@ +use bert_burn::data::{BertInputBatcher, BertTokenizer}; +use bert_burn::fill_mask::fill_mask; +use bert_burn::loader::{download_hf_model, load_model_config}; +use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord}; +use burn::data::dataloader::batcher::Batcher; +use burn::module::Module; +use burn::tensor::backend::Backend; +use std::env; +use std::sync::Arc; + +#[cfg(not(feature = "f16"))] +#[allow(dead_code)] +type ElemType = f32; +#[cfg(feature = "f16")] +type ElemType = burn::tensor::f16; + +pub fn launch(device: B::Device) { + let args: Vec = env::args().collect(); + let default_model = "roberta-base".to_string(); + let model_variant = if args.len() > 1 { + // Use the argument provided by the user + // Possible values: "bert-base-uncased", "roberta-large" etc. + &args[1] + } else { + // Use the default value if no argument is provided + &default_model + }; + + println!("Model variant: {}", model_variant); + + let text_samples = vec![ + "Paris is the of France.".to_string(), + "The goal of life is .".to_string(), + ]; + + let (config_file, model_file) = download_hf_model(model_variant); + let model_config = load_model_config(config_file); + + let model_record: BertMaskedLMRecord = + BertMaskedLM::from_safetensors(model_file, &device, model_config.clone()); + + let model = model_config + .init_with_lm_head(&device) + .load_record(model_record); + + let tokenizer = Arc::new(BertTokenizer::new( + model_variant.to_string(), + model_config.pad_token_id, + )); + + // Batch the input samples to max sequence length with padding + let batcher = Arc::new(BertInputBatcher::::new( + tokenizer.clone(), + device.clone(), + model_config.max_seq_len.unwrap(), + )); + + // Batch input samples using the batcher Shape: [Batch size, Seq_len] + let input = batcher.batch(text_samples.clone()); + let [batch_size, _seq_len] = input.tokens.dims(); + println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape()); + + let output = fill_mask(&model, &model_config, tokenizer.as_ref(), input); + + for i in 0..batch_size { + let input = &text_samples[i]; + let result = &output[i]; + println!("Input: {}", input); + for fill_mask_result in result.iter() { + let mask_idx = fill_mask_result.mask_idx; + let top_k = &fill_mask_result.top_k; + for (k, (score, token)) in top_k.iter().enumerate() { + println!( + "Top {} Prediction for {}: {} (Score: {:.4})", + k + 1, + mask_idx, + token, + score + ); + } + } + } +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + + use crate::{launch, ElemType}; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use crate::{launch, ElemType}; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use crate::{launch, ElemType}; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::{launch, ElemType}; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + + pub fn run() { + launch::>(WgpuDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); +} diff --git a/bert-burn/src/fill_mask.rs b/bert-burn/src/fill_mask.rs new file mode 100644 index 0000000..4e53820 --- /dev/null +++ b/bert-burn/src/fill_mask.rs @@ -0,0 +1,94 @@ +use crate::{ + data::Tokenizer, + data::{BertInferenceBatch, BertTokenizer}, + model::BertMaskedLM, + model::BertModelConfig, +}; +use burn::tensor::{activation::softmax, backend::Backend, Data, Element, Tensor}; + +type TokenType = usize; +const MASK_TOKEN_ID: TokenType = 50264; + +#[derive(Debug, Clone)] +pub struct FillMaskResult { + pub mask_idx: usize, + pub top_k: Vec<(f32, String)>, +} + +pub fn fill_mask( + model: &BertMaskedLM, + model_config: &BertModelConfig, + tokenizer: &BertTokenizer, + input: BertInferenceBatch, +) -> Vec> { + let [batch_size, seq_len] = input.tokens.dims(); + let output = model.forward(input.clone()); + + let mut results = vec![]; + + // Embedding size + let d_model = model_config.vocab_size.clone(); + for i in 0..batch_size { + let mut batch_results = vec![]; + let input_tokens = input + .tokens + .clone() + .slice([i..i + 1, 0..seq_len]) + .squeeze(0) + .to_data(); + // Find the mask tokens in the input, as a list of indices + let masks = find_masks(&input_tokens, MASK_TOKEN_ID); + for mask in masks { + let logits = output + .clone() + .slice([i..i + 1, mask..(mask + 1), 0..d_model]) + .squeeze::<2>(0) + .squeeze(0); + // Find the top k tokens with the highest probabilities + let top_k = top_k(5, logits); + batch_results.push(FillMaskResult { + mask_idx: mask, + top_k: top_k + .iter() + .map(|(k, prob)| (*prob, tokenizer.decode(&[*k]))) + .collect(), + }); + } + results.push(batch_results); + } + + results +} + +fn find_masks(tokens: &Data, mask_token_id: TokenType) -> Vec { + let mut masks = Vec::new(); + for (i, token) in tokens.value.iter().enumerate() { + if token.to_usize() == Some(mask_token_id) { + masks.push(i); + } + } + masks +} + +fn data_to_vec_f32(data: &Data) -> Vec { + data.value.iter().map(|x| x.to_f32().unwrap()).collect() +} + +fn data_to_vec_usize(data: &Data) -> Vec { + data.value.iter().map(|x| x.to_usize().unwrap()).collect() +} + +fn top_k(k: usize, logits: Tensor) -> Vec<(usize, f32)> { + let (pre_soft_probs, indices) = logits.sort_with_indices(0); + let (probabilities, indices) = ( + data_to_vec_f32(&softmax(pre_soft_probs, 0).to_data()), + data_to_vec_usize(&indices.to_data()), + ); + probabilities + .iter() + .enumerate() + .rev() + .take(k) + .map(|(i, &p)| (indices[i], p)) + .collect() +} diff --git a/bert-burn/src/lib.rs b/bert-burn/src/lib.rs index 1dead18..b384d28 100644 --- a/bert-burn/src/lib.rs +++ b/bert-burn/src/lib.rs @@ -3,6 +3,7 @@ extern crate derive_new; pub mod data; mod embedding; +pub mod fill_mask; pub mod loader; pub mod model; pub mod pooler; diff --git a/bert-burn/src/loader.rs b/bert-burn/src/loader.rs index 03e717e..1416fba 100644 --- a/bert-burn/src/loader.rs +++ b/bert-burn/src/loader.rs @@ -17,7 +17,7 @@ use candle_core::Tensor as CandleTensor; use std::collections::HashMap; use std::path::PathBuf; -fn load_1d_tensor_from_candle( +pub(crate) fn load_1d_tensor_from_candle( tensor: &CandleTensor, device: &B::Device, ) -> Tensor { @@ -29,7 +29,7 @@ fn load_1d_tensor_from_candle( weight } -fn load_2d_tensor_from_candle( +pub(crate) fn load_2d_tensor_from_candle( tensor: &CandleTensor, device: &B::Device, ) -> Tensor { @@ -46,7 +46,7 @@ fn load_2d_tensor_from_candle( weight } -fn load_layer_norm_safetensor( +pub(crate) fn load_layer_norm_safetensor( bias: &CandleTensor, weight: &CandleTensor, device: &B::Device, @@ -62,7 +62,7 @@ fn load_layer_norm_safetensor( layer_norm_record } -fn load_linear_safetensor( +pub(crate) fn load_linear_safetensor( bias: &CandleTensor, weight: &CandleTensor, device: &B::Device, @@ -79,7 +79,7 @@ fn load_linear_safetensor( linear_record } -fn load_intermediate_layer_safetensor( +pub(crate) fn load_intermediate_layer_safetensor( linear_inner_weight: &CandleTensor, linear_inner_bias: &CandleTensor, linear_outer_weight: &CandleTensor, @@ -227,6 +227,22 @@ pub fn load_encoder_from_safetensors( encoder_record } +pub fn load_decoder_from_safetensors( + bias: &CandleTensor, + word_embedding_weights: &CandleTensor, + device: &B::Device, +) -> LinearRecord { + let bias = load_1d_tensor_from_candle::(bias, device); + let weight = load_2d_tensor_from_candle::(word_embedding_weights, device); + let weight = weight.transpose(); + + let linear_record = LinearRecord { + weight: Param::from_tensor(weight), + bias: Some(Param::from_tensor(bias)), + }; + linear_record +} + fn load_embedding_safetensor( weight: &CandleTensor, device: &B::Device, diff --git a/bert-burn/src/model.rs b/bert-burn/src/model.rs index 4810d7c..3af85d1 100644 --- a/bert-burn/src/model.rs +++ b/bert-burn/src/model.rs @@ -1,7 +1,8 @@ use crate::data::BertInferenceBatch; use crate::embedding::{BertEmbeddings, BertEmbeddingsConfig}; use crate::loader::{ - load_embeddings_from_safetensors, load_encoder_from_safetensors, load_pooler_from_safetensors, + load_decoder_from_safetensors, load_embeddings_from_safetensors, load_encoder_from_safetensors, + load_layer_norm_safetensor, load_linear_safetensor, load_pooler_from_safetensors, }; use crate::pooler::{Pooler, PoolerConfig}; use burn::config::Config; @@ -10,6 +11,8 @@ use burn::nn::transformer::{ TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput, }; use burn::nn::Initializer::KaimingUniform; +use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig}; +use burn::tensor::activation::gelu; use burn::tensor::backend::Backend; use burn::tensor::Tensor; use candle_core::{safetensors, Device, Tensor as CandleTensor}; @@ -85,6 +88,19 @@ impl BertModelConfig { } } + pub fn init_with_lm_head(&self, device: &B::Device) -> BertMaskedLM { + let bert = self.init(device); + let lm_head = BertLMHead { + dense: LinearConfig::new(self.hidden_size, self.hidden_size).init(device), + layer_norm: LayerNormConfig::new(self.hidden_size) + .with_epsilon(self.layer_norm_eps) + .init(device), + decoder: LinearConfig::new(self.hidden_size, self.vocab_size).init(device), + }; + + BertMaskedLM { bert, lm_head } + } + fn get_embeddings_config(&self) -> BertEmbeddingsConfig { BertEmbeddingsConfig { vocab_size: self.vocab_size, @@ -104,7 +120,7 @@ impl BertModelConfig { d_model: self.hidden_size, d_ff: self.intermediate_size, dropout: self.hidden_dropout_prob, - norm_first: true, + norm_first: false, quiet_softmax: false, initializer: KaimingUniform { gain: 1.0 / libm::sqrt(3.0), @@ -190,3 +206,89 @@ impl BertModel { model_record } } + +#[derive(Module, Debug)] +pub struct BertMaskedLM { + pub bert: BertModel, + pub lm_head: BertLMHead, +} + +#[derive(Module, Debug)] +pub struct BertLMHead { + pub dense: Linear, + pub layer_norm: LayerNorm, + pub decoder: Linear, +} + +impl BertMaskedLM { + pub fn forward(&self, input: BertInferenceBatch) -> Tensor { + let output = self.bert.forward(BertInferenceBatch { + tokens: input.tokens.clone(), + mask_pad: input.mask_pad.clone(), + }); + let output = self.lm_head.forward(output.hidden_states); + output + } + + pub fn from_safetensors( + file_path: PathBuf, + device: &B::Device, + config: BertModelConfig, + ) -> BertMaskedLMRecord { + let bert = BertModel::from_safetensors(file_path.clone(), device, config.clone()); + let lm_head = BertLMHead::from_safetensors(file_path, device, config); + + BertMaskedLMRecord { bert, lm_head } + } +} + +impl BertLMHead { + pub fn forward(&self, features: Tensor) -> Tensor { + let output = self.dense.forward(features); + let output = gelu(output); + let output = self.layer_norm.forward(output); + + let output = self.decoder.forward(output); + output + } + + pub fn from_safetensors( + file_path: PathBuf, + device: &B::Device, + _config: BertModelConfig, + ) -> BertLMHeadRecord { + let weight_result = safetensors::load::(file_path, &Device::Cpu); + + // Match on the result of loading the weights + let weights = match weight_result { + Ok(weights) => weights, + Err(e) => panic!("Error loading weights: {:?}", e), + }; + + let dense = load_linear_safetensor::( + &weights["lm_head.dense.bias"], + &weights["lm_head.dense.weight"], + device, + ); + let layer_norm = load_layer_norm_safetensor::( + &weights["lm_head.layer_norm.bias"], + &weights["lm_head.layer_norm.weight"], + device, + ); + let decoder = load_decoder_from_safetensors::( + &weights["lm_head.bias"], + &weights + .iter() + .find(|(k, _)| k.contains("word_embeddings.weight")) + .unwrap() + .1, + device, + ); + + BertLMHeadRecord { + dense, + layer_norm, + decoder, + } + } +}