Skip to content

Commit

Permalink
Merge pull request #4 from antimora/no-std
Browse files Browse the repository at this point in the history
Make SqueezeNet model no-std compatible
  • Loading branch information
nathanielsimard authored Sep 25, 2023
2 parents de43abf + 45c3361 commit 5323c22
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 13 deletions.
23 changes: 17 additions & 6 deletions squeezenet-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,31 @@ version = "0.1.0"
edition = "2021"

[features]
# TODO make it work with no_std
# TODO export model with embeddible weights
# TODO export model with f16 weights
default = []
default = ["weights_file"]

# Enables Half precision (f16) support
weights_f16 = []

# Embed weights into the binary
weights_embedded = []

# Use weights from a file
weights_file = ["burn/default"]


[dependencies]
burn = { git = "https://github.com/burn-rs/burn", package = "burn" }

# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/burn-rs/burn", package = "burn", default-features = false }

# Used to load weights from a file
serde = { version = "1.0.183", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
# Used by examples
# Used by the classify example
burn-ndarray = { git = "https://github.com/burn-rs/burn", package = "burn-ndarray" }
image = { version = "0.24.7", features = ["png", "jpeg"] }

Expand Down
33 changes: 30 additions & 3 deletions squeezenet-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The labels for the classes are included in the crate and generated from the
The data normalizer for the model is included in the crate. See
[Normalizer](src/model/normalizer.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

See the [classify example](examples/classify.rs) for how to use the model.

## Usage
Expand All @@ -31,12 +33,37 @@ Add this to your `Cargo.toml`:

```toml
[dependencies]
squeezenet-burn = { git = "https://github.com/burn-rs/models", package = "squeezenet-burn" }

squeezenet-burn = { git = "https://github.com/burn-rs/models", package = "squeezenet-burn", features = ["weights_embedded"], default-features = false }
```

### To run the example

1. Use the `weights_embedded` feature to embed the weights in the binary.

```shell
cargo r --release --example classify samples/flamingo.jpg
cargo r --release --features weights_embedded --no-default-features --example classify samples/flamingo.jpg
```

2. Use the `weights_file` feature to load the weights from a file.

```shell
cargo r --release --features weights_file --example classify samples/flamingo.jpg
```

3. Use the `weights_f16` feature to use 16-bit floating point numbers for the weights.

```shell
cargo r --release --features "weights_embedded, weights_f16" --no-default-features --example classify samples/flamingo.jpg
```

Or

```shell
cargo r --release --features "weights_file, weights_f16" --example classify samples/flamingo.jpg
```

## Feature Flags

- `weights_file`: Load the weights from a file (enabled by default).
- `weights_embedded`: Embed the weights in the binary.
- `weights_f16`: Use 16-bit floating point numbers for the weights. (by default 32-bit is used)
60 changes: 58 additions & 2 deletions squeezenet-burn/build.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,56 @@
use std::env;
use std::fs;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
use std::path::Path;

use burn_import::burn::graph::RecordType;
use burn_import::onnx::ModelGen;

const LABEL_SOURCE_FILE: &str = "src/model/label.txt";
const LABEL_DEST_FILE: &str = "model/label.rs";
const GENERATED_MODEL_WEIGHTS_FILE: &str = "squeezenet1.mpk.gz";
const INPUT_ONNX_FILE: &str = "src/model/squeezenet1.onnx";
const OUT_DIR: &str = "model/";

fn main() {
// Re-run the build script if model files change.
println!("cargo:rerun-if-changed=src/model");

// Make sure either weights_file or weights_embedded is enabled.
if cfg!(feature = "weights_file") && cfg!(feature = "weights_embedded") {
panic!("Only one of the features weights_file and weights_embedded can be enabled");
}

// Make sure at least one of weights_file or weights_embedded is enabled.
if !cfg!(feature = "weights_file") && !cfg!(feature = "weights_embedded") {
panic!("One of the features weights_file and weights_embedded must be enabled");
}

// Check if the weights are embedded.
let (record_type, embed_states) = if cfg!(feature = "weights_embedded") {
(RecordType::Bincode, true)
} else {
(RecordType::NamedMpkGz, false)
};

// Check if half precision is enabled.
let half_precision = cfg!(feature = "half_precision");

// Generate the model code from the ONNX file.
ModelGen::new()
.input("src/model/squeezenet1.onnx")
.out_dir("model/")
.input(INPUT_ONNX_FILE)
.out_dir(OUT_DIR)
.record_type(record_type)
.embed_states(embed_states)
.half_precision(half_precision)
.run_from_script();

// Copy the weights next to the executable.
if cfg!(feature = "weights_file") {
copy_weights_next_to_executable();
}

// Generate the labels from the synset.txt file.
generate_labels_from_txt_file().unwrap();
}
Expand All @@ -39,3 +72,26 @@ fn generate_labels_from_txt_file() -> std::io::Result<()> {

Ok(())
}

/// Copy the weights file next to the executable.
fn copy_weights_next_to_executable() {
// Obtain the OUT_DIR path from the environment variable.
let out_dir = env::var("OUT_DIR").expect("OUT_DIR not defined");

// Weights file in OUT_DIR that you want to copy.
let source_path = Path::new(&out_dir)
.join("model")
.join(GENERATED_MODEL_WEIGHTS_FILE);

// Determine the profile (debug or release) to set the appropriate destination directory.
let profile = env::var("PROFILE").expect("PROFILE not defined");
let target_dir = format!("target/{}", profile);

// Specify the destination path.
let destination_path = Path::new(&target_dir)
.join("examples")
.join(GENERATED_MODEL_WEIGHTS_FILE);

// Copy the file.
fs::copy(source_path, destination_path).expect("Failed to copy generated file");
}
20 changes: 18 additions & 2 deletions squeezenet-burn/examples/classify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ use image::{self, GenericImageView, Pixel};
const HEIGHT: usize = 224;
const WIDTH: usize = 224;

#[cfg(feature = "weights_file")]
const RECORD_FILE: &str = "squeezenet1";

type Backend = NdArrayBackend<f32>;

fn main() {
// Path to the image from the main args
let img_path = std::env::args().nth(1).expect("No image path provided");

// Load the image
let img = image::open(&img_path).expect(format!("Failed to load image: {img_path}").as_str());
let img = image::open(&img_path).unwrap_or_else(|_| panic!("Failed to load image: {img_path}"));

// Resize it to 224x224
let resized_img = img.resize_exact(
Expand Down Expand Up @@ -47,7 +50,20 @@ fn main() {
let normalized_image = normalizer.normalize(image_input);

// Create the model
let model = Model::<Backend>::default();
// Load the weights from the file next to the executable
#[cfg(feature = "weights_file")]
let weights_file = std::env::current_exe()
.unwrap()
.parent()
.unwrap()
.join(RECORD_FILE);

#[cfg(feature = "weights_file")]
let model = Model::<Backend>::from_file(weights_file.to_str().unwrap());

#[cfg(feature = "weights_embedded")]
// Load model from embedded weights
let model = Model::<Backend>::from_embedded();

// Run the model
let output = model.forward(normalized_image);
Expand Down
1 change: 1 addition & 0 deletions squeezenet-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
#![no_std]
pub mod model;

0 comments on commit 5323c22

Please sign in to comment.