Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ashdtu committed Feb 3, 2024
1 parent 5451927 commit d264354
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions bert-burn/examples/infer-embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use bert_burn::data::{BertInputBatcher, BertTokenizer};
use bert_burn::loader::{load_model_config, load_model_from_safetensors};
use burn::data::dataloader::batcher::Batcher;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use serde_json::Value;
use std::collections::HashMap;
use std::fs::File;
use std::sync::Arc;
use burn::tensor::Tensor;
// use burn_import::pytorch::PyTorchFileRecorder;
// use burn::record::{FullPrecisionSettings, Recorder};

Expand Down Expand Up @@ -57,14 +57,13 @@ pub fn launch<B: Backend>(device: B::Device) {

// Batch input samples using the batcher Shape: [Batch size, Seq_len]
let input = batcher.batch(text_samples.clone());
println!("Input shape: {:?} // (Batch Size, Seq_len)", input.tokens.shape());
println!(
"Input shape: {:?} // (Batch Size, Seq_len)",
input.tokens.shape()
);

let model_config: BertModelConfig = load_model_config(config);
let model = load_model_from_safetensors(
"weights/model.safetensors",
&device,
model_config,
);
let model = load_model_from_safetensors("weights/model.safetensors", &device, model_config);

let output = model.forward(input);

Expand Down

0 comments on commit d264354

Please sign in to comment.