From e2f060fa816ba544db91251714eace59e5fa9f9c Mon Sep 17 00:00:00 2001 From: vincent Date: Wed, 15 May 2024 21:10:40 +0800 Subject: [PATCH] [Squeezenet] Upgrade to burn 0.13.0 (#33) * [Squeezenet] Upgrade to burn 0.13.0 * Upgrade to burn 0.13.0 * Add feature flag to disable copy the weights file Signed-off-by: vincent * [CI] Update the minimum Rust version to 1.75 Signed-off-by: vincent --------- Signed-off-by: vincent --- .github/workflows/rust.yml | 2 +- squeezenet-burn/Cargo.toml | 11 +++++++---- squeezenet-burn/build.rs | 2 +- squeezenet-burn/examples/classify.rs | 10 ++++++---- squeezenet-burn/src/model/normalizer.rs | 12 +++--------- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c01d36b..10facb0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: - rust: [stable, 1.74.0] + rust: [stable, 1.75.0] steps: - uses: actions/checkout@v4 diff --git a/squeezenet-burn/Cargo.toml b/squeezenet-burn/Cargo.toml index 40e3977..8721ecf 100644 --- a/squeezenet-burn/Cargo.toml +++ b/squeezenet-burn/Cargo.toml @@ -6,7 +6,7 @@ version = "0.1.0" edition = "2021" [features] -default = ["weights_file"] +default = ["weights_file", "weights_file_dump"] # Enables Half precision (f16) support weights_f16 = [] @@ -17,11 +17,14 @@ weights_embedded = [] # Use weights from a file weights_file = ["burn/default"] +# Copy weights file to specif folder +weights_file_dump = [] + [dependencies] # Note: default-features = false is needed to disable std -burn = { version = "0.11.1", default-features = false } +burn = { version = "0.13.0", default-features = false } # Used to load weights from a file serde = { version = "1.0.183", default-features = false, features = [ @@ -31,9 +34,9 @@ serde = { version = "1.0.183", default-features = false, features = [ [dev-dependencies] # Used by the classify example -burn = { version = "0.11.1", features = ["ndarray"] } +burn = { version = "0.13.0", features = ["ndarray"] } image = { version = "0.24.7", features = ["png", "jpeg"] } [build-dependencies] # Used to generate code from ONNX model -burn-import = { version = "0.11.1", package = "burn-import" } +burn-import = { version = "0.13.2", package = "burn-import" } diff --git a/squeezenet-burn/build.rs b/squeezenet-burn/build.rs index bb8b3f8..b9dffd6 100644 --- a/squeezenet-burn/build.rs +++ b/squeezenet-burn/build.rs @@ -47,7 +47,7 @@ fn main() { .run_from_script(); // Copy the weights next to the executable. - if cfg!(feature = "weights_file") { + if cfg!(feature = "weights_file") && cfg!(feature = "weights_file_dump") { copy_weights_next_to_executable(); } diff --git a/squeezenet-burn/examples/classify.rs b/squeezenet-burn/examples/classify.rs index 9900e9c..b367c4e 100644 --- a/squeezenet-burn/examples/classify.rs +++ b/squeezenet-burn/examples/classify.rs @@ -42,11 +42,13 @@ fn main() { } } - // Create a tensor from the array - let image_input = Tensor::::from_data(img_array).reshape([1, 3, HEIGHT, WIDTH]); + let device = Default::default(); + // Create a tensor from the array + let image_input = + Tensor::::from_data(img_array, &device).reshape([1, 3, HEIGHT, WIDTH]); // Normalize the image - let normalizer = Normalizer::new(); + let normalizer = Normalizer::new(&device); let normalized_image = normalizer.normalize(image_input); // Create the model @@ -59,7 +61,7 @@ fn main() { .join(RECORD_FILE); #[cfg(feature = "weights_file")] - let model = Model::::from_file(weights_file.to_str().unwrap()); + let model = Model::::from_file(weights_file.to_str().unwrap(), &device); #[cfg(feature = "weights_embedded")] // Load model from embedded weights diff --git a/squeezenet-burn/src/model/normalizer.rs b/squeezenet-burn/src/model/normalizer.rs index 7e1e601..f14439e 100644 --- a/squeezenet-burn/src/model/normalizer.rs +++ b/squeezenet-burn/src/model/normalizer.rs @@ -13,9 +13,9 @@ pub struct Normalizer { impl Normalizer { /// Creates a new normalizer. - pub fn new() -> Self { - let mean = Tensor::from_floats(MEAN).reshape([1, 3, 1, 1]); - let std = Tensor::from_floats(STD).reshape([1, 3, 1, 1]); + pub fn new(device: &B::Device) -> Self { + let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]); + let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]); Self { mean, std } } @@ -30,9 +30,3 @@ impl Normalizer { (input - self.mean.clone()) / self.std.clone() } } - -impl Default for Normalizer { - fn default() -> Self { - Self::new() - } -}