Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/burn-vision/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ workspace = true

[features]
default = ["ndarray", "cubecl-backend", "fusion", "std"]
std = []
std = ["aligned-vec/std"]
tracing = [
"burn-candle?/tracing",
"burn-cubecl?/tracing",
Expand All @@ -46,6 +46,7 @@ test-vulkan = ["burn-wgpu/vulkan", "test-wgpu"]
test-metal = ["burn-wgpu/metal", "test-wgpu"]

[dependencies]
aligned-vec = { version = "0.6", default-features = false }
bon = { workspace = true }
burn-candle = { path = "../burn-candle", version = "=0.20.0-pre.6", optional = true }
burn-cubecl = { path = "../burn-cubecl", version = "=0.20.0-pre.6", optional = true }
Expand Down
2 changes: 2 additions & 0 deletions crates/burn-vision/src/backends/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod base;
mod connected_components;
mod morphology;
mod nms;
mod ops;

pub use base::*;
pub use connected_components::*;
pub use morphology::*;
pub use nms::*;
212 changes: 212 additions & 0 deletions crates/burn-vision/src/backends/cpu/nms.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use crate::NmsOptions;
use aligned_vec::{AVec, ConstAlign};
use alloc::vec::Vec;
use burn_tensor::{Int, Shape, Tensor, TensorData, backend::Backend};
use macerator::{Scalar, Simd, Vector, vload};

/// Perform NMS on CPU using SIMD acceleration.
///
/// This implementation:
/// 1. Sorts boxes by score (descending)
/// 2. Iteratively selects the highest-scoring non-suppressed box
/// 3. Suppresses all boxes with IoU > threshold using SIMD
pub fn nms<B: Backend>(
boxes: Tensor<B, 2>,
scores: Tensor<B, 1>,
options: NmsOptions,
) -> Tensor<B, 1, Int> {
let device = boxes.device();
let [n_boxes, _] = boxes.shape().dims();
if n_boxes == 0 {
return Tensor::<B, 1, Int>::empty([0], &device);
}

// Get raw data
let boxes_data = boxes.to_data();
let boxes_vec: Vec<f32> = boxes_data.to_vec().unwrap();

let scores_data = scores.to_data();
let scores_vec: Vec<f32> = scores_data.to_vec().unwrap();

let keep = nms_vec(boxes_vec, scores_vec, options);
let n_kept = keep.len();
let indices_data = TensorData::new(keep, Shape::new([n_kept]));
Tensor::<B, 1, Int>::from_data(indices_data, &device)
}

/// Perform NMS on CPU using SIMD acceleration.
fn nms_vec(boxes_vec: Vec<f32>, scores_vec: Vec<f32>, options: NmsOptions) -> Vec<i32> {
let n_boxes = scores_vec.len();

if n_boxes == 0 {
return vec![];
}

// Filter by score threshold first
let mut filtered_indices = Vec::with_capacity(n_boxes);
for (i, &score) in scores_vec.iter().enumerate() {
if score >= options.score_threshold {
filtered_indices.push(i); // original index
}
}

let n_filtered = filtered_indices.len();
if n_filtered == 0 {
return vec![];
}

// Sort by score descending
filtered_indices.sort_by(|&a, &b| scores_vec[b].total_cmp(&scores_vec[a]));

const ALIGN: usize = 64;
const FLOATS_PER_ALIGN: usize = ALIGN / size_of::<f32>(); // 16
let stride = n_filtered.div_ceil(FLOATS_PER_ALIGN) * FLOATS_PER_ALIGN;
let mut buf: AVec<f32, ConstAlign<64>> = AVec::with_capacity(ALIGN, stride * 5);
buf.resize(stride * 5, 0.0);

let (x1s, rest) = buf.split_at_mut(stride);
let (y1s, rest) = rest.split_at_mut(stride);
let (x2s, rest) = rest.split_at_mut(stride);
let (y2s, areas) = rest.split_at_mut(stride);

// Convert filtered boxes to SoA format
for (j, &orig_idx) in filtered_indices.iter().enumerate() {
let x1 = boxes_vec[orig_idx * 4];
let y1 = boxes_vec[orig_idx * 4 + 1];
let x2 = boxes_vec[orig_idx * 4 + 2];
let y2 = boxes_vec[orig_idx * 4 + 3];
x1s[j] = x1;
y1s[j] = y1;
x2s[j] = x2;
y2s[j] = y2;
areas[j] = (x2 - x1) * (y2 - y1);
}

// Apply NMS with SIMD dispatch
let mut suppressed = vec![false; stride];
let mut keep = Vec::new();

for i in 0..n_filtered {
if suppressed[i] {
continue;
}

// Optimization to reduce inner loop comparisons
suppressed[i] = true;
keep.push(filtered_indices[i] as i32); // original index

if options.max_output_boxes > 0 && keep.len() >= options.max_output_boxes {
break;
}

// Suppress overlapping boxes using SIMD
suppress_overlapping(
x1s[i],
y1s[i],
x2s[i],
y2s[i],
areas[i],
x1s,
y1s,
x2s,
y2s,
areas,
&mut suppressed,
stride,
options.iou_threshold,
);
}

keep
}

/// SIMD-accelerated suppression of overlapping boxes.
#[allow(clippy::too_many_arguments)]
#[inline(always)]
#[macerator::with_simd]
fn suppress_overlapping<'a, S: Simd>(
ref_x1: f32,
ref_y1: f32,
ref_x2: f32,
ref_y2: f32,
ref_area: f32,
x1s: &'a [f32],
y1s: &'a [f32],
x2s: &'a [f32],
y2s: &'a [f32],
areas: &'a [f32],
suppressed: &'a mut [bool],
n_boxes: usize, // stride, always multiple of lanes
threshold: f32,
) where
'a: 'a,
{
let lanes = f32::lanes::<S>();

// Splat reference values
let ref_x1_v: Vector<S, f32> = ref_x1.splat();
let ref_y1_v: Vector<S, f32> = ref_y1.splat();
let ref_x2_v: Vector<S, f32> = ref_x2.splat();
let ref_y2_v: Vector<S, f32> = ref_y2.splat();
let ref_area_v: Vector<S, f32> = ref_area.splat();
let thresh_v: Vector<S, f32> = threshold.splat();
let zero_v: Vector<S, f32> = 0.0f32.splat();

let mut i = 0;

let mut mask_buf = core::mem::MaybeUninit::<[bool; 16]>::uninit();
// Process lanes boxes at a time with SIMD
while i + lanes <= n_boxes {
// Skip if all boxes in this chunk are already suppressed
let all_suppressed = unsafe {
match lanes {
4 => *(suppressed.as_ptr().add(i) as *const u32) == 0x01010101,
8 => *(suppressed.as_ptr().add(i) as *const u64) == 0x0101010101010101,
16 => {
*(suppressed.as_ptr().add(i) as *const u128)
== 0x01010101010101010101010101010101
}
_ => unreachable!(),
}
};

if !all_suppressed {
let x1_v: Vector<S, f32> = unsafe { vload(x1s.as_ptr().add(i)) };
let y1_v: Vector<S, f32> = unsafe { vload(y1s.as_ptr().add(i)) };
let x2_v: Vector<S, f32> = unsafe { vload(x2s.as_ptr().add(i)) };
let y2_v: Vector<S, f32> = unsafe { vload(y2s.as_ptr().add(i)) };
let area_v: Vector<S, f32> = unsafe { vload(areas.as_ptr().add(i)) };

// Compute intersection coordinates
let xx1 = ref_x1_v.max(x1_v);
let yy1 = ref_y1_v.max(y1_v);
let xx2 = ref_x2_v.min(x2_v);
let yy2 = ref_y2_v.min(y2_v);

// Compute intersection area (clamp to 0 for non-overlapping)
let w = (xx2 - xx1).max(zero_v);
let h = (yy2 - yy1).max(zero_v);
let inter = w * h;

// Compute IoU
let union = ref_area_v + area_v - inter;
let iou = inter / union;

// Get suppression mask (IoU > threshold)
let suppress_mask = iou.gt(thresh_v);

// Extract mask to bool array and apply to suppressed
// SAFETY: mask_store_as_bool writes exactly `lanes` bools, we only read 0..lanes
unsafe { f32::mask_store_as_bool::<S>(mask_buf.as_mut_ptr().cast(), suppress_mask) };
let mask_buf = unsafe { mask_buf.assume_init() };

for k in 0..lanes {
if mask_buf[k] {
suppressed[i + k] = true;
}
}
}

i += lanes;
}
}
1 change: 1 addition & 0 deletions crates/burn-vision/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! Currently implemented are:
//! - `connected_components`
//! - `connected_components_with_stats`
//! - `nms` (Non-Maximum Suppression)
//!
#![warn(missing_docs)]
Expand Down
46 changes: 46 additions & 0 deletions crates/burn-vision/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,29 @@ impl ConnectedStatsOptions {
}
}

