Skip to content

Commit

Permalink
Merge branch 'main' into yolox
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Apr 25, 2024
2 parents 5e793df + b3abbfd commit e3d35e8
Show file tree
Hide file tree
Showing 20 changed files with 1,853 additions and 304 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
|------------------------------------------------|----------------------------------------------------------|----------------------------------------------|
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/README.md) |
| Model | Description | Repository Link |
|-------------------------------------------------|----------------------------------------------------------|------------------------------------------------|
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/README.md) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/README.md) |

## Community Contributions

Expand Down
2 changes: 1 addition & 1 deletion bert-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Add this to your `Cargo.toml`:

```toml
[dependencies]
bert-burn = { git = "https://github.com/burn-rs/models", package = "bert-burn", default-features = false }
bert-burn = { git = "https://github.com/tracel-ai/models", package = "bert-burn", default-features = false }
```

## Example Usage
Expand Down
25 changes: 25 additions & 0 deletions mobilenetv2-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[package]
authors = ["Arjun31415", "guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "mobilenetv2-burn"
version = "0.1.0"
edition = "2021"

[features]
default = []
std = []
pretrained = ["burn/network", "std", "dep:dirs"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { version = "0.13.0" }
burn-import = { version = "0.13.0" }
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { version = "0.13.0", features = ["ndarray"] }
image = { version = "0.24.9", features = ["png", "jpeg"] }
16 changes: 16 additions & 0 deletions mobilenetv2-burn/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.

## Sample Image

Image Title: Standing yellow Labrador Retriever dog.
Author: Djmirko
Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg
License: https://creativecommons.org/licenses/by-sa/3.0/

## Pre-trained Model

The ImageNet pre-trained model was ported from [`torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2`](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights).

As opposed to [other pre-trained models](https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_128gf.html#torchvision.models.RegNet_Y_128GF_Weights) in `torchvision`, no specific license was linked to the weights, which are assumed to be under the library's [BSD-3-Clause license](https://github.com/pytorch/vision/blob/main/LICENSE) ([ref](https://github.com/pytorch/vision/issues/160)).
40 changes: 40 additions & 0 deletions mobilenetv2-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# MobileNetV2 Burn

[MobileNetV2](https://arxiv.org/abs/1801.04381) is a convolutional neural network architecture for
classification tasks which seeks to perform well on mobile devices. You can find the
[Burn](https://github.com/tracel-ai/burn) implementation for the MobileNetV2 in
[src/model/mobilenetv2.rs](src/model/mobilenetv2.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

## Usage

### `Cargo.toml`

Add this to your `Cargo.toml`:

```toml
[dependencies]
mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", default-features = false }
```

If you want to get the pre-trained ImageNet weights, enable the `pretrained` feature flag.

```toml
[dependencies]
mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", features = ["pretrained"] }
```

**Important:** this feature requires `std`.

### Example Usage

The [inference example](examples/inference.rs) initializes a MobileNetV2 from the ImageNet
[pre-trained weights](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights)
with the `NdArray` backend and performs inference on the provided input image.

You can run the example with the following command:

```sh
cargo run --release --features pretrained --example inference samples/dog.jpg
```
69 changes: 69 additions & 0 deletions mobilenetv2-burn/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use mobilenetv2_burn::model::{imagenet, mobilenetv2::MobileNetV2, weights};

use burn::{
backend::NdArray,
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
};

const HEIGHT: usize = 224;
const WIDTH: usize = 224;

fn to_tensor<B: Backend, T: Element>(
data: Vec<T>,
shape: [usize; 3],
device: &Device<B>,
) -> Tensor<B, 3> {
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
// [H, W, C] -> [C, H, W]
.permute([2, 0, 1])
/ 255 // normalize between [0, 1]
}

pub fn main() {
// Parse arguments
let img_path = std::env::args().nth(1).expect("No image path provided");

// Create MobileNetV2
let device = Default::default();
let model: MobileNetV2<NdArray> =
MobileNetV2::pretrained(weights::MobileNetV2::ImageNet1kV2, &device)
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}"))
.unwrap();

// Load image
let img = image::open(&img_path)
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}"))
.unwrap();

// Resize to 224x224
let resized_img = img.resize_exact(
WIDTH as u32,
HEIGHT as u32,
image::imageops::FilterType::Triangle, // also known as bilinear in 2D
);

// Create tensor from image data
let img_tensor = to_tensor(
resized_img.into_rgb8().into_raw(),
[HEIGHT, WIDTH, 3],
&device,
)
.unsqueeze::<4>(); // [B, C, H, W]

// Normalize the image
let x = imagenet::Normalizer::new(&device).normalize(img_tensor);

// Forward pass
let out = model.forward(x);

// Output class index w/ score (raw)
let (score, idx) = out.max_dim_with_indices(1);
let idx = idx.into_scalar() as usize;

println!(
"Predicted: {}\nCategory Id: {}\nScore: {:.4}",
imagenet::CLASSES[idx],
idx,
score.into_scalar()
);
}
Binary file added mobilenetv2-burn/samples/dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions mobilenetv2-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#![cfg_attr(not(feature = "std"), no_std)]
pub mod model;
extern crate alloc;
83 changes: 83 additions & 0 deletions mobilenetv2-burn/src/model/conv_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
BatchNorm, BatchNormConfig, PaddingConfig2d,
},
tensor::{self, backend::Backend, Tensor},
};

/// A rectified linear unit where the activation is limited to a maximum of 6.
#[derive(Module, Debug, Clone, Default)]
pub struct ReLU6 {}
impl ReLU6 {
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
tensor::activation::relu(input).clamp_max(6)
}
}

/// A Conv2d -> BatchNorm -> activation block.
#[derive(Module, Debug)]
pub struct Conv2dNormActivation<B: Backend> {
conv: Conv2d<B>,
norm: BatchNorm<B, 2>,
activation: ReLU6,
}

/// [Conv2dNormActivation] configuration.
#[derive(Config, Debug)]
pub struct Conv2dNormActivationConfig {
pub in_channels: usize,
pub out_channels: usize,

#[config(default = "3")]
pub kernel_size: usize,

#[config(default = "1")]
pub stride: usize,

#[config(default = "None")]
pub padding: Option<usize>,

#[config(default = "1")]
pub groups: usize,

#[config(default = "1")]
pub dilation: usize,

#[config(default = false)]
pub bias: bool,
}

impl Conv2dNormActivationConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2dNormActivation<B> {
let padding = if let Some(padding) = self.padding {
padding
} else {
(self.kernel_size - 1) / 2 * self.dilation
};

Conv2dNormActivation {
conv: Conv2dConfig::new(
[self.in_channels, self.out_channels],
[self.kernel_size, self.kernel_size],
)
.with_padding(PaddingConfig2d::Explicit(padding, padding))
.with_stride([self.stride, self.stride])
.with_bias(self.bias)
.with_dilation([self.dilation, self.dilation])
.with_groups(self.groups)
.init(device),
norm: BatchNormConfig::new(self.out_channels).init(device),
activation: ReLU6 {},
}
}
}
impl<B: Backend> Conv2dNormActivation<B> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(input);
let x = self.norm.forward(x);
self.activation.forward(x)
}
}
Loading

0 comments on commit e3d35e8

Please sign in to comment.