Skip to content

Commit

Permalink
Merge pull request #6 from antimora/fix-bug
Browse files Browse the repository at this point in the history
Change record type to NamedMpk to avoid overflow
  • Loading branch information
nathanielsimard authored Dec 12, 2023
2 parents 99be61c + e0cd789 commit b9c77d7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions squeezenet-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ weights_file = ["burn/default"]
[dependencies]

# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/burn-rs/burn", package = "burn", default-features = false }
burn = { version = "0.10.0", default-features = false }

# Used to load weights from a file
serde = { version = "1.0.183", default-features = false, features = [
Expand All @@ -31,9 +31,9 @@ serde = { version = "1.0.183", default-features = false, features = [

[dev-dependencies]
# Used by the classify example
burn-ndarray = { git = "https://github.com/burn-rs/burn", package = "burn-ndarray" }
burn-ndarray = { version = "0.10.0", package = "burn-ndarray" }
image = { version = "0.24.7", features = ["png", "jpeg"] }

[build-dependencies]
# Used to generate code from ONNX model
burn-import = { git = "https://github.com/burn-rs/burn", package = "burn-import" }
burn-import = { version = "0.10.0", package = "burn-import" }
6 changes: 3 additions & 3 deletions squeezenet-burn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use burn_import::onnx::ModelGen;

const LABEL_SOURCE_FILE: &str = "src/model/label.txt";
const LABEL_DEST_FILE: &str = "model/label.rs";
const GENERATED_MODEL_WEIGHTS_FILE: &str = "squeezenet1.mpk.gz";
const GENERATED_MODEL_WEIGHTS_FILE: &str = "squeezenet1.mpk";
const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx";
const OUT_DIR: &str = "model/";

Expand All @@ -31,7 +31,7 @@ fn main() {
let (record_type, embed_states) = if cfg!(feature = "weights_embedded") {
(RecordType::Bincode, true)
} else {
(RecordType::NamedMpkGz, false)
(RecordType::NamedMpk, false)
};

// Check if half precision is enabled.
Expand Down Expand Up @@ -59,7 +59,7 @@ fn main() {
fn generate_labels_from_txt_file() -> std::io::Result<()> {
let out_dir = env::var("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join(LABEL_DEST_FILE);
let mut f = File::create(&dest_path)?;
let mut f = File::create(dest_path)?;

let file = File::open(LABEL_SOURCE_FILE)?;
let reader = BufReader::new(file);
Expand Down

0 comments on commit b9c77d7

Please sign in to comment.