/// Non-Maximum Suppression options.
#[derive(Clone, Copy, Debug)]
pub struct NmsOptions {
/// IoU threshold for suppression (default: 0.5).
/// Boxes with IoU > threshold with a higher-scoring box are suppressed.
pub iou_threshold: f32,
/// Score threshold to filter boxes before NMS (default: 0.0, i.e., no filtering).
/// Boxes with score < score_threshold are discarded.
pub score_threshold: f32,
/// Maximum number of boxes to keep (0 = unlimited).
pub max_output_boxes: usize,
}
Comment on lines +166 to +177
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future notes: might be useful to extend box format to [cx, cy, w, h] on top of the current [x1, y1, x2, y2] format.


impl Default for NmsOptions {
fn default() -> Self {
Self {
iou_threshold: 0.5,
score_threshold: 0.0,
max_output_boxes: 0,
}
}
}

/// Vision capable backend, implemented by each backend
pub trait VisionBackend:
BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend
Expand Down Expand Up @@ -262,6 +285,29 @@ pub trait FloatVisionOps: Backend {
.into_primitive()
.tensor()
}

/// Perform Non-Maximum Suppression on bounding boxes.
///
/// Returns indices of kept boxes after suppressing overlapping detections.
/// Boxes are processed in descending score order; a box suppresses all
/// lower-scoring boxes with IoU > threshold.
///
/// # Arguments
/// * `boxes` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
/// * `scores` - Confidence scores as \[N\] tensor
/// * `options` - NMS options (IoU threshold, score threshold, max boxes)
///
/// # Returns
/// Indices of kept boxes as \[M\] tensor where M <= N
fn nms(
boxes: FloatTensor<Self>,
scores: FloatTensor<Self>,
options: NmsOptions,
) -> IntTensor<Self> {
let boxes = Tensor::<Self, 2>::from_primitive(TensorPrimitive::Float(boxes));
let scores = Tensor::<Self, 1>::from_primitive(TensorPrimitive::Float(scores));
cpu::nms::<Self>(boxes, scores, options).into_primitive()
}
}

