Skip to content

Commit 47c712c

Browse files
TheDarkchipcatch-twenty-two
authored andcommitted
Implement optimized bool_select for primary backends (#3710)
* Implement optimized bool_select for each backend This commit implements optimized bool_select operations for all Burn backends to address GitHub issue #3697. ## Changes Made ### NdArray Backend - Direct boolean array operations using ndarray's select method - File: crates/burn-ndarray/src/ops/bool_tensor.rs ### Candle Backend - Leverages Candle's native index_select operation on boolean tensors - File: crates/burn-candle/src/ops/bool_tensor.rs ### CubeBackend (WGPU/CUDA/ROCm) - GPU kernel-based implementation using kernel::select - File: crates/burn-cubecl/src/ops/bool_ops.rs ### Tch Backend (PyTorch) - Uses PyTorch's efficient index_select_dim operation - File: crates/burn-tch/src/ops/bool_tensor.rs The default implementation used inefficient type conversions: ```rust let int_tensor = B::bool_into_int(tensor); let selected = B::int_select(int_tensor, dim, indices); B::int_equal_elem(selected, 1_i32.elem()) ``` The optimized implementations eliminate these conversions by using backend-native boolean selection operations. Addresses: #3697 * Fix code formatting issues - Fix function signature formatting in bool_select implementations - Reorder imports in ndarray backend - Apply cargo fmt formatting standards This resolves the CI code-quality check failure. * Add comprehensive bool select_assign tests Adds focused tests covering: - Overlapping indices accumulation behavior - Complete boolean truth table coverage - Edge cases (empty indices, multiple true accumulations) - Verification against original default implementation - Proof by contradiction for replacement semantics Tests validate that optimized implementations maintain identical semantics to the framework's original behavior while providing performance improvements. * Add comprehensive test suite for bool_select_assign - 6 behavior tests covering edge cases and accumulation semantics - 2 expected failure tests proving replacement semantics would be wrong - 5 comparison tests validating optimized vs default implementation - All tests pass confirming OR accumulation behavior is preserved * Fix formatting
1 parent 343160e commit 47c712c

File tree

5 files changed

+325
-1
lines changed

5 files changed

+325
-1
lines changed

crates/burn-candle/src/ops/bool_tensor.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,28 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
111111
super::base::flip(tensor, axes)
112112
}
113113

114+
fn bool_select(
115+
tensor: BoolTensor<Self>,
116+
dim: usize,
117+
indices: IntTensor<Self>,
118+
) -> BoolTensor<Self> {
119+
CandleTensor::new(tensor.tensor.index_select(&indices.tensor, dim).unwrap())
120+
}
121+
122+
fn bool_select_assign(
123+
tensor: BoolTensor<Self>,
124+
dim: usize,
125+
indices: IntTensor<Self>,
126+
value: BoolTensor<Self>,
127+
) -> BoolTensor<Self> {
128+
CandleTensor::new(
129+
tensor
130+
.tensor
131+
.index_add(&indices.tensor, &value.tensor, dim)
132+
.unwrap(),
133+
)
134+
}
135+
114136
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
115137
expand(tensor, shape)
116138
}

crates/burn-cubecl/src/ops/bool_ops.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,23 @@ where
106106
expand(tensor, shape)
107107
}
108108

109+
fn bool_select(
110+
tensor: BoolTensor<Self>,
111+
dim: usize,
112+
indices: IntTensor<Self>,
113+
) -> BoolTensor<Self> {
114+
kernel::select::<R, BT, I>(tensor, dim, indices)
115+
}
116+
117+
fn bool_select_assign(
118+
tensor: BoolTensor<Self>,
119+
dim: usize,
120+
indices: IntTensor<Self>,
121+
value: BoolTensor<Self>,
122+
) -> BoolTensor<Self> {
123+
kernel::select_assign::<R, BT, I>(tensor, dim, indices, value)
124+
}
125+
109126
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
110127
kernel::flip::<R, BT, BT>(tensor, axes)
111128
}

crates/burn-ndarray/src/ops/bool_tensor.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use ndarray::IntoDimension;
88

