Skip to content

Commit

Permalink
Merge pull request #29 from tracel-ai/resnet/fine-tune
Browse files Browse the repository at this point in the history
[ResNet] Add fine-tuning example
  • Loading branch information
nathanielsimard committed May 7, 2024
2 parents a00bee0 + fda9ccf commit f689d8e
Show file tree
Hide file tree
Showing 22 changed files with 679 additions and 32 deletions.
33 changes: 17 additions & 16 deletions resnet-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "resnet-burn"
version = "0.1.0"
edition = "2021"
[workspace]
# Try
# require version 2 to avoid "feature" additiveness for dev-dependencies
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2
resolver = "2"

members = [
"resnet",
"examples/*",
]

[features]
default = []
std = []
pretrained = ["burn/network", "std", "dep:dirs"]
[workspace.package]
edition = "2021"
version = "0.2.0"
readme = "README.md"
license = "MIT OR Apache-2.0"

[dependencies]
[workspace.dependencies]
# Note: default-features = false is needed to disable std
burn = { version = "0.13.0", default-features = false }
burn-import = "0.13.0"
dirs = { version = "5.0.1", optional = true }
dirs = "5.0.1"
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { version = "0.13.0", features = ["ndarray"] }
image = { version = "0.24.9", features = ["png", "jpeg"] }
41 changes: 38 additions & 3 deletions resnet-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

To this day, [ResNet](https://arxiv.org/abs/1512.03385)s are still a strong baseline for your image
classification tasks. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for
the ResNet variants in [src/model/resnet.rs](src/model/resnet.rs).
the ResNet variants in [resnet.rs](resnet/src/resnet.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

Expand All @@ -28,12 +28,47 @@ resnet-burn = { git = "https://github.com/tracel-ai/models", package = "resnet-b

### Example Usage

The [inference example](examples/inference.rs) initializes a ResNet-18 from the ImageNet
#### Inference

The [inference example](examples/inference/examples/inference.rs) initializes a ResNet-18 from the
ImageNet
[pre-trained weights](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights)
with the `NdArray` backend and performs inference on the provided input image.

You can run the example with the following command:

```sh
cargo run --release --features pretrained --example inference samples/dog.jpg
cargo run --release --example inference samples/dog.jpg --release
```

#### Fine-tuning

For this [multi-label image classification fine-tuning example](examples/finetune), a sample of the
planets dataset from the Kaggle competition
[Planet: Understanding the Amazon from Space](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space)
is downloaded from a
[fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). The
sample dataset is a collection of satellite images with multiple labels describing the scene, as
illustrated below.

<img src="./samples/dataset.jpg" alt="Planet dataset sample" width="1000"/>

To achieve this task, a ResNet-18 pre-trained on the ImageNet dataset is fine-tuned on the target
planets dataset. The training recipe used is fairly simple. The main objective is to demonstrate how to re-use a
pre-trained model for a different downstream task.

Without any bells and whistle, our model achieves over 90% multi-label accuracy (i.e., hamming
score) on the validation set after just 5 epochs.

Run the example with the Torch GPU backend:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --example finetune --features tch-gpu
```

Run it with our WGPU backend:

```sh
cargo run --release --example finetune --features wgpu
```
2 changes: 2 additions & 0 deletions resnet-burn/examples/finetune/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Downloaded files
planet_sample/
24 changes: 24 additions & 0 deletions resnet-burn/examples/finetune/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
name = "finetune"
license.workspace = true
version.workspace = true
edition.workspace = true

[features]
default = ["burn/default"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
resnet-burn = { path = "../../resnet", features = ["pretrained"] }
burn = { workspace = true, features = ["train", "vision", "network"] }

# Dataset files
csv = "1.3.0"
flate2 = "1.0.28"
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] }
serde = { version = "1.0.192", features = ["std", "derive"] }
tar = "0.4.40"
44 changes: 44 additions & 0 deletions resnet-burn/examples/finetune/examples/finetune.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use burn::{backend::Autodiff, tensor::backend::Backend};
use finetune::{inference::infer, training::train};

#[allow(dead_code)]
const ARTIFACT_DIR: &str = "/tmp/resnet-finetune";

#[allow(dead_code)]
fn run<B: Backend>(device: B::Device) {
train::<Autodiff<B>>(ARTIFACT_DIR, device.clone());
infer::<B>(ARTIFACT_DIR, device, 0.5);
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

super::run::<LibTorch>(device);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use burn::{
backend::wgpu::{Wgpu, WgpuDevice},
Wgpu,
};

pub fn run() {
super::run::<Wgpu>(WgpuDevice::default());
}
}

fn main() {
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
}
132 changes: 132 additions & 0 deletions resnet-burn/examples/finetune/src/data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use burn::{
data::{
dataloader::batcher::Batcher,
dataset::vision::{Annotation, ImageDatasetItem, PixelDepth},
},
prelude::*,
};

use super::dataset::CLASSES;

// ImageNet mean and std values
const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const STD: [f32; 3] = [0.229, 0.224, 0.225];

// Planets patch size
const WIDTH: usize = 256;
const HEIGHT: usize = 256;

/// Create a multi-hot encoded tensor.
///
/// # Example
///
/// ```rust, ignore
/// let multi_hot = multi_hot::<B>(&[2, 5, 8], 10, &device);
/// println!("{}", multi_hot.to_data());
/// // [0, 0, 1, 0, 0, 1, 0, 0, 1, 0]
/// ```
pub fn multi_hot<B: Backend>(
indices: &[usize],
num_classes: usize,
device: &B::Device,
) -> Tensor<B, 1, Int> {
Tensor::zeros(Shape::new([num_classes]), device).scatter(
0,
Tensor::from_ints(
indices
.iter()
.map(|i| *i as i32)
.collect::<Vec<_>>()
.as_slice(),
device,
),
Tensor::ones(Shape::new([indices.len()]), device),
)
}

/// Normalizer with ImageNet values as it helps accelerate training since we are fine-tuning from
/// ImageNet pre-trained weights and the model expects the data to be in this normalized range.
#[derive(Clone)]
pub struct Normalizer<B: Backend> {
pub mean: Tensor<B, 4>,
pub std: Tensor<B, 4>,
}

impl<B: Backend> Normalizer<B> {
/// Creates a new normalizer.
pub fn new(device: &Device<B>) -> Self {
let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]);
Self { mean, std }
}

/// Normalizes the input image according to the ImageNet dataset.
///
/// The input image should be in the range [0, 1].
/// The output image will be in the range [-1, 1].
///
/// The normalization is done according to the following formula:
/// `input = (input - mean) / std`
pub fn normalize(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
(input - self.mean.clone()) / self.std.clone()
}
}

#[derive(Clone)]
pub struct ClassificationBatcher<B: Backend> {
normalizer: Normalizer<B>,
device: B::Device,
}

#[derive(Clone, Debug)]
pub struct ClassificationBatch<B: Backend> {
pub images: Tensor<B, 4>,
pub targets: Tensor<B, 2, Int>,
}

impl<B: Backend> ClassificationBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self {
normalizer: Normalizer::<B>::new(&device),
device,
}
}
}

impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for ClassificationBatcher<B> {
fn batch(&self, items: Vec<ImageDatasetItem>) -> ClassificationBatch<B> {
fn image_as_vec_u8(item: ImageDatasetItem) -> Vec<u8> {
// Convert Vec<PixelDepth> to Vec<u8> (Planet images are u8)
item.image
.into_iter()
.map(|p: PixelDepth| -> u8 { p.try_into().unwrap() })
.collect::<Vec<u8>>()
}

let targets = items
.iter()
.map(|item| {
// Expect multi-hot encoded class labels as target (e.g., [0, 1, 0, 0, 1])
if let Annotation::MultiLabel(y) = &item.annotation {
multi_hot(y, CLASSES.len(), &self.device)
} else {
panic!("Invalid target type")
}
})
.collect();

let images = items
.into_iter()
.map(|item| Data::new(image_as_vec_u8(item), Shape::new([HEIGHT, WIDTH, 3])))
.map(|data| Tensor::<B, 3>::from_data(data.convert(), &self.device).permute([2, 0, 1]))
.map(|tensor| tensor / 255) // normalize between [0, 1]
.collect();

let images = Tensor::stack(images, 0);
let targets = Tensor::stack(targets, 0);

let images = self.normalizer.normalize(images);

ClassificationBatch { images, targets }
}
}
Loading

0 comments on commit f689d8e

Please sign in to comment.