/// Vision ops on quantized float tensors
Expand Down
32 changes: 31 additions & 1 deletion crates/burn-vision/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use burn_tensor::{
};

use crate::{
BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, VisionBackend,
BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions,
VisionBackend,
};

/// Connected components tensor extensions
Expand Down Expand Up @@ -53,6 +54,24 @@ pub trait MorphologyKind<B: Backend>: BasicOps<B> {
) -> Self::Primitive;
}

/// Non-maximum suppression tensor operations
pub trait Nms<B: Backend> {
/// Perform Non-Maximum Suppression on this tensor of bounding boxes.
///
/// Returns indices of kept boxes after suppressing overlapping detections.
/// Boxes are processed in descending score order; a box suppresses all
/// lower-scoring boxes with IoU > threshold.
///
/// # Arguments
/// * `self` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format
/// * `scores` - Confidence scores as \[N\] tensor
/// * `options` - NMS options (IoU threshold, score threshold, max boxes)
///
/// # Returns
/// Indices of kept boxes as \[M\] tensor where M <= N
fn nms(self, scores: Tensor<B, 1, Float>, opts: NmsOptions) -> Tensor<B, 1, Int>;
}

impl<B: BoolVisionOps> ConnectedComponents<B> for Tensor<B, 2, Bool> {
fn connected_components(self, connectivity: Connectivity) -> Tensor<B, 2, Int> {
Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity))
Expand Down Expand Up @@ -154,3 +173,14 @@ impl<B: VisionBackend> MorphologyKind<B> for Bool {
B::bool_dilate(tensor, kernel, opts)
}
}

impl<B: VisionBackend> Nms<B> for Tensor<B, 2> {
fn nms(self, scores: Tensor<B, 1>, options: NmsOptions) -> Tensor<B, 1, Int> {
match (self.into_primitive(), scores.into_primitive()) {
(TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => {
Tensor::<B, 1, Int>::from_primitive(B::nms(boxes, scores, options))
}
_ => todo!("Quantized inputs are not yet supported"),
}
}
}
Loading