99
// Current crate
1010
use crate::element::{FloatNdArrayElement, IntNdArrayElement, QuantElement};
11-
use crate::{NdArray, tensor::NdArrayTensor};
11+
use crate::{NdArray, execute_with_int_dtype, tensor::NdArrayTensor};
1212
use crate::{NdArrayDevice, SharedArray};
1313

1414
// Workspace crates
@@ -117,6 +117,40 @@ where
117117
NdArrayOps::expand(tensor.bool(), shape).into()
118118
}
119119

120+
fn bool_select(tensor: NdArrayTensor, dim: usize, indices: NdArrayTensor) -> NdArrayTensor {
121+
execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
122+
let tensor_bool = tensor.bool();
123+
let indices_vec: Vec<usize> = indices
124+
.into_iter()
125+
.map(|i| i.elem::<i64>() as usize)
126+
.collect();
127+
128+
let selected = tensor_bool.select(ndarray::Axis(dim), &indices_vec);
129+
selected.into_shared().into()
130+
})
131+
}
132+
133+
fn bool_select_assign(
134+
tensor: NdArrayTensor,
135+
dim: usize,
136+
indices: NdArrayTensor,
137+
value: NdArrayTensor,
138+
) -> NdArrayTensor {
139+
execute_with_int_dtype!(indices, I, |indices: SharedArray<I>| -> NdArrayTensor {
140+
let mut output_array = tensor.bool().into_owned();
141+
let value_bool = value.bool();
142+
143+
for (index_value, index) in indices.into_iter().enumerate() {
144+
let index_usize = index.elem::<i64>() as usize;
145+
let mut view = output_array.index_axis_mut(ndarray::Axis(dim), index_usize);
146+
let value_slice = value_bool.index_axis(ndarray::Axis(dim), index_value);
147+
// For boolean tensors, select_assign should use logical OR operation
148+
view.zip_mut_with(&value_slice, |a, b| *a = *a || *b);
149+
}
150+
output_array.into_shared().into()
151+
})
152+
}
153+
120154
fn bool_flip(tensor: NdArrayTensor, axes: &[usize]) -> NdArrayTensor {
121155
NdArrayOps::flip(tensor.bool(), axes).into()
122156
}

crates/burn-tch/src/ops/bool_tensor.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
134134
TchTensor::new(tensor.tensor.argwhere())
135135
}
136136

137+
fn bool_select(tensor: TchTensor, dim: usize, indices: TchTensor) -> TchTensor {
138+
TchOps::index_select_dim(tensor, dim, indices)
139+
}
140+
141+
fn bool_select_assign(
142+
tensor: TchTensor,
143+
dim: usize,
144+
indices: TchTensor,
145+
value: TchTensor,
146+
) -> TchTensor {
147+
TchOps::select_assign(tensor, dim, indices, value)
148+
}
149+
137150
fn bool_expand(tensor: TchTensor, shape: Shape) -> TchTensor {
138151
TchOps::expand(tensor, shape)
139152
}

crates/burn-tensor/src/tests/ops/select.rs

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,244 @@ mod tests {
220220
output.into_data().assert_eq(&expected, false);
221221
}
222222

