diff --git a/Cargo.lock b/Cargo.lock index ea43fd7a0f..fd85fb585d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1199,6 +1199,7 @@ dependencies = [ name = "burn-vision" version = "0.20.0-pre.6" dependencies = [ + "aligned-vec", "bon", "burn-candle", "burn-cubecl", diff --git a/crates/burn-vision/Cargo.toml b/crates/burn-vision/Cargo.toml index 01f6407f0f..06d46f15b4 100644 --- a/crates/burn-vision/Cargo.toml +++ b/crates/burn-vision/Cargo.toml @@ -20,7 +20,7 @@ workspace = true [features] default = ["ndarray", "cubecl-backend", "fusion", "std"] -std = [] +std = ["aligned-vec/std"] tracing = [ "burn-candle?/tracing", "burn-cubecl?/tracing", @@ -46,6 +46,7 @@ test-vulkan = ["burn-wgpu/vulkan", "test-wgpu"] test-metal = ["burn-wgpu/metal", "test-wgpu"] [dependencies] +aligned-vec = { version = "0.6", default-features = false } bon = { workspace = true } burn-candle = { path = "../burn-candle", version = "=0.20.0-pre.6", optional = true } burn-cubecl = { path = "../burn-cubecl", version = "=0.20.0-pre.6", optional = true } diff --git a/crates/burn-vision/src/backends/cpu/mod.rs b/crates/burn-vision/src/backends/cpu/mod.rs index 897e2c2953..e6096437b8 100644 --- a/crates/burn-vision/src/backends/cpu/mod.rs +++ b/crates/burn-vision/src/backends/cpu/mod.rs @@ -1,8 +1,10 @@ mod base; mod connected_components; mod morphology; +mod nms; mod ops; pub use base::*; pub use connected_components::*; pub use morphology::*; +pub use nms::*; diff --git a/crates/burn-vision/src/backends/cpu/nms.rs b/crates/burn-vision/src/backends/cpu/nms.rs new file mode 100644 index 0000000000..b916dce20c --- /dev/null +++ b/crates/burn-vision/src/backends/cpu/nms.rs @@ -0,0 +1,212 @@ +use crate::NmsOptions; +use aligned_vec::{AVec, ConstAlign}; +use alloc::vec::Vec; +use burn_tensor::{Int, Shape, Tensor, TensorData, backend::Backend}; +use macerator::{Scalar, Simd, Vector, vload}; + +/// Perform NMS on CPU using SIMD acceleration. +/// +/// This implementation: +/// 1. Sorts boxes by score (descending) +/// 2. Iteratively selects the highest-scoring non-suppressed box +/// 3. Suppresses all boxes with IoU > threshold using SIMD +pub fn nms( + boxes: Tensor, + scores: Tensor, + options: NmsOptions, +) -> Tensor { + let device = boxes.device(); + let [n_boxes, _] = boxes.shape().dims(); + if n_boxes == 0 { + return Tensor::::empty([0], &device); + } + + // Get raw data + let boxes_data = boxes.to_data(); + let boxes_vec: Vec = boxes_data.to_vec().unwrap(); + + let scores_data = scores.to_data(); + let scores_vec: Vec = scores_data.to_vec().unwrap(); + + let keep = nms_vec(boxes_vec, scores_vec, options); + let n_kept = keep.len(); + let indices_data = TensorData::new(keep, Shape::new([n_kept])); + Tensor::::from_data(indices_data, &device) +} + +/// Perform NMS on CPU using SIMD acceleration. +fn nms_vec(boxes_vec: Vec, scores_vec: Vec, options: NmsOptions) -> Vec { + let n_boxes = scores_vec.len(); + + if n_boxes == 0 { + return vec![]; + } + + // Filter by score threshold first + let mut filtered_indices = Vec::with_capacity(n_boxes); + for (i, &score) in scores_vec.iter().enumerate() { + if score >= options.score_threshold { + filtered_indices.push(i); // original index + } + } + + let n_filtered = filtered_indices.len(); + if n_filtered == 0 { + return vec![]; + } + + // Sort by score descending + filtered_indices.sort_by(|&a, &b| scores_vec[b].total_cmp(&scores_vec[a])); + + const ALIGN: usize = 64; + const FLOATS_PER_ALIGN: usize = ALIGN / size_of::(); // 16 + let stride = n_filtered.div_ceil(FLOATS_PER_ALIGN) * FLOATS_PER_ALIGN; + let mut buf: AVec> = AVec::with_capacity(ALIGN, stride * 5); + buf.resize(stride * 5, 0.0); + + let (x1s, rest) = buf.split_at_mut(stride); + let (y1s, rest) = rest.split_at_mut(stride); + let (x2s, rest) = rest.split_at_mut(stride); + let (y2s, areas) = rest.split_at_mut(stride); + + // Convert filtered boxes to SoA format + for (j, &orig_idx) in filtered_indices.iter().enumerate() { + let x1 = boxes_vec[orig_idx * 4]; + let y1 = boxes_vec[orig_idx * 4 + 1]; + let x2 = boxes_vec[orig_idx * 4 + 2]; + let y2 = boxes_vec[orig_idx * 4 + 3]; + x1s[j] = x1; + y1s[j] = y1; + x2s[j] = x2; + y2s[j] = y2; + areas[j] = (x2 - x1) * (y2 - y1); + } + + // Apply NMS with SIMD dispatch + let mut suppressed = vec![false; stride]; + let mut keep = Vec::new(); + + for i in 0..n_filtered { + if suppressed[i] { + continue; + } + + // Optimization to reduce inner loop comparisons + suppressed[i] = true; + keep.push(filtered_indices[i] as i32); // original index + + if options.max_output_boxes > 0 && keep.len() >= options.max_output_boxes { + break; + } + + // Suppress overlapping boxes using SIMD + suppress_overlapping( + x1s[i], + y1s[i], + x2s[i], + y2s[i], + areas[i], + x1s, + y1s, + x2s, + y2s, + areas, + &mut suppressed, + stride, + options.iou_threshold, + ); + } + + keep +} + +/// SIMD-accelerated suppression of overlapping boxes. +#[allow(clippy::too_many_arguments)] +#[inline(always)] +#[macerator::with_simd] +fn suppress_overlapping<'a, S: Simd>( + ref_x1: f32, + ref_y1: f32, + ref_x2: f32, + ref_y2: f32, + ref_area: f32, + x1s: &'a [f32], + y1s: &'a [f32], + x2s: &'a [f32], + y2s: &'a [f32], + areas: &'a [f32], + suppressed: &'a mut [bool], + n_boxes: usize, // stride, always multiple of lanes + threshold: f32, +) where + 'a: 'a, +{ + let lanes = f32::lanes::(); + + // Splat reference values + let ref_x1_v: Vector = ref_x1.splat(); + let ref_y1_v: Vector = ref_y1.splat(); + let ref_x2_v: Vector = ref_x2.splat(); + let ref_y2_v: Vector = ref_y2.splat(); + let ref_area_v: Vector = ref_area.splat(); + let thresh_v: Vector = threshold.splat(); + let zero_v: Vector = 0.0f32.splat(); + + let mut i = 0; + + let mut mask_buf = core::mem::MaybeUninit::<[bool; 16]>::uninit(); + // Process lanes boxes at a time with SIMD + while i + lanes <= n_boxes { + // Skip if all boxes in this chunk are already suppressed + let all_suppressed = unsafe { + match lanes { + 4 => *(suppressed.as_ptr().add(i) as *const u32) == 0x01010101, + 8 => *(suppressed.as_ptr().add(i) as *const u64) == 0x0101010101010101, + 16 => { + *(suppressed.as_ptr().add(i) as *const u128) + == 0x01010101010101010101010101010101 + } + _ => unreachable!(), + } + }; + + if !all_suppressed { + let x1_v: Vector = unsafe { vload(x1s.as_ptr().add(i)) }; + let y1_v: Vector = unsafe { vload(y1s.as_ptr().add(i)) }; + let x2_v: Vector = unsafe { vload(x2s.as_ptr().add(i)) }; + let y2_v: Vector = unsafe { vload(y2s.as_ptr().add(i)) }; + let area_v: Vector = unsafe { vload(areas.as_ptr().add(i)) }; + + // Compute intersection coordinates + let xx1 = ref_x1_v.max(x1_v); + let yy1 = ref_y1_v.max(y1_v); + let xx2 = ref_x2_v.min(x2_v); + let yy2 = ref_y2_v.min(y2_v); + + // Compute intersection area (clamp to 0 for non-overlapping) + let w = (xx2 - xx1).max(zero_v); + let h = (yy2 - yy1).max(zero_v); + let inter = w * h; + + // Compute IoU + let union = ref_area_v + area_v - inter; + let iou = inter / union; + + // Get suppression mask (IoU > threshold) + let suppress_mask = iou.gt(thresh_v); + + // Extract mask to bool array and apply to suppressed + // SAFETY: mask_store_as_bool writes exactly `lanes` bools, we only read 0..lanes + unsafe { f32::mask_store_as_bool::(mask_buf.as_mut_ptr().cast(), suppress_mask) }; + let mask_buf = unsafe { mask_buf.assume_init() }; + + for k in 0..lanes { + if mask_buf[k] { + suppressed[i + k] = true; + } + } + } + + i += lanes; + } +} diff --git a/crates/burn-vision/src/lib.rs b/crates/burn-vision/src/lib.rs index 4ae0a8110d..7e88c680fa 100644 --- a/crates/burn-vision/src/lib.rs +++ b/crates/burn-vision/src/lib.rs @@ -6,6 +6,7 @@ //! Currently implemented are: //! - `connected_components` //! - `connected_components_with_stats` +//! - `nms` (Non-Maximum Suppression) //! #![warn(missing_docs)] diff --git a/crates/burn-vision/src/ops/base.rs b/crates/burn-vision/src/ops/base.rs index 4a11d82f57..6ecaa941a4 100644 --- a/crates/burn-vision/src/ops/base.rs +++ b/crates/burn-vision/src/ops/base.rs @@ -163,6 +163,29 @@ impl ConnectedStatsOptions { } } +/// Non-Maximum Suppression options. +#[derive(Clone, Copy, Debug)] +pub struct NmsOptions { + /// IoU threshold for suppression (default: 0.5). + /// Boxes with IoU > threshold with a higher-scoring box are suppressed. + pub iou_threshold: f32, + /// Score threshold to filter boxes before NMS (default: 0.0, i.e., no filtering). + /// Boxes with score < score_threshold are discarded. + pub score_threshold: f32, + /// Maximum number of boxes to keep (0 = unlimited). + pub max_output_boxes: usize, +} + +impl Default for NmsOptions { + fn default() -> Self { + Self { + iou_threshold: 0.5, + score_threshold: 0.0, + max_output_boxes: 0, + } + } +} + /// Vision capable backend, implemented by each backend pub trait VisionBackend: BoolVisionOps + IntVisionOps + FloatVisionOps + QVisionOps + Backend @@ -262,6 +285,29 @@ pub trait FloatVisionOps: Backend { .into_primitive() .tensor() } + + /// Perform Non-Maximum Suppression on bounding boxes. + /// + /// Returns indices of kept boxes after suppressing overlapping detections. + /// Boxes are processed in descending score order; a box suppresses all + /// lower-scoring boxes with IoU > threshold. + /// + /// # Arguments + /// * `boxes` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format + /// * `scores` - Confidence scores as \[N\] tensor + /// * `options` - NMS options (IoU threshold, score threshold, max boxes) + /// + /// # Returns + /// Indices of kept boxes as \[M\] tensor where M <= N + fn nms( + boxes: FloatTensor, + scores: FloatTensor, + options: NmsOptions, + ) -> IntTensor { + let boxes = Tensor::::from_primitive(TensorPrimitive::Float(boxes)); + let scores = Tensor::::from_primitive(TensorPrimitive::Float(scores)); + cpu::nms::(boxes, scores, options).into_primitive() + } } /// Vision ops on quantized float tensors diff --git a/crates/burn-vision/src/tensor.rs b/crates/burn-vision/src/tensor.rs index 0c889f44ba..a79d4434c5 100644 --- a/crates/burn-vision/src/tensor.rs +++ b/crates/burn-vision/src/tensor.rs @@ -4,7 +4,8 @@ use burn_tensor::{ }; use crate::{ - BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, VisionBackend, + BoolVisionOps, ConnectedStats, ConnectedStatsOptions, Connectivity, MorphOptions, NmsOptions, + VisionBackend, }; /// Connected components tensor extensions @@ -53,6 +54,24 @@ pub trait MorphologyKind: BasicOps { ) -> Self::Primitive; } +/// Non-maximum suppression tensor operations +pub trait Nms { + /// Perform Non-Maximum Suppression on this tensor of bounding boxes. + /// + /// Returns indices of kept boxes after suppressing overlapping detections. + /// Boxes are processed in descending score order; a box suppresses all + /// lower-scoring boxes with IoU > threshold. + /// + /// # Arguments + /// * `self` - Bounding boxes as \[N, 4\] tensor in (x1, y1, x2, y2) format + /// * `scores` - Confidence scores as \[N\] tensor + /// * `options` - NMS options (IoU threshold, score threshold, max boxes) + /// + /// # Returns + /// Indices of kept boxes as \[M\] tensor where M <= N + fn nms(self, scores: Tensor, opts: NmsOptions) -> Tensor; +} + impl ConnectedComponents for Tensor { fn connected_components(self, connectivity: Connectivity) -> Tensor { Tensor::from_primitive(B::connected_components(self.into_primitive(), connectivity)) @@ -154,3 +173,14 @@ impl MorphologyKind for Bool { B::bool_dilate(tensor, kernel, opts) } } + +impl Nms for Tensor { + fn nms(self, scores: Tensor, options: NmsOptions) -> Tensor { + match (self.into_primitive(), scores.into_primitive()) { + (TensorPrimitive::Float(boxes), TensorPrimitive::Float(scores)) => { + Tensor::::from_primitive(B::nms(boxes, scores, options)) + } + _ => todo!("Quantized inputs are not yet supported"), + } + } +} diff --git a/crates/burn-vision/tests/nms.rs b/crates/burn-vision/tests/nms.rs new file mode 100644 index 0000000000..5c3bb76371 --- /dev/null +++ b/crates/burn-vision/tests/nms.rs @@ -0,0 +1,92 @@ +use burn_vision::{Nms, NmsOptions}; + +mod common; +use common::*; + +#[test] +fn should_suppress_non_maximum() { + let boxes = TestTensor::<2>::from([ + [0, 0, 100, 100], + [0, 1, 100, 100], + [0, 101, 200, 200], + [0, 100, 200, 200], + [0, 170, 300, 300], + ]); + let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); + let options = NmsOptions { + iou_threshold: 0.5, + score_threshold: 0.0, + max_output_boxes: 0, + }; + + let output = boxes.nms(scores, options); + + let expected = TestTensorInt::<1>::from([4, 2, 1]); + output.into_data().assert_eq(&expected.into_data(), true); +} + +#[test] +fn should_apply_score_threshold() { + let boxes = TestTensor::<2>::from([ + [0, 0, 100, 100], + [0, 1, 100, 100], + [0, 101, 200, 200], + [0, 100, 200, 200], + [0, 170, 300, 300], + ]); + let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); + let options = NmsOptions { + iou_threshold: 0.5, + score_threshold: 0.3, + max_output_boxes: 0, + }; + + let output = boxes.nms(scores, options); + + let expected = TestTensorInt::<1>::from([4, 2]); + output.into_data().assert_eq(&expected.into_data(), true); +} + +#[test] +fn should_apply_iou_threshold() { + let boxes = TestTensor::<2>::from([ + [0, 0, 100, 100], + [0, 1, 100, 100], + [0, 101, 200, 200], + [0, 100, 200, 200], + [0, 170, 300, 300], + ]); + let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); + let options = NmsOptions { + iou_threshold: 0.1, + score_threshold: 0.0, + max_output_boxes: 0, + }; + + let output = boxes.nms(scores, options); + + let expected = TestTensorInt::<1>::from([4, 1]); + output.into_data().assert_eq(&expected.into_data(), true); +} + +#[test] +fn should_apply_max_output_boxes() { + let boxes = TestTensor::<2>::from([ + [0, 0, 100, 100], + [0, 1, 100, 100], + [0, 101, 200, 200], + [0, 100, 200, 200], + [0, 170, 300, 300], + ]); + let scores = TestTensor::<1>::from([0.1, 0.2, 0.4, 0.3, 0.5]); + let options = NmsOptions { + iou_threshold: 0.5, + score_threshold: 0.0, + max_output_boxes: 1, + }; + + let output = boxes.nms(scores, options); + + let expected = TestTensorInt::<1>::from([4]); + output.into_data().assert_eq(&expected.into_data(), true); +}