Skip to content

Commit

Permalink
Address comments on PR #34.
Browse files Browse the repository at this point in the history
  • Loading branch information
seurimas committed May 24, 2024
1 parent ffff31b commit 8880259
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 36 deletions.
15 changes: 10 additions & 5 deletions bert-burn/examples/masked.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::fill_mask::fill_mask;
use bert_burn::loader::{download_hf_model, load_model_config};
use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord, BertModel, BertModelRecord};
use bert_burn::model::{BertMaskedLM, BertMaskedLMRecord};
use burn::data::dataloader::batcher::Batcher;
use burn::module::Module;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use std::env;
use std::sync::Arc;

Expand Down Expand Up @@ -61,17 +60,23 @@ pub fn launch<B: Backend>(device: B::Device) {
let [batch_size, _seq_len] = input.tokens.dims();
println!("Input: {:?} // (Batch Size, Seq_len)", input.tokens.shape());

let output = fill_mask(&model, &model_config, &tokenizer, input);
let output = fill_mask(&model, &model_config, tokenizer.as_ref(), input);

for i in 0..batch_size {
let input = &text_samples[i];
let result = &output[i];
println!("Input: {}", input);
for (j, fill_mask_result) in result.iter().enumerate() {
for fill_mask_result in result.iter() {
let mask_idx = fill_mask_result.mask_idx;
let top_k = &fill_mask_result.top_k;
for (k, (score, token)) in top_k.iter().enumerate() {
println!("Top {} Prediction: {} (Score: {:.4})", k + 1, token, score);
println!(
"Top {} Prediction for {}: {} (Score: {:.4})",
k + 1,
mask_idx,
token,
score
);
}
}
}
Expand Down
66 changes: 36 additions & 30 deletions bert-burn/src/fill_mask.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,14 @@
use std::sync::Arc;

use crate::{
data::Tokenizer,
data::{BertInferenceBatch, BertTokenizer},
model::BertMaskedLM,
model::BertModelConfig,
};
use burn::tensor::{activation::softmax, backend::Backend, Data, Element};
use burn::tensor::{activation::softmax, backend::Backend, Data, Element, Tensor};

type TokenType = usize;
const MASK_TOKEN_ID: TokenType = 50264;

fn find_masks<T: Element>(tokens: &Data<T, 1>, mask_token_id: TokenType) -> Vec<usize> {
let mut masks = Vec::new();
for (i, token) in tokens.value.iter().enumerate() {
if token.to_usize() == Some(mask_token_id) {
masks.push(i);
}
}
masks
}

fn data_to_vec_f32<T: Element>(data: &Data<T, 1>) -> Vec<f32> {
data.value.iter().map(|x| x.to_f32().unwrap()).collect()
}

fn top_k(k: usize, probabilities: Vec<f32>) -> Vec<(usize, f32)> {
let mut probabilities = probabilities.iter().enumerate().collect::<Vec<_>>();

probabilities.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());

probabilities.truncate(k);

probabilities.into_iter().map(|(i, &p)| (i, p)).collect()
}

#[derive(Debug, Clone)]
pub struct FillMaskResult {
pub mask_idx: usize,
Expand All @@ -44,7 +18,7 @@ pub struct FillMaskResult {
pub fn fill_mask<B: Backend>(
model: &BertMaskedLM<B>,
model_config: &BertModelConfig,
tokenizer: &Arc<BertTokenizer>,
tokenizer: &BertTokenizer,
input: BertInferenceBatch<B>,
) -> Vec<Vec<FillMaskResult>> {
let [batch_size, seq_len] = input.tokens.dims();
Expand All @@ -71,8 +45,7 @@ pub fn fill_mask<B: Backend>(
.squeeze::<2>(0)
.squeeze(0);
// Find the top k tokens with the highest probabilities
let probs = data_to_vec_f32(&softmax(logits, 0).to_data());
let top_k = top_k(5, probs);
let top_k = top_k(5, logits);
batch_results.push(FillMaskResult {
mask_idx: mask,
top_k: top_k
Expand All @@ -86,3 +59,36 @@ pub fn fill_mask<B: Backend>(

results
}

fn find_masks<T: Element>(tokens: &Data<T, 1>, mask_token_id: TokenType) -> Vec<usize> {
let mut masks = Vec::new();
for (i, token) in tokens.value.iter().enumerate() {
if token.to_usize() == Some(mask_token_id) {
masks.push(i);
}
}
masks
}

fn data_to_vec_f32<T: Element>(data: &Data<T, 1>) -> Vec<f32> {
data.value.iter().map(|x| x.to_f32().unwrap()).collect()
}

fn data_to_vec_usize<T: Element>(data: &Data<T, 1>) -> Vec<usize> {
data.value.iter().map(|x| x.to_usize().unwrap()).collect()
}

fn top_k<B: Backend>(k: usize, logits: Tensor<B, 1>) -> Vec<(usize, f32)> {
let (pre_soft_probs, indices) = logits.sort_with_indices(0);
let (probabilities, indices) = (
data_to_vec_f32(&softmax(pre_soft_probs, 0).to_data()),
data_to_vec_usize(&indices.to_data()),
);
probabilities
.iter()
.enumerate()
.rev()
.take(k)
.map(|(i, &p)| (indices[i], p))
.collect()
}
2 changes: 1 addition & 1 deletion bert-burn/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl<B: Backend> BertLMHead<B> {
pub fn from_safetensors(
file_path: PathBuf,
device: &B::Device,
config: BertModelConfig,
_config: BertModelConfig,
) -> BertLMHeadRecord<B> {
let weight_result = safetensors::load::<PathBuf>(file_path, &Device::Cpu);

Expand Down

0 comments on commit 8880259

Please sign in to comment.