Skip to content

Commit

Permalink
Add YOLOX object detection (#24)
Browse files Browse the repository at this point in the history
* Add yolox base models

* Fix 2d grid for anchors

* Change sample image

* Add post-processing and inference results

* Cleanup

* Fix 2d grid repeat with dim=1

* Default to ndarray backend

* Switch to YOLOX-tiny for example

* Remove dead comment

* Use tensor.dims()

* Use the new tensor.permute()

* Fix comments

- Pre-trained weights are from COCO (README)
- Remove training outputs TODO
- Current example uses YOLOX-Tiny (Nano WIP)

* Add YOLOX-Nano w/ depthwise separable conv (enum)

* Remove dead code comment

* Remove incorrect return comment

* Add YOLOX to models README

* Fix dead comments and add enum variants doc

* Rephrase enum variants doc

* Change burn-rs -> tracel-ai links

* Upgrade to Burn 0.13.0

- Removed init_with methods
- Fixed empty MaxPool2d vec initialization

* Update image version
  • Loading branch information
laggui committed Apr 25, 2024
1 parent b3abbfd commit 0f9c6c6
Show file tree
Hide file tree
Showing 19 changed files with 1,885 additions and 6 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
|-------------------------------------------------|-------------------------------------------------------|----------------------------------------------|
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/README.md) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| Model | Description | Repository Link |
|-------------------------------------------------|----------------------------------------------------------|------------------------------------------------|
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/README.md) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/README.md) |

## Community Contributions

Expand Down
2 changes: 2 additions & 0 deletions yolox-burn/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Output image
*.output.png
28 changes: 28 additions & 0 deletions yolox-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "yolox-burn"
version = "0.1.0"
edition = "2021"