223+
#[test]
224+
fn should_select_assign_bool_overlapping_indices() {
225+
// Test accumulation behavior with overlapping indices
226+
let device = Default::default();
227+
let tensor = TestTensorBool::<1>::from_data([false, true], &device);
228+
let indices = TestTensorInt::from_data([0, 0], &device);
229+
let values = TestTensorBool::<1>::from_data([true, false], &device);
230+
231+
let output = tensor.select_assign(0, indices, values);
232+
// Index 0: false OR true OR false = true
233+
let expected = TensorData::from([true, true]);
234+
235+
output.into_data().assert_eq(&expected, false);
236+
}
237+
238+
#[test]
239+
fn should_select_assign_bool_false_to_true_case() {
240+
// Test false OR true = true
241+
let device = Default::default();
242+
let tensor = TestTensorBool::<1>::from_data([false], &device);
243+
let indices = TestTensorInt::from_data([0], &device);
244+
let values = TestTensorBool::<1>::from_data([true], &device);
245+
246+
let output = tensor.select_assign(0, indices, values);
247+
let expected = TensorData::from([true]);
248+
249+
output.into_data().assert_eq(&expected, false);
250+
}
251+
252+
#[test]
253+
fn should_select_assign_bool_empty_indices() {
254+
// Test empty indices array
255+
let device = Default::default();
256+
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
257+
let indices = TestTensorInt::<1>::from_data([] as [i32; 0], &device);
258+
let values = TestTensorBool::<1>::from_data([] as [bool; 0], &device);
259+
260+
let output = tensor.select_assign(0, indices, values);
261+
let expected = TensorData::from([true, false, true]);
262+
263+
output.into_data().assert_eq(&expected, false);
264+
}
265+
266+
#[test]
267+
fn should_select_assign_bool_true_or_true_accumulation() {
268+
// Test multiple true accumulations
269+
let device = Default::default();
270+
let tensor = TestTensorBool::<1>::from_data([true, false], &device);
271+
let indices = TestTensorInt::from_data([0, 0, 0], &device);
272+
let values = TestTensorBool::<1>::from_data([true, true, true], &device);
273+
274+
let output = tensor.select_assign(0, indices, values);
275+
let expected = TensorData::from([true, false]);
276+
277+
output.into_data().assert_eq(&expected, false);
278+
}
279+
280+
#[test]
281+
fn should_match_default_implementation_behavior() {
282+
// Verify optimized implementation matches original default logic
283+
use burn_tensor::backend::Backend;
284+
285+
let device = Default::default();
286+
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
287+
let indices = TestTensorInt::from_data([0, 1, 0], &device);
288+
let values = TestTensorBool::<1>::from_data([false, true, true], &device);
289+
290+
let optimized_result = tensor
291+
.clone()
292+
.select_assign(0, indices.clone(), values.clone());
293+
294+
// Manual default implementation logic
295+
let int_tensor = tensor.int();
296+
let int_values = values.int();
297+
let assigned = int_tensor.select_assign(0, indices, int_values);
298+
let default_result = assigned.greater_elem(0);
299+
300+
optimized_result
301+
.into_data()
302+
.assert_eq(&default_result.into_data(), false);
303+
}
304+
305+
#[test]
306+
fn should_select_assign_bool_overlapping_indices_vs_default() {
307+
// Test overlapping indices against default implementation
308+
use burn_tensor::backend::Backend;
309+
310+
let device = Default::default();
311+
let tensor = TestTensorBool::<1>::from_data([false, true], &device);
312+
let indices = TestTensorInt::from_data([0, 0], &device);
313+
let values = TestTensorBool::<1>::from_data([true, false], &device);
314+
315+
let optimized_result = tensor
316+
.clone()
317+
.select_assign(0, indices.clone(), values.clone());
318+
319+
let int_tensor = tensor.int();
320+
let int_values = values.int();
321+
let assigned = int_tensor.select_assign(0, indices, int_values);
322+
let default_result = assigned.greater_elem(0);
323+
324+
optimized_result
325+
.into_data()
326+
.assert_eq(&default_result.into_data(), false);
327+
}
328+
329+
#[test]
330+
fn should_select_assign_bool_true_or_true_accumulation_vs_default() {
331+
// Test multiple true accumulations against default implementation
332+
use burn_tensor::backend::Backend;
333+
334+
let device = Default::default();
335+
let tensor = TestTensorBool::<1>::from_data([true, false], &device);
336+
let indices = TestTensorInt::from_data([0, 0, 0], &device);
337+
let values = TestTensorBool::<1>::from_data([true, true, true], &device);
338+
339+
let optimized_result = tensor
340+
.clone()
341+
.select_assign(0, indices.clone(), values.clone());
342+
343+
let int_tensor = tensor.int();
344+
let int_values = values.int();
345+
let assigned = int_tensor.select_assign(0, indices, int_values);
346+
let default_result = assigned.greater_elem(0);
347+
348+
optimized_result
349+
.into_data()
350+
.assert_eq(&default_result.into_data(), false);
351+
}
352+
353+
#[test]
354+
fn should_select_assign_bool_false_to_true_case_vs_default() {
355+
// Test false OR true case against default implementation
356+
use burn_tensor::backend::Backend;
357+
358+
let device = Default::default();
359+
let tensor = TestTensorBool::<1>::from_data([false], &device);
360+
let indices = TestTensorInt::from_data([0], &device);
361+
let values = TestTensorBool::<1>::from_data([true], &device);
362+
363+
let optimized_result = tensor
364+
.clone()
365+
.select_assign(0, indices.clone(), values.clone());
366+
367+
let int_tensor = tensor.int();
368+
let int_values = values.int();
369+
let assigned = int_tensor.select_assign(0, indices, int_values);
370+
let default_result = assigned.greater_elem(0);
371+
372+
optimized_result
373+
.into_data()
374+
.assert_eq(&default_result.into_data(), false);
375+
}
376+
377+
#[test]
378+
fn should_select_assign_bool_empty_indices_vs_default() {
379+
// Test empty indices against default implementation
380+
use burn_tensor::backend::Backend;
381+
382+
let device = Default::default();
383+
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
384+
let indices = TestTensorInt::<1>::from_data([] as [i32; 0], &device);
385+
let values = TestTensorBool::<1>::from_data([] as [bool; 0], &device);
386+
387+
let optimized_result = tensor
388+
.clone()
389+
.select_assign(0, indices.clone(), values.clone());
390+
391+
let int_tensor = tensor.int();
392+
let int_values = values.int();
393+
let assigned = int_tensor.select_assign(0, indices, int_values);
394+
let default_result = assigned.greater_elem(0);
395+
396+
optimized_result
397+
.into_data()
398+
.assert_eq(&default_result.into_data(), false);
399+
}
400+
401+
#[test]
402+
fn should_select_assign_bool_tensor_vs_default() {
403+
// Test existing basic case against default implementation
404+
use burn_tensor::backend::Backend;
405+
406+
let device = Default::default();
407+
let tensor = TestTensorBool::<1>::from_data([true, false, true], &device);
408+
let indices = TestTensorInt::from_data([0, 2], &device);
409+
let values = TestTensorBool::<1>::from_data([false, false], &device);
410+
411+
let optimized_result = tensor
412+
.clone()
413+
.select_assign(0, indices.clone(), values.clone());
414+
415+
let int_tensor = tensor.int();
416+
let int_values = values.int();
417+
let assigned = int_tensor.select_assign(0, indices, int_values);
418+
let default_result = assigned.greater_elem(0);
419+
420+
optimized_result
421+
.into_data()
422+
.assert_eq(&default_result.into_data(), false);
423+
}
424+
425+
#[test]
426+
#[should_panic(expected = "Tensors are not eq")]
427+
fn should_fail_if_replacement_semantics_were_used() {
428+
// Test that framework uses accumulation, not replacement
429+
let device = Default::default();
430+
let tensor = TestTensorBool::<1>::from_data([true], &device);
431+
let indices = TestTensorInt::from_data([0], &device);
432+
let values = TestTensorBool::<1>::from_data([false], &device);
433+
434+
let output = tensor.select_assign(0, indices, values);
435+
let replacement_expected = TensorData::from([false]);
436+
437+
output.into_data().assert_eq(&replacement_expected, false);
438+
}
439+
440+
#[test]
441+
#[should_panic(expected = "Tensors are not eq")]
442+
fn should_fail_if_replacement_semantics_were_used_vs_default() {
443+
// Test that default implementation also uses accumulation, not replacement
444+
use burn_tensor::backend::Backend;
445+
let device = Default::default();
446+
let tensor = TestTensorBool::<1>::from_data([true], &device);
447+
let indices = TestTensorInt::from_data([0], &device);
448+
let values = TestTensorBool::<1>::from_data([false], &device);
449+
450+
let int_tensor = tensor.int();
451+
let int_values = values.int();
452+
let assigned = int_tensor.select_assign(0, indices, int_values);
453+
let default_result = assigned.greater_elem(0);
454+
let replacement_expected = TensorData::from([false]);
455+
456+
default_result
457+
.into_data()
458+
.assert_eq(&replacement_expected, false);
459+
}
460+
223461
#[test]
224462
fn should_select_with_negative_dim_2d() {
225463
// Test using negative dimension indexing on 2D tensor

0 commit comments

Comments
 (0)