From ffff31b4cc7a3ae3436a150ba8b36259923f58e4 Mon Sep 17 00:00:00 2001 From: Nick Anderson Date: Fri, 17 May 2024 18:25:27 -0500 Subject: [PATCH 1/2] Add LM head and masked fill_mask for bert-burn. --- bert-burn/examples/masked.rs | 145 +++++++++++++++++++++++++++++++++++ bert-burn/src/fill_mask.rs | 88 +++++++++++++++++++++ bert-burn/src/lib.rs | 1 + bert-burn/src/loader.rs | 26 +++++-- bert-burn/src/model.rs | 106 ++++++++++++++++++++++++- 5 files changed, 359 insertions(+), 7 deletions(-) create mode 100644 bert-burn/examples/masked.rs create mode 100644 bert-burn/src/fill_mask.rs diff --git a/bert-burn/examples/masked.rs b/bert-burn/examples/masked.rs new file mode 100644 index 0000000..7866698 --- /dev/null +++ b/bert-burn/examples/masked.rs @@ -0,0 +1,145 @@ +use bert_burn::data::{BertInputBatcher, BertTokenizer}; +use bert_burn::fill_mask::fill_mask; +use bert_burn::loader::{download_hf_model, load_model_config}; +use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord, BertModel, BertModelRecord}; +use burn::data::dataloader::batcher::Batcher; +use burn::module::Module; +use burn::tensor::backend::Backend; +use burn::tensor::Tensor; +use std::env; +use std::sync::Arc; + +#[cfg(not(feature = "f16"))] +#[allow(dead_code)] +type ElemType = f32; +#[cfg(feature = "f16")] +type ElemType = burn::tensor::f16; + +pub fn launch(device: B::Device) { + let args: Vec = env::args().collect(); + let default_model = "roberta-base".to_string(); + let model_variant = if args.len() > 1 { + // Use the argument provided by the user + // Possible values: "bert-base-uncased", "roberta-large" etc. + &args[1] + } else { + // Use the default value if no argument is provided + &default_model + }; + + println!("Model variant: {}", model_variant); + + let text_samples = vec![ + "Paris is the of France.".to_string(), + "The goal of life is .".to_string(), + ]; + + let (config_file, model_file) = download_hf_model(model_variant); + let model_config = load_model_config(config_file); + + let model_record: BertMaskedLMRecord = + BertMaskedLM::from_safetensors(model_file, &device, model_config.clone()); + + let model = model_config + .init_with_lm_head(&device) + .load_record(model_record); + + let tokenizer = Arc::new(BertTokenizer::new( + model_variant.to_string(), + model_config.pad_token_id, + )); + + // Batch the input samples to max sequence length with padding + let batcher = Arc::new(BertInputBatcher::::new( + tokenizer.clone(), + device.clone(), + model_config.max_seq_len.unwrap(), + )); + + // Batch input samples using the batcher Shape: [Batch size, Seq_len] + let input = batcher.batch(text_samples.clone()); + let [batch_size, _seq_len] = input.tokens.dims(); + println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape()); + + let output = fill_mask(&model, &model_config, &tokenizer, input); + + for i in 0..batch_size { + let input = &text_samples[i]; + let result = &output[i]; + println!("Input: {}", input); + for (j, fill_mask_result) in result.iter().enumerate() { + let mask_idx = fill_mask_result.mask_idx; + let top_k = &fill_mask_result.top_k; + for (k, (score, token)) in top_k.iter().enumerate() { + println!("Top {} Prediction: {} (Score: {:.4})", k + 1, token, score); + } + } + } +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + + use crate::{launch, ElemType}; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use crate::{launch, ElemType}; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use crate::{launch, ElemType}; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::{launch, ElemType}; + use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; + + pub fn run() { + launch::>(WgpuDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); +} diff --git a/bert-burn/src/fill_mask.rs b/bert-burn/src/fill_mask.rs new file mode 100644 index 0000000..b2cc903 --- /dev/null +++ b/bert-burn/src/fill_mask.rs @@ -0,0 +1,88 @@ +use std::sync::Arc; + +use crate::{ + data::Tokenizer, + data::{BertInferenceBatch, BertTokenizer}, + model::BertMaskedLM, + model::BertModelConfig, +}; +use burn::tensor::{activation::softmax, backend::Backend, Data, Element}; + +type TokenType = usize; +const MASK_TOKEN_ID: TokenType = 50264; + +fn find_masks(tokens: &Data, mask_token_id: TokenType) -> Vec { + let mut masks = Vec::new(); + for (i, token) in tokens.value.iter().enumerate() { + if token.to_usize() == Some(mask_token_id) { + masks.push(i); + } + } + masks +} + +fn data_to_vec_f32(data: &Data) -> Vec { + data.value.iter().map(|x| x.to_f32().unwrap()).collect() +} + +fn top_k(k: usize, probabilities: Vec) -> Vec<(usize, f32)> { + let mut probabilities = probabilities.iter().enumerate().collect::>(); + + probabilities.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + + probabilities.truncate(k); + + probabilities.into_iter().map(|(i, &p)| (i, p)).collect() +} + +#[derive(Debug, Clone)] +pub struct FillMaskResult { + pub mask_idx: usize, + pub top_k: Vec<(f32, String)>, +} + +pub fn fill_mask( + model: &BertMaskedLM, + model_config: &BertModelConfig, + tokenizer: &Arc, + input: BertInferenceBatch, +) -> Vec> { + let [batch_size, seq_len] = input.tokens.dims(); + let output = model.forward(input.clone()); + + let mut results = vec![]; + + // Embedding size + let d_model = model_config.vocab_size.clone(); + for i in 0..batch_size { + let mut batch_results = vec![]; + let input_tokens = input + .tokens + .clone() + .slice([i..i + 1, 0..seq_len]) + .squeeze(0) + .to_data(); + // Find the mask tokens in the input, as a list of indices + let masks = find_masks(&input_tokens, MASK_TOKEN_ID); + for mask in masks { + let logits = output + .clone() + .slice([i..i + 1, mask..(mask + 1), 0..d_model]) + .squeeze::<2>(0) + .squeeze(0); + // Find the top k tokens with the highest probabilities + let probs = data_to_vec_f32(&softmax(logits, 0).to_data()); + let top_k = top_k(5, probs); + batch_results.push(FillMaskResult { + mask_idx: mask, + top_k: top_k + .iter() + .map(|(k, prob)| (*prob, tokenizer.decode(&[*k]))) + .collect(), + }); + } + results.push(batch_results); + } + + results +} diff --git a/bert-burn/src/lib.rs b/bert-burn/src/lib.rs index 1dead18..b384d28 100644 --- a/bert-burn/src/lib.rs +++ b/bert-burn/src/lib.rs @@ -3,6 +3,7 @@ extern crate derive_new; pub mod data; mod embedding; +pub mod fill_mask; pub mod loader; pub mod model; pub mod pooler; diff --git a/bert-burn/src/loader.rs b/bert-burn/src/loader.rs index 03e717e..1416fba 100644 --- a/bert-burn/src/loader.rs +++ b/bert-burn/src/loader.rs @@ -17,7 +17,7 @@ use candle_core::Tensor as CandleTensor; use std::collections::HashMap; use std::path::PathBuf; -fn load_1d_tensor_from_candle( +pub(crate) fn load_1d_tensor_from_candle( tensor: &CandleTensor, device: &B::Device, ) -> Tensor { @@ -29,7 +29,7 @@ fn load_1d_tensor_from_candle( weight } -fn load_2d_tensor_from_candle( +pub(crate) fn load_2d_tensor_from_candle( tensor: &CandleTensor, device: &B::Device, ) -> Tensor { @@ -46,7 +46,7 @@ fn load_2d_tensor_from_candle( weight } -fn load_layer_norm_safetensor( +pub(crate) fn load_layer_norm_safetensor( bias: &CandleTensor, weight: &CandleTensor, device: &B::Device, @@ -62,7 +62,7 @@ fn load_layer_norm_safetensor( layer_norm_record } -fn load_linear_safetensor( +pub(crate) fn load_linear_safetensor( bias: &CandleTensor, weight: &CandleTensor, device: &B::Device, @@ -79,7 +79,7 @@ fn load_linear_safetensor( linear_record } -fn load_intermediate_layer_safetensor( +pub(crate) fn load_intermediate_layer_safetensor( linear_inner_weight: &CandleTensor, linear_inner_bias: &CandleTensor, linear_outer_weight: &CandleTensor, @@ -227,6 +227,22 @@ pub fn load_encoder_from_safetensors( encoder_record } +pub fn load_decoder_from_safetensors( + bias: &CandleTensor, + word_embedding_weights: &CandleTensor, + device: &B::Device, +) -> LinearRecord { + let bias = load_1d_tensor_from_candle::(bias, device); + let weight = load_2d_tensor_from_candle::(word_embedding_weights, device); + let weight = weight.transpose(); + + let linear_record = LinearRecord { + weight: Param::from_tensor(weight), + bias: Some(Param::from_tensor(bias)), + }; + linear_record +} + fn load_embedding_safetensor( weight: &CandleTensor, device: &B::Device, diff --git a/bert-burn/src/model.rs b/bert-burn/src/model.rs index 4810d7c..7d3c663 100644 --- a/bert-burn/src/model.rs +++ b/bert-burn/src/model.rs @@ -1,7 +1,8 @@ use crate::data::BertInferenceBatch; use crate::embedding::{BertEmbeddings, BertEmbeddingsConfig}; use crate::loader::{ - load_embeddings_from_safetensors, load_encoder_from_safetensors, load_pooler_from_safetensors, + load_decoder_from_safetensors, load_embeddings_from_safetensors, load_encoder_from_safetensors, + load_layer_norm_safetensor, load_linear_safetensor, load_pooler_from_safetensors, }; use crate::pooler::{Pooler, PoolerConfig}; use burn::config::Config; @@ -10,6 +11,8 @@ use burn::nn::transformer::{ TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput, }; use burn::nn::Initializer::KaimingUniform; +use burn::nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig}; +use burn::tensor::activation::gelu; use burn::tensor::backend::Backend; use burn::tensor::Tensor; use candle_core::{safetensors, Device, Tensor as CandleTensor}; @@ -85,6 +88,19 @@ impl BertModelConfig { } } + pub fn init_with_lm_head(&self, device: &B::Device) -> BertMaskedLM { + let bert = self.init(device); + let lm_head = BertLMHead { + dense: LinearConfig::new(self.hidden_size, self.hidden_size).init(device), + layer_norm: LayerNormConfig::new(self.hidden_size) + .with_epsilon(self.layer_norm_eps) + .init(device), + decoder: LinearConfig::new(self.hidden_size, self.vocab_size).init(device), + }; + + BertMaskedLM { bert, lm_head } + } + fn get_embeddings_config(&self) -> BertEmbeddingsConfig { BertEmbeddingsConfig { vocab_size: self.vocab_size, @@ -104,7 +120,7 @@ impl BertModelConfig { d_model: self.hidden_size, d_ff: self.intermediate_size, dropout: self.hidden_dropout_prob, - norm_first: true, + norm_first: false, quiet_softmax: false, initializer: KaimingUniform { gain: 1.0 / libm::sqrt(3.0), @@ -190,3 +206,89 @@ impl BertModel { model_record } } + +#[derive(Module, Debug)] +pub struct BertMaskedLM { + pub bert: BertModel, + pub lm_head: BertLMHead, +} + +#[derive(Module, Debug)] +pub struct BertLMHead { + pub dense: Linear, + pub layer_norm: LayerNorm, + pub decoder: Linear, +} + +impl BertMaskedLM { + pub fn forward(&self, input: BertInferenceBatch) -> Tensor { + let output = self.bert.forward(BertInferenceBatch { + tokens: input.tokens.clone(), + mask_pad: input.mask_pad.clone(), + }); + let output = self.lm_head.forward(output.hidden_states); + output + } + + pub fn from_safetensors( + file_path: PathBuf, + device: &B::Device, + config: BertModelConfig, + ) -> BertMaskedLMRecord { + let bert = BertModel::from_safetensors(file_path.clone(), device, config.clone()); + let lm_head = BertLMHead::from_safetensors(file_path, device, config); + + BertMaskedLMRecord { bert, lm_head } + } +} + +impl BertLMHead { + pub fn forward(&self, features: Tensor) -> Tensor { + let output = self.dense.forward(features); + let output = gelu(output); + let output = self.layer_norm.forward(output); + + let output = self.decoder.forward(output); + output + } + + pub fn from_safetensors( + file_path: PathBuf, + device: &B::Device, + config: BertModelConfig, + ) -> BertLMHeadRecord { + let weight_result = safetensors::load::(file_path, &Device::Cpu); + + // Match on the result of loading the weights + let weights = match weight_result { + Ok(weights) => weights, + Err(e) => panic!("Error loading weights: {:?}", e), + }; + + let dense = load_linear_safetensor::( + &weights["lm_head.dense.bias"], + &weights["lm_head.dense.weight"], + device, + ); + let layer_norm = load_layer_norm_safetensor::( + &weights["lm_head.layer_norm.bias"], + &weights["lm_head.layer_norm.weight"], + device, + ); + let decoder = load_decoder_from_safetensors::( + &weights["lm_head.bias"], + &weights + .iter() + .find(|(k, _)| k.contains("word_embeddings.weight")) + .unwrap() + .1, + device, + ); + + BertLMHeadRecord { + dense, + layer_norm, + decoder, + } + } +} From 888025919fe11704cb20f217a269c03be4615044 Mon Sep 17 00:00:00 2001 From: Nick Anderson Date: Thu, 23 May 2024 23:00:08 -0500 Subject: [PATCH 2/2] Address comments on PR #34. --- bert-burn/examples/masked.rs | 15 +++++--- bert-burn/src/fill_mask.rs | 66 ++++++++++++++++++++---------------- bert-burn/src/model.rs | 2 +- 3 files changed, 47 insertions(+), 36 deletions(-) diff --git a/bert-burn/examples/masked.rs b/bert-burn/examples/masked.rs index 7866698..5a4b948 100644 --- a/bert-burn/examples/masked.rs +++ b/bert-burn/examples/masked.rs @@ -1,11 +1,10 @@ use bert_burn::data::{BertInputBatcher, BertTokenizer}; use bert_burn::fill_mask::fill_mask; use bert_burn::loader::{download_hf_model, load_model_config}; -use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord, BertModel, BertModelRecord}; +use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord}; use burn::data::dataloader::batcher::Batcher; use burn::module::Module; use burn::tensor::backend::Backend; -use burn::tensor::Tensor; use std::env; use std::sync::Arc; @@ -61,17 +60,23 @@ pub fn launch(device: B::Device) { let [batch_size, _seq_len] = input.tokens.dims(); println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape()); - let output = fill_mask(&model, &model_config, &tokenizer, input); + let output = fill_mask(&model, &model_config, tokenizer.as_ref(), input); for i in 0..batch_size { let input = &text_samples[i]; let result = &output[i]; println!("Input: {}", input); - for (j, fill_mask_result) in result.iter().enumerate() { + for fill_mask_result in result.iter() { let mask_idx = fill_mask_result.mask_idx; let top_k = &fill_mask_result.top_k; for (k, (score, token)) in top_k.iter().enumerate() { - println!("Top {} Prediction: {} (Score: {:.4})", k + 1, token, score); + println!( + "Top {} Prediction for {}: {} (Score: {:.4})", + k + 1, + mask_idx, + token, + score + ); } } } diff --git a/bert-burn/src/fill_mask.rs b/bert-burn/src/fill_mask.rs index b2cc903..4e53820 100644 --- a/bert-burn/src/fill_mask.rs +++ b/bert-burn/src/fill_mask.rs @@ -1,40 +1,14 @@ -use std::sync::Arc; - use crate::{ data::Tokenizer, data::{BertInferenceBatch, BertTokenizer}, model::BertMaskedLM, model::BertModelConfig, }; -use burn::tensor::{activation::softmax, backend::Backend, Data, Element}; +use burn::tensor::{activation::softmax, backend::Backend, Data, Element, Tensor}; type TokenType = usize; const MASK_TOKEN_ID: TokenType = 50264; -fn find_masks(tokens: &Data, mask_token_id: TokenType) -> Vec { - let mut masks = Vec::new(); - for (i, token) in tokens.value.iter().enumerate() { - if token.to_usize() == Some(mask_token_id) { - masks.push(i); - } - } - masks -} - -fn data_to_vec_f32(data: &Data) -> Vec { - data.value.iter().map(|x| x.to_f32().unwrap()).collect() -} - -fn top_k(k: usize, probabilities: Vec) -> Vec<(usize, f32)> { - let mut probabilities = probabilities.iter().enumerate().collect::>(); - - probabilities.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); - - probabilities.truncate(k); - - probabilities.into_iter().map(|(i, &p)| (i, p)).collect() -} - #[derive(Debug, Clone)] pub struct FillMaskResult { pub mask_idx: usize, @@ -44,7 +18,7 @@ pub struct FillMaskResult { pub fn fill_mask( model: &BertMaskedLM, model_config: &BertModelConfig, - tokenizer: &Arc, + tokenizer: &BertTokenizer, input: BertInferenceBatch, ) -> Vec> { let [batch_size, seq_len] = input.tokens.dims(); @@ -71,8 +45,7 @@ pub fn fill_mask( .squeeze::<2>(0) .squeeze(0); // Find the top k tokens with the highest probabilities - let probs = data_to_vec_f32(&softmax(logits, 0).to_data()); - let top_k = top_k(5, probs); + let top_k = top_k(5, logits); batch_results.push(FillMaskResult { mask_idx: mask, top_k: top_k @@ -86,3 +59,36 @@ pub fn fill_mask( results } + +fn find_masks(tokens: &Data, mask_token_id: TokenType) -> Vec { + let mut masks = Vec::new(); + for (i, token) in tokens.value.iter().enumerate() { + if token.to_usize() == Some(mask_token_id) { + masks.push(i); + } + } + masks +} + +fn data_to_vec_f32(data: &Data) -> Vec { + data.value.iter().map(|x| x.to_f32().unwrap()).collect() +} + +fn data_to_vec_usize(data: &Data) -> Vec { + data.value.iter().map(|x| x.to_usize().unwrap()).collect() +} + +fn top_k(k: usize, logits: Tensor) -> Vec<(usize, f32)> { + let (pre_soft_probs, indices) = logits.sort_with_indices(0); + let (probabilities, indices) = ( + data_to_vec_f32(&softmax(pre_soft_probs, 0).to_data()), + data_to_vec_usize(&indices.to_data()), + ); + probabilities + .iter() + .enumerate() + .rev() + .take(k) + .map(|(i, &p)| (indices[i], p)) + .collect() +} diff --git a/bert-burn/src/model.rs b/bert-burn/src/model.rs index 7d3c663..3af85d1 100644 --- a/bert-burn/src/model.rs +++ b/bert-burn/src/model.rs @@ -255,7 +255,7 @@ impl BertLMHead { pub fn from_safetensors( file_path: PathBuf, device: &B::Device, - config: BertModelConfig, + _config: BertModelConfig, ) -> BertLMHeadRecord { let weight_result = safetensors::load::(file_path, &Device::Cpu);