[features]
default = []
std = []
pretrained = ["burn/network", "std", "dep:dirs"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { version = "0.13.0", default-features = false }
burn-import = { version = "0.13.0" }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { version = "0.13.0", features = ["ndarray"] }
image = { version = "0.24.9", features = ["png", "jpeg"] }
1 change: 1 addition & 0 deletions yolox-burn/LICENSE-APACHE
1 change: 1 addition & 0 deletions yolox-burn/LICENSE-MIT
16 changes: 16 additions & 0 deletions yolox-burn/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided.

## Sample Image

Image Title: Man with Bike and Pet Dog circa 1900 (archive ref DDX1319-2-3)
Author: East Riding Archives
Source: https://commons.wikimedia.org/wiki/File:Man_with_Bike_and_Pet_Dog_circa_1900_%28archive_ref_DDX1319-2-3%29_%2826507570321%29.jpg
License: [Creative Commons](https://www.flickr.com/commons/usage/)

## Pre-trained Model

The COCO pre-trained model was ported from the original [YOLOX implementation](https://github.com/Megvii-BaseDetection/YOLOX).

As opposed to other YOLO variants (YOLOv8, YOLO-NAS, etc.), both the code and pre-trained weights are distributed under the [Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license.
44 changes: 44 additions & 0 deletions yolox-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# YOLOX Burn

There have been many different object detection models with the YOLO prefix released in the recent
years, though most of them carry a GPL or AGPL license which restricts their usage. For this reason,
we selected [YOLOX](https://arxiv.org/abs/2107.08430) as the first object detection architecture
since both the original code and pre-trained weights are released under the
[Apache 2.0](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/LICENSE) open source license.

You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the YOLOX variants in
[src/model/yolox.rs](src/model/yolox.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

## Usage

### `Cargo.toml`

Add this to your `Cargo.toml`:

```toml
[dependencies]
yolox-burn = { git = "https://github.com/tracel-ai/models", package = "yolox-burn", default-features = false }
```

If you want to get the COCO pre-trained weights, enable the `pretrained` feature flag.

```toml
[dependencies]
yolox-burn = { git = "https://github.com/tracel-ai/models", package = "yolox-burn", features = ["pretrained"] }
```

**Important:** this feature requires `std`.

### Example Usage

The [inference example](examples/inference.rs) initializes a YOLOX-Tiny from the COCO
[pre-trained weights](https://github.com/Megvii-BaseDetection/YOLOX?tab=readme-ov-file#standard-models)
with the `NdArray` backend and performs inference on the provided input image.

You can run the example with the following command:

```sh
cargo run --release --features pretrained --example inference samples/dog_bike_man.jpg
```
145 changes: 145 additions & 0 deletions yolox-burn/examples/inference.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::path::Path;

use image::{DynamicImage, ImageBuffer};
use yolox_burn::model::{boxes::nms, weights, yolox::Yolox, BoundingBox};

use burn::{
backend::NdArray,
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
};

const HEIGHT: usize = 640;
const WIDTH: usize = 640;

fn to_tensor<B: Backend, T: Element>(
data: Vec<T>,
shape: [usize; 3],
device: &Device<B>,
) -> Tensor<B, 3> {
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
// [H, W, C] -> [C, H, W]
.permute([2, 0, 1])
}

/// Draws bounding boxes on the given image.
///
/// # Arguments
///
/// * `image`: Original input image.
/// * `boxes` - Bounding boxes, grouped per class.
/// * `color` - [R, G, B] color values to draw the boxes.
/// * `ratio` - [x, y] aspect ratio to scale the predicted boxes.
///
/// # Returns
///
/// The image annotated with bounding boxes.
fn draw_boxes(
image: DynamicImage,
boxes: &[Vec<BoundingBox>],
color: &[u8; 3],
ratio: &[f32; 2], // (x, y) ratio
) -> DynamicImage {
// Assumes x1 <= x2 and y1 <= y2
fn draw_rect(
image: &mut ImageBuffer<image::Rgb<u8>, Vec<u8>>,
x1: u32,
x2: u32,
y1: u32,
y2: u32,
color: &[u8; 3],
) {
for x in x1..=x2 {
let pixel = image.get_pixel_mut(x, y1);
*pixel = image::Rgb(*color);
let pixel = image.get_pixel_mut(x, y2);
*pixel = image::Rgb(*color);
}
for y in y1..=y2 {
let pixel = image.get_pixel_mut(x1, y);
*pixel = image::Rgb(*color);
let pixel = image.get_pixel_mut(x2, y);
*pixel = image::Rgb(*color);
}
}

// Annotate the original image and print boxes information.
let (image_h, image_w) = (image.height(), image.width());
let mut image = image.to_rgb8();
for (class_index, bboxes_for_class) in boxes.iter().enumerate() {
for b in bboxes_for_class.iter() {
let xmin = (b.xmin * ratio[0]).clamp(0., image_w as f32 - 1.);
let ymin = (b.ymin * ratio[1]).clamp(0., image_h as f32 - 1.);
let xmax = (b.xmax * ratio[0]).clamp(0., image_w as f32 - 1.);
let ymax = (b.ymax * ratio[1]).clamp(0., image_h as f32 - 1.);

println!(
"Predicted {} ({:.2}) at [{:.2}, {:.2}, {:.2}, {:.2}]",
class_index, b.confidence, xmin, ymin, xmax, ymax,
);

draw_rect(
&mut image,
xmin as u32,
xmax as u32,
ymin as u32,
ymax as u32,
color,
);
}
}
DynamicImage::ImageRgb8(image)
}

pub fn main() {
// Parse arguments
let img_path = std::env::args().nth(1).expect("No image path provided");

// Create YOLOX-Tiny
let device = Default::default();
let model: Yolox<NdArray> = Yolox::yolox_tiny_pretrained(weights::YoloxTiny::Coco, &device)
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}"))
.unwrap();

// Load image
let img = image::open(&img_path)
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}"))
.unwrap();

// Resize to 640x640
let resized_img = img.resize_exact(
WIDTH as u32,
HEIGHT as u32,
image::imageops::FilterType::Triangle, // also known as bilinear in 2D
);

// Create tensor from image data
let x = to_tensor(
resized_img.into_rgb8().into_raw(),
[HEIGHT, WIDTH, 3],
&device,
)
.unsqueeze::<4>(); // [B, C, H, W]

// Forward pass
let out = model.forward(x);

// Post-processing
let [_, num_boxes, num_outputs] = out.dims();
let boxes = out.clone().slice([0..1, 0..num_boxes, 0..4]);
let obj_scores = out.clone().slice([0..1, 0..num_boxes, 4..5]);
let cls_scores = out.slice([0..1, 0..num_boxes, 5..num_outputs]);
let scores = cls_scores * obj_scores;
let boxes = nms(boxes, scores, 0.65, 0.5);

// Draw outputs and save results
let (h, w) = (img.height(), img.width());
let img_out = draw_boxes(
img,
&boxes[0],
&[239u8, 62u8, 5u8],
&[w as f32 / WIDTH as f32, h as f32 / HEIGHT as f32],
);

let img_path = Path::new(&img_path);
let _ = img_out.save(img_path.with_extension("output.png"));
}
Binary file added yolox-burn/samples/dog_bike_man.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions yolox-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#![cfg_attr(not(feature = "std"), no_std)]
pub mod model;
extern crate alloc;
Loading

0 comments on commit 0f9c6c6

Please sign in to comment.