Skip to content

Commit

Permalink
[Squeezenet] Upgrade to burn 0.13.0 (#33)
Browse files Browse the repository at this point in the history
* [Squeezenet] Upgrade to burn 0.13.0

* Upgrade to burn 0.13.0
* Add feature flag to disable copy the weights file

Signed-off-by: vincent <[email protected]>

* [CI] Update the minimum Rust version to 1.75

Signed-off-by: vincent <[email protected]>

---------

Signed-off-by: vincent <[email protected]>
  • Loading branch information
CaptainVincent committed May 15, 2024
1 parent 14ae737 commit e2f060f
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:

strategy:
matrix:
rust: [stable, 1.74.0]
rust: [stable, 1.75.0]

steps:
- uses: actions/checkout@v4
Expand Down
11 changes: 7 additions & 4 deletions squeezenet-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version = "0.1.0"
edition = "2021"

[features]
default = ["weights_file"]
default = ["weights_file", "weights_file_dump"]

# Enables Half precision (f16) support
weights_f16 = []
Expand All @@ -17,11 +17,14 @@ weights_embedded = []
# Use weights from a file
weights_file = ["burn/default"]

# Copy weights file to specif folder
weights_file_dump = []


[dependencies]

# Note: default-features = false is needed to disable std
burn = { version = "0.11.1", default-features = false }
burn = { version = "0.13.0", default-features = false }

# Used to load weights from a file
serde = { version = "1.0.183", default-features = false, features = [
Expand All @@ -31,9 +34,9 @@ serde = { version = "1.0.183", default-features = false, features = [

[dev-dependencies]
# Used by the classify example
burn = { version = "0.11.1", features = ["ndarray"] }
burn = { version = "0.13.0", features = ["ndarray"] }
image = { version = "0.24.7", features = ["png", "jpeg"] }

[build-dependencies]
# Used to generate code from ONNX model
burn-import = { version = "0.11.1", package = "burn-import" }
burn-import = { version = "0.13.2", package = "burn-import" }
2 changes: 1 addition & 1 deletion squeezenet-burn/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn main() {
.run_from_script();

// Copy the weights next to the executable.
if cfg!(feature = "weights_file") {
if cfg!(feature = "weights_file") && cfg!(feature = "weights_file_dump") {
copy_weights_next_to_executable();
}

Expand Down
10 changes: 6 additions & 4 deletions squeezenet-burn/examples/classify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ fn main() {
}
}

// Create a tensor from the array
let image_input = Tensor::<Backend, 3>::from_data(img_array).reshape([1, 3, HEIGHT, WIDTH]);
let device = Default::default();

// Create a tensor from the array
let image_input =
Tensor::<Backend, 3>::from_data(img_array, &device).reshape([1, 3, HEIGHT, WIDTH]);
// Normalize the image
let normalizer = Normalizer::new();
let normalizer = Normalizer::new(&device);
let normalized_image = normalizer.normalize(image_input);

// Create the model
Expand All @@ -59,7 +61,7 @@ fn main() {
.join(RECORD_FILE);

#[cfg(feature = "weights_file")]
let model = Model::<Backend>::from_file(weights_file.to_str().unwrap());
let model = Model::<Backend>::from_file(weights_file.to_str().unwrap(), &device);

#[cfg(feature = "weights_embedded")]
// Load model from embedded weights
Expand Down
12 changes: 3 additions & 9 deletions squeezenet-burn/src/model/normalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ pub struct Normalizer<B: Backend> {

impl<B: Backend> Normalizer<B> {
/// Creates a new normalizer.
pub fn new() -> Self {
let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]);
let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]);
pub fn new(device: &B::Device) -> Self {
let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]);
Self { mean, std }
}

Expand All @@ -30,9 +30,3 @@ impl<B: Backend> Normalizer<B> {
(input - self.mean.clone()) / self.std.clone()
}
}

impl<B: Backend> Default for Normalizer<B> {
fn default() -> Self {
Self::new()
}
}

0 comments on commit e2f060f

Please sign in to comment.