diff --git a/resnet-burn/Cargo.toml b/resnet-burn/Cargo.toml index 9d6979a..18c5194 100644 --- a/resnet-burn/Cargo.toml +++ b/resnet-burn/Cargo.toml @@ -1,25 +1,26 @@ -[package] -authors = ["guillaumelagrange "] -license = "MIT OR Apache-2.0" -name = "resnet-burn" -version = "0.1.0" -edition = "2021" +[workspace] +# Try +# require version 2 to avoid "feature" additiveness for dev-dependencies +# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2 +resolver = "2" + +members = [ + "resnet", + "examples/*", +] -[features] -default = [] -std = [] -pretrained = ["burn/network", "std", "dep:dirs"] +[workspace.package] +edition = "2021" +version = "0.2.0" +readme = "README.md" +license = "MIT OR Apache-2.0" -[dependencies] +[workspace.dependencies] # Note: default-features = false is needed to disable std burn = { version = "0.13.0", default-features = false } burn-import = "0.13.0" -dirs = { version = "5.0.1", optional = true } +dirs = "5.0.1" serde = { version = "1.0.192", default-features = false, features = [ "derive", "alloc", ] } # alloc is for no_std, derive is needed - -[dev-dependencies] -burn = { version = "0.13.0", features = ["ndarray"] } -image = { version = "0.24.9", features = ["png", "jpeg"] } diff --git a/resnet-burn/README.md b/resnet-burn/README.md index 9af2d62..e472743 100644 --- a/resnet-burn/README.md +++ b/resnet-burn/README.md @@ -2,7 +2,7 @@ To this day, [ResNet](https://arxiv.org/abs/1512.03385)s are still a strong baseline for your image classification tasks. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for -the ResNet variants in [src/model/resnet.rs](src/model/resnet.rs). +the ResNet variants in [resnet.rs](resnet/src/resnet.rs). The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). @@ -28,12 +28,47 @@ resnet-burn = { git = "https://github.com/tracel-ai/models", package = "resnet-b ### Example Usage -The [inference example](examples/inference.rs) initializes a ResNet-18 from the ImageNet +#### Inference + +The [inference example](examples/inference/examples/inference.rs) initializes a ResNet-18 from the +ImageNet [pre-trained weights](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights) with the `NdArray` backend and performs inference on the provided input image. You can run the example with the following command: ```sh -cargo run --release --features pretrained --example inference samples/dog.jpg +cargo run --release --example inference samples/dog.jpg --release +``` + +#### Fine-tuning + +For this [multi-label image classification fine-tuning example](examples/finetune), a sample of the +planets dataset from the Kaggle competition +[Planet: Understanding the Amazon from Space](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space) +is downloaded from a +[fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). The +sample dataset is a collection of satellite images with multiple labels describing the scene, as +illustrated below. + +Planet dataset sample + +To achieve this task, a ResNet-18 pre-trained on the ImageNet dataset is fine-tuned on the target +planets dataset. The training recipe used is fairly simple. The main objective is to demonstrate how to re-use a +pre-trained model for a different downstream task. + +Without any bells and whistle, our model achieves over 90% multi-label accuracy (i.e., hamming +score) on the validation set after just 5 epochs. + +Run the example with the Torch GPU backend: + +```sh +export TORCH_CUDA_VERSION=cu121 +cargo run --release --example finetune --features tch-gpu +``` + +Run it with our WGPU backend: + +```sh +cargo run --release --example finetune --features wgpu ``` diff --git a/resnet-burn/examples/finetune/.gitignore b/resnet-burn/examples/finetune/.gitignore new file mode 100644 index 0000000..af39006 --- /dev/null +++ b/resnet-burn/examples/finetune/.gitignore @@ -0,0 +1,2 @@ +# Downloaded files +planet_sample/ \ No newline at end of file diff --git a/resnet-burn/examples/finetune/Cargo.toml b/resnet-burn/examples/finetune/Cargo.toml new file mode 100644 index 0000000..84a7e15 --- /dev/null +++ b/resnet-burn/examples/finetune/Cargo.toml @@ -0,0 +1,24 @@ +[package] +authors = ["guillaumelagrange "] +name = "finetune" +license.workspace = true +version.workspace = true +edition.workspace = true + +[features] +default = ["burn/default"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] + +[dependencies] +resnet-burn = { path = "../../resnet", features = ["pretrained"] } +burn = { workspace = true, features = ["train", "vision", "network"] } + +# Dataset files +csv = "1.3.0" +flate2 = "1.0.28" +rand = { version = "0.8.5", default-features = false, features = [ + "std_rng", +] } +serde = { version = "1.0.192", features = ["std", "derive"] } +tar = "0.4.40" \ No newline at end of file diff --git a/resnet-burn/examples/finetune/examples/finetune.rs b/resnet-burn/examples/finetune/examples/finetune.rs new file mode 100644 index 0000000..5a0b033 --- /dev/null +++ b/resnet-burn/examples/finetune/examples/finetune.rs @@ -0,0 +1,44 @@ +use burn::{backend::Autodiff, tensor::backend::Backend}; +use finetune::{inference::infer, training::train}; + +#[allow(dead_code)] +const ARTIFACT_DIR: &str = "/tmp/resnet-finetune"; + +#[allow(dead_code)] +fn run(device: B::Device) { + train::>(ARTIFACT_DIR, device.clone()); + infer::(ARTIFACT_DIR, device, 0.5); +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + super::run::(device); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use burn::{ + backend::wgpu::{Wgpu, WgpuDevice}, + Wgpu, + }; + + pub fn run() { + super::run::(WgpuDevice::default()); + } +} + +fn main() { + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); +} diff --git a/resnet-burn/examples/finetune/src/data.rs b/resnet-burn/examples/finetune/src/data.rs new file mode 100644 index 0000000..ce701e3 --- /dev/null +++ b/resnet-burn/examples/finetune/src/data.rs @@ -0,0 +1,132 @@ +use burn::{ + data::{ + dataloader::batcher::Batcher, + dataset::vision::{Annotation, ImageDatasetItem, PixelDepth}, + }, + prelude::*, +}; + +use super::dataset::CLASSES; + +// ImageNet mean and std values +const MEAN: [f32; 3] = [0.485, 0.456, 0.406]; +const STD: [f32; 3] = [0.229, 0.224, 0.225]; + +// Planets patch size +const WIDTH: usize = 256; +const HEIGHT: usize = 256; + +/// Create a multi-hot encoded tensor. +/// +/// # Example +/// +/// ```rust, ignore +/// let multi_hot = multi_hot::(&[2, 5, 8], 10, &device); +/// println!("{}", multi_hot.to_data()); +/// // [0, 0, 1, 0, 0, 1, 0, 0, 1, 0] +/// ``` +pub fn multi_hot( + indices: &[usize], + num_classes: usize, + device: &B::Device, +) -> Tensor { + Tensor::zeros(Shape::new([num_classes]), device).scatter( + 0, + Tensor::from_ints( + indices + .iter() + .map(|i| *i as i32) + .collect::>() + .as_slice(), + device, + ), + Tensor::ones(Shape::new([indices.len()]), device), + ) +} + +/// Normalizer with ImageNet values as it helps accelerate training since we are fine-tuning from +/// ImageNet pre-trained weights and the model expects the data to be in this normalized range. +#[derive(Clone)] +pub struct Normalizer { + pub mean: Tensor, + pub std: Tensor, +} + +impl Normalizer { + /// Creates a new normalizer. + pub fn new(device: &Device) -> Self { + let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]); + let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]); + Self { mean, std } + } + + /// Normalizes the input image according to the ImageNet dataset. + /// + /// The input image should be in the range [0, 1]. + /// The output image will be in the range [-1, 1]. + /// + /// The normalization is done according to the following formula: + /// `input = (input - mean) / std` + pub fn normalize(&self, input: Tensor) -> Tensor { + (input - self.mean.clone()) / self.std.clone() + } +} + +#[derive(Clone)] +pub struct ClassificationBatcher { + normalizer: Normalizer, + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct ClassificationBatch { + pub images: Tensor, + pub targets: Tensor, +} + +impl ClassificationBatcher { + pub fn new(device: B::Device) -> Self { + Self { + normalizer: Normalizer::::new(&device), + device, + } + } +} + +impl Batcher> for ClassificationBatcher { + fn batch(&self, items: Vec) -> ClassificationBatch { + fn image_as_vec_u8(item: ImageDatasetItem) -> Vec { + // Convert Vec to Vec (Planet images are u8) + item.image + .into_iter() + .map(|p: PixelDepth| -> u8 { p.try_into().unwrap() }) + .collect::>() + } + + let targets = items + .iter() + .map(|item| { + // Expect multi-hot encoded class labels as target (e.g., [0, 1, 0, 0, 1]) + if let Annotation::MultiLabel(y) = &item.annotation { + multi_hot(y, CLASSES.len(), &self.device) + } else { + panic!("Invalid target type") + } + }) + .collect(); + + let images = items + .into_iter() + .map(|item| Data::new(image_as_vec_u8(item), Shape::new([HEIGHT, WIDTH, 3]))) + .map(|data| Tensor::::from_data(data.convert(), &self.device).permute([2, 0, 1])) + .map(|tensor| tensor / 255) // normalize between [0, 1] + .collect(); + + let images = Tensor::stack(images, 0); + let targets = Tensor::stack(targets, 0); + + let images = self.normalizer.normalize(images); + + ClassificationBatch { images, targets } + } +} diff --git a/resnet-burn/examples/finetune/src/dataset.rs b/resnet-burn/examples/finetune/src/dataset.rs new file mode 100644 index 0000000..c340d26 --- /dev/null +++ b/resnet-burn/examples/finetune/src/dataset.rs @@ -0,0 +1,157 @@ +use flate2::read::GzDecoder; +use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; +use tar::Archive; + +use burn::data::{ + dataset::vision::{ImageFolderDataset, ImageLoaderError}, + network::downloader, +}; + +/// Planets dataset sample mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). +/// Licensed under the [Appache License](https://github.com/fastai/fastai/blob/master/LICENSE). +const URL: &str = "https://s3.amazonaws.com/fast-ai-sample/planet_sample.tgz"; +const LABELS: &str = "labels.csv"; +pub const CLASSES: [&str; 17] = [ + "agriculture", + "artisinal_mine", + "bare_ground", + "blooming", + "blow_down", + "clear", + "cloudy", + "conventional_mine", + "cultivation", + "habitation", + "haze", + "partly_cloudy", + "primary", + "road", + "selective_logging", + "slash_burn", + "water", +]; + +/// A sample of the planets dataset from the Kaggle competition +/// [Planet: Understanding the Amazon from Space](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space). +/// +/// This version of the multi-label classification dataset contains 1,000 256x256 image patches +/// with possibly multiple labels per patch. The labels can broadly be broken into three groups: +/// atmospheric conditions, common land cover/land use phenomena, and rare land cover/land use +/// phenomena. Each patch will have one and potentially more than one atmospheric label and zero +/// or more common and rare labels. +/// +/// The data is downloaded from the web from the [fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). +pub trait PlanetLoader: Sized { + fn planet_train_val_split( + train_percentage: u8, + seed: u64, + ) -> Result<(Self, Self), ImageLoaderError>; +} + +#[derive(Deserialize, Serialize, Debug, Clone)] +struct PlanetSample { + image_name: String, + tags: String, +} + +impl PlanetLoader for ImageFolderDataset { + /// Creates new Planet dataset for train and validation splits. + /// + /// # Arguments + /// + /// * `train_percentage` - Percentage of the training split. The remainder will be used for the validation split. + /// * `seed` - Controls the shuffling applied to the data before applying the split. + /// + fn planet_train_val_split( + train_percentage: u8, + seed: u64, + ) -> Result<(Self, Self), ImageLoaderError> { + assert!( + train_percentage > 0 && train_percentage < 100, + "Training split percentage must be between (0, 100)" + ); + let root = download(); + + // Load items from csv + let mut rdr = csv::ReaderBuilder::new() + .from_path(root.join(LABELS)) + .map_err(|err| ImageLoaderError::Unknown(err.to_string()))?; + + // Collect items (image path, labels) + let mut classes = HashSet::new(); + let mut items = rdr + .deserialize() + .map(|result| { + let item: PlanetSample = + result.map_err(|err| ImageLoaderError::Unknown(err.to_string()))?; + let tags = item + .tags + .split(' ') + .map(|s| s.to_string()) + .collect::>(); + + for tag in tags.iter() { + classes.insert(tag.clone()); + } + + Ok(( + // Full path to image + root.join("train") + .join(item.image_name) + .with_extension("jpg"), + // Multiple labels per image (e.g., ["haze", "primary", "water"]) + tags, + )) + }) + .collect::, _>>()?; + + // Sort class names + let mut classes = classes.iter().collect::>(); + classes.sort(); + assert_eq!(classes, CLASSES, "Invalid categories"); // just in case the labels unexpectedly change + + // Shuffle items + items.shuffle(&mut StdRng::seed_from_u64(seed)); + + // Split train and validation + let size = items.len(); + let train_slice = (size as f32 * (train_percentage as f32 / 100.0)) as usize; + + let train = Self::new_multilabel_classification_with_items( + items[..train_slice].to_vec(), + &classes, + )?; + let valid = Self::new_multilabel_classification_with_items( + items[train_slice..].to_vec(), + &classes, + )?; + + Ok((train, valid)) + } +} + +/// Download the Planet dataset from the web to the current example directory. +fn download() -> PathBuf { + // Point to current example directory + let example_dir = Path::new(file!()).parent().unwrap().parent().unwrap(); + let planet_dir = example_dir.join("planet_sample"); + + // Check for already downloaded content + let labels_file = planet_dir.join(LABELS); + if !labels_file.exists() { + // Download gzip file + let bytes = downloader::download_file_as_bytes(URL, "planet_sample.tgz"); + + // Decode gzip file content and unpack archive + let gz_buffer = GzDecoder::new(&bytes[..]); + let mut archive = Archive::new(gz_buffer); + archive.unpack(example_dir).unwrap(); + } + + planet_dir +} diff --git a/resnet-burn/examples/finetune/src/inference.rs b/resnet-burn/examples/finetune/src/inference.rs new file mode 100644 index 0000000..aa0defb --- /dev/null +++ b/resnet-burn/examples/finetune/src/inference.rs @@ -0,0 +1,55 @@ +use crate::{ + data::ClassificationBatcher, + dataset::{PlanetLoader, CLASSES}, + training::TrainingConfig, +}; +use burn::{ + data::{ + dataloader::batcher::Batcher, + dataset::{ + vision::{Annotation, ImageFolderDataset}, + Dataset, + }, + }, + prelude::*, + record::{CompactRecorder, Recorder}, + tensor::activation::sigmoid, +}; +use resnet_burn::ResNet; + +pub fn infer(artifact_dir: &str, device: B::Device, threshold: f32) { + // Load trained ResNet-18 + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into(), &device) + .expect("Trained model should exist"); + + let model: ResNet = ResNet::resnet18(config.num_classes, &device).load_record(record); + + // Get an item from validation split with multiple labels + let (_train, valid) = + ImageFolderDataset::planet_train_val_split(config.train_percentage, config.seed).unwrap(); + let item = valid.get(20).unwrap(); + + let label = if let Annotation::MultiLabel(ref categories) = item.annotation { + categories.iter().map(|&i| CLASSES[i]).collect::>() + } else { + panic!("Annotation should be multilabel") + }; + + // Forward pass with sigmoid activation function + let batcher = ClassificationBatcher::new(device); + let batch = batcher.batch(vec![item]); + let output = sigmoid(model.forward(batch.images)); + + // Get predicted class names over the specified threshold + let predicted = output.greater_equal_elem(threshold).nonzero()[1] + .to_data() + .value + .iter() + .map(|i| CLASSES[i.elem::() as usize]) + .collect::>(); + + println!("Predicted: {:?}\nExpected: {:?}", predicted, label); +} diff --git a/resnet-burn/examples/finetune/src/lib.rs b/resnet-burn/examples/finetune/src/lib.rs new file mode 100644 index 0000000..59066af --- /dev/null +++ b/resnet-burn/examples/finetune/src/lib.rs @@ -0,0 +1,4 @@ +pub mod data; +pub mod dataset; +pub mod inference; +pub mod training; diff --git a/resnet-burn/examples/finetune/src/training.rs b/resnet-burn/examples/finetune/src/training.rs new file mode 100644 index 0000000..0dd5d22 --- /dev/null +++ b/resnet-burn/examples/finetune/src/training.rs @@ -0,0 +1,159 @@ +use std::time::Instant; + +use crate::{ + data::{ClassificationBatch, ClassificationBatcher}, + dataset::{PlanetLoader, CLASSES}, +}; +use burn::{ + data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset}, + nn::loss::BinaryCrossEntropyLossConfig, + optim::{decay::WeightDecayConfig, AdamConfig}, + prelude::*, + record::CompactRecorder, + tensor::backend::AutodiffBackend, + train::{ + metric::{HammingScore, LossMetric}, + LearnerBuilder, MultiLabelClassificationOutput, TrainOutput, TrainStep, ValidStep, + }, +}; +use resnet_burn::{weights, ResNet}; + +const NUM_CLASSES: usize = CLASSES.len(); + +pub trait MultiLabelClassification { + fn forward_classification( + &self, + images: Tensor, + targets: Tensor, + ) -> MultiLabelClassificationOutput; +} + +impl MultiLabelClassification for ResNet { + fn forward_classification( + &self, + images: Tensor, + targets: Tensor, + ) -> MultiLabelClassificationOutput { + let output = self.forward(images); + let loss = BinaryCrossEntropyLossConfig::new() + .with_logits(true) + .init(&output.device()) + .forward(output.clone(), targets.clone()); + + MultiLabelClassificationOutput::new(loss, output, targets) + } +} + +impl TrainStep, MultiLabelClassificationOutput> + for ResNet +{ + fn step( + &self, + batch: ClassificationBatch, + ) -> TrainOutput> { + let item = self.forward_classification(batch.images, batch.targets); + + TrainOutput::new(self, item.loss.backward(), item) + } +} + +impl ValidStep, MultiLabelClassificationOutput> + for ResNet +{ + fn step(&self, batch: ClassificationBatch) -> MultiLabelClassificationOutput { + self.forward_classification(batch.images, batch.targets) + } +} + +#[derive(Config)] +pub struct TrainingConfig { + #[config(default = 5)] + pub num_epochs: usize, + + #[config(default = 128)] + pub batch_size: usize, + + #[config(default = 4)] + pub num_workers: usize, + + #[config(default = 42)] + pub seed: u64, + + #[config(default = 1e-3)] + pub learning_rate: f64, + + #[config(default = 5e-5)] + pub weight_decay: f64, + + #[config(default = 70)] + pub train_percentage: u8, + + pub num_classes: usize, +} + +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts before to get an accurate learner summary + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Config + let config = TrainingConfig::new(NUM_CLASSES); + let optimizer = AdamConfig::new() + .with_weight_decay(Some(WeightDecayConfig::new(config.weight_decay))) + .init(); + + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + + B::seed(config.seed); + + // Dataloaders + let batcher_train = ClassificationBatcher::::new(device.clone()); + let batcher_valid = ClassificationBatcher::::new(device.clone()); + + let (train, valid) = + ImageFolderDataset::planet_train_val_split(config.train_percentage, config.seed).unwrap(); + + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(train); + + let dataloader_test = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .num_workers(config.num_workers) + .build(valid); + + // Pre-trained ResNet-18 adapted for num_classes in this task + let model = ResNet::resnet18_pretrained(weights::ResNet18::ImageNet1kV1, &device) + .unwrap() + .with_classes(NUM_CLASSES); + + // Learner config + let learner = LearnerBuilder::new(artifact_dir) + .metric_train_numeric(HammingScore::new()) + .metric_valid_numeric(HammingScore::new()) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .devices(vec![device.clone()]) + .num_epochs(config.num_epochs) + .summary() + .build(model, optimizer, config.learning_rate); + + // Training + let now = Instant::now(); + let model_trained = learner.fit(dataloader_train, dataloader_test); + let elapsed = now.elapsed().as_secs(); + println!("Training completed in {}m{}s", (elapsed / 60), elapsed % 60); + + model_trained + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Trained model should be saved successfully"); +} diff --git a/resnet-burn/examples/inference/Cargo.toml b/resnet-burn/examples/inference/Cargo.toml new file mode 100644 index 0000000..8fc5573 --- /dev/null +++ b/resnet-burn/examples/inference/Cargo.toml @@ -0,0 +1,13 @@ +[package] +authors = ["guillaumelagrange "] +name = "inference" +license.workspace = true +version.workspace = true +edition.workspace = true +publish = false + +[dependencies] +resnet-burn = { path = "../../resnet", features = ["pretrained"] } +burn = { workspace = true, features = ["ndarray"] } +image = { version = "0.24.9", features = ["png", "jpeg"] } + diff --git a/resnet-burn/examples/inference.rs b/resnet-burn/examples/inference/examples/inference.rs similarity index 97% rename from resnet-burn/examples/inference.rs rename to resnet-burn/examples/inference/examples/inference.rs index 055bbc5..a26e597 100644 --- a/resnet-burn/examples/inference.rs +++ b/resnet-burn/examples/inference/examples/inference.rs @@ -1,4 +1,5 @@ -use resnet_burn::model::{imagenet, resnet::ResNet, weights}; +use inference::imagenet; +use resnet_burn::{weights, ResNet}; use burn::{ backend::NdArray, diff --git a/resnet-burn/src/model/imagenet.rs b/resnet-burn/examples/inference/src/imagenet.rs similarity index 100% rename from resnet-burn/src/model/imagenet.rs rename to resnet-burn/examples/inference/src/imagenet.rs diff --git a/resnet-burn/examples/inference/src/lib.rs b/resnet-burn/examples/inference/src/lib.rs new file mode 100644 index 0000000..d605ca9 --- /dev/null +++ b/resnet-burn/examples/inference/src/lib.rs @@ -0,0 +1 @@ +pub mod imagenet; diff --git a/resnet-burn/resnet/Cargo.toml b/resnet-burn/resnet/Cargo.toml new file mode 100644 index 0000000..dc984c8 --- /dev/null +++ b/resnet-burn/resnet/Cargo.toml @@ -0,0 +1,17 @@ +[package] +authors = ["guillaumelagrange "] +name = "resnet-burn" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[features] +default = [] +std = [] +pretrained = ["burn/network", "std", "dep:dirs"] + +[dependencies] +burn = { workspace = true } +burn-import = { workspace = true } +dirs = { workspace = true, optional = true } +serde = { workspace = true } diff --git a/resnet-burn/src/model/block.rs b/resnet-burn/resnet/src/block.rs similarity index 100% rename from resnet-burn/src/model/block.rs rename to resnet-burn/resnet/src/block.rs diff --git a/resnet-burn/resnet/src/lib.rs b/resnet-burn/resnet/src/lib.rs new file mode 100644 index 0000000..1b34e41 --- /dev/null +++ b/resnet-burn/resnet/src/lib.rs @@ -0,0 +1,9 @@ +#![cfg_attr(not(feature = "std"), no_std)] +mod block; +pub mod resnet; +pub mod weights; + +pub use resnet::*; +pub use weights::*; + +extern crate alloc; diff --git a/resnet-burn/src/model/resnet.rs b/resnet-burn/resnet/src/resnet.rs similarity index 97% rename from resnet-burn/src/model/resnet.rs rename to resnet-burn/resnet/src/resnet.rs index 43b6e85..e7f8787 100644 --- a/resnet-burn/src/model/resnet.rs +++ b/resnet-burn/resnet/src/resnet.rs @@ -95,7 +95,7 @@ impl ResNet { ) -> Result { let weights = weights.weights(); let record = Self::load_weights_record(&weights, device)?; - let model = ResNet::::resnet18(weights.num_classes, &device).load_record(record); + let model = ResNet::::resnet18(weights.num_classes, device).load_record(record); Ok(model) } @@ -132,7 +132,7 @@ impl ResNet { ) -> Result { let weights = weights.weights(); let record = Self::load_weights_record(&weights, device)?; - let model = ResNet::::resnet34(weights.num_classes, &device).load_record(record); + let model = ResNet::::resnet34(weights.num_classes, device).load_record(record); Ok(model) } @@ -169,7 +169,7 @@ impl ResNet { ) -> Result { let weights = weights.weights(); let record = Self::load_weights_record(&weights, device)?; - let model = ResNet::::resnet50(weights.num_classes, &device).load_record(record); + let model = ResNet::::resnet50(weights.num_classes, device).load_record(record); Ok(model) } @@ -206,7 +206,7 @@ impl ResNet { ) -> Result { let weights = weights.weights(); let record = Self::load_weights_record(&weights, device)?; - let model = ResNet::::resnet101(weights.num_classes, &device).load_record(record); + let model = ResNet::::resnet101(weights.num_classes, device).load_record(record); Ok(model) } @@ -243,7 +243,7 @@ impl ResNet { ) -> Result { let weights = weights.weights(); let record = Self::load_weights_record(&weights, device)?; - let model = ResNet::::resnet152(weights.num_classes, &device).load_record(record); + let model = ResNet::::resnet152(weights.num_classes, device).load_record(record); Ok(model) } diff --git a/resnet-burn/src/model/weights.rs b/resnet-burn/resnet/src/weights.rs similarity index 100% rename from resnet-burn/src/model/weights.rs rename to resnet-burn/resnet/src/weights.rs diff --git a/resnet-burn/samples/dataset.jpg b/resnet-burn/samples/dataset.jpg new file mode 100755 index 0000000..954e6d1 Binary files /dev/null and b/resnet-burn/samples/dataset.jpg differ diff --git a/resnet-burn/src/lib.rs b/resnet-burn/src/lib.rs deleted file mode 100644 index c85bcc1..0000000 --- a/resnet-burn/src/lib.rs +++ /dev/null @@ -1,3 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_std)] -pub mod model; -extern crate alloc; diff --git a/resnet-burn/src/model/mod.rs b/resnet-burn/src/model/mod.rs deleted file mode 100644 index b9e16f1..0000000 --- a/resnet-burn/src/model/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod block; -pub mod imagenet; -pub mod resnet; -pub mod weights;