Skip to content

Commit

Permalink
Add ImageNet pre-trained weights under pretrained feature flag (#18)
Browse files Browse the repository at this point in the history
* Add ImageNet pre-trained weights under `pretrained` feature flag

* Remove print comments and fix ResNetConfig private methods

* Fix pretrained weights enum usage and expansion=4 for bigger ResNets

* Fix regex for larger models with more than 10 layer blocks

* Fix error message typo

* Add burn/ndarry to dev-dependencies
  • Loading branch information
laggui authored Feb 14, 2024
1 parent b225ed6 commit 323a00d
Show file tree
Hide file tree
Showing 8 changed files with 855 additions and 230 deletions.
13 changes: 9 additions & 4 deletions resnet-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@ version = "0.1.0"
edition = "2021"

[features]
default = ["burn/default"]

default = []
std = []
pretrained = ["burn/network", "std", "dep:dirs"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn.git", rev = "75cb5b6d5633c1c6092cf5046419da75e7f74b11", default-features = false }
burn = { git = "https://github.com/tracel-ai/burn.git", rev = "9a2cbadd41161c8aac142bbcb9c2ceaf5ffd6edd", default-features = false }
burn-import = { git = "https://github.com/tracel-ai/burn.git", rev = "9a2cbadd41161c8aac142bbcb9c2ceaf5ffd6edd" }
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn-import = { git = "https://github.com/tracel-ai/burn.git", rev = "75cb5b6d5633c1c6092cf5046419da75e7f74b11"}
burn = { git = "https://github.com/tracel-ai/burn.git", rev = "9a2cbadd41161c8aac142bbcb9c2ceaf5ffd6edd", features = [
"ndarray",
] }
image = { version = "0.24.7", features = ["png", "jpeg"] }
22 changes: 14 additions & 8 deletions resnet-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,23 @@ Add this to your `Cargo.toml`:
resnet-burn = { git = "https://github.com/burn-rs/models", package = "resnet-burn", default-features = false }
```

If you want to get the pre-trained ImageNet weights, enable the `pretrained` feature flag.

```toml
[dependencies]
resnet-burn = { git = "https://github.com/burn-rs/models", package = "resnet-burn", features = ["pretrained"] }
```

**Important:** this feature requires `std`.

### Example Usage

The [inference example](examples/inference.rs) initializes a ResNet-18 with the `NdArray` backend,
imports the ImageNet pre-trained weights from
[`torchvision`](https://download.pytorch.org/models/resnet18-f37072fd.pth) and performs inference on
the provided input image.
The [inference example](examples/inference.rs) initializes a ResNet-18 from the ImageNet
[pre-trained weights](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights)
with the `NdArray` backend and performs inference on the provided input image.

After downloading the
[pre-trained weights](https://download.pytorch.org/models/resnet18-f37072fd.pth) to the current
directory, you can run the example with the following command:
You can run the example with the following command:

```sh
cargo run --release --example inference samples/dog.jpg
cargo run --release --features pretrained --example inference samples/dog.jpg
```
27 changes: 6 additions & 21 deletions resnet-burn/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
use resnet_burn::model::{imagenet, resnet::ResNet};
use resnet_burn::model::{imagenet, resnet::ResNet, weights};

use burn::{
backend::NdArray,
module::Module,
record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder},
record::{FullPrecisionSettings, NamedMpkFileRecorder},
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
};
use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};

const TORCH_WEIGHTS: &str = "resnet18-f37072fd.pth";
const MODEL_PATH: &str = "resnet18-ImageNet1k";
const NUM_CLASSES: usize = 1000;
const HEIGHT: usize = 224;
const WIDTH: usize = 224;

Expand All @@ -32,22 +29,10 @@ pub fn main() {

// Create ResNet-18
let device = Default::default();
let model: ResNet<NdArray, _> = ResNet::resnet18(NUM_CLASSES, &device);

// Load weights from torch state_dict
let load_args = LoadArgs::new(TORCH_WEIGHTS.into())
// Map *.downsample.0.* -> *.downsample.conv.*
.with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2")
// Map *.downsample.1.* -> *.downsample.bn.*
.with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2")
// Map layer[i].[j].* -> layer[i].blocks.[j].*
.with_key_remap("(layer[1-4])\\.([0-9])\\.(.+)", "$1.blocks.$2.$3");
let record = PyTorchFileRecorder::<FullPrecisionSettings>::new()
.load(load_args, &device)
.map_err(|err| format!("Failed to load weights.\nError: {err}"))
.unwrap();

let model = model.load_record(record);
let model: ResNet<NdArray, _> =
ResNet::resnet18_pretrained(weights::ResNet18::ImageNet1kV1, &device)
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}"))
.unwrap();

// Save the model to a supported format and load it back
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
Expand Down
2 changes: 1 addition & 1 deletion resnet-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#![no_std]
#![cfg_attr(not(feature = "std"), no_std)]
pub mod model;
extern crate alloc;
Loading

0 comments on commit 323a00d

Please sign in to comment.