@@ -220,6 +220,244 @@ mod tests {
220
220
output. into_data ( ) . assert_eq ( & expected, false ) ;
221
221
}
222
222
223
+ #[ test]
224
+ fn should_select_assign_bool_overlapping_indices ( ) {
225
+ // Test accumulation behavior with overlapping indices
226
+ let device = Default :: default ( ) ;
227
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ false , true ] , & device) ;
228
+ let indices = TestTensorInt :: from_data ( [ 0 , 0 ] , & device) ;
229
+ let values = TestTensorBool :: < 1 > :: from_data ( [ true , false ] , & device) ;
230
+
231
+ let output = tensor. select_assign ( 0 , indices, values) ;
232
+ // Index 0: false OR true OR false = true
233
+ let expected = TensorData :: from ( [ true , true ] ) ;
234
+
235
+ output. into_data ( ) . assert_eq ( & expected, false ) ;
236
+ }
237
+
238
+ #[ test]
239
+ fn should_select_assign_bool_false_to_true_case ( ) {
240
+ // Test false OR true = true
241
+ let device = Default :: default ( ) ;
242
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ false ] , & device) ;
243
+ let indices = TestTensorInt :: from_data ( [ 0 ] , & device) ;
244
+ let values = TestTensorBool :: < 1 > :: from_data ( [ true ] , & device) ;
245
+
246
+ let output = tensor. select_assign ( 0 , indices, values) ;
247
+ let expected = TensorData :: from ( [ true ] ) ;
248
+
249
+ output. into_data ( ) . assert_eq ( & expected, false ) ;
250
+ }
251
+
252
+ #[ test]
253
+ fn should_select_assign_bool_empty_indices ( ) {
254
+ // Test empty indices array
255
+ let device = Default :: default ( ) ;
256
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true , false , true ] , & device) ;
257
+ let indices = TestTensorInt :: < 1 > :: from_data ( [ ] as [ i32 ; 0 ] , & device) ;
258
+ let values = TestTensorBool :: < 1 > :: from_data ( [ ] as [ bool ; 0 ] , & device) ;
259
+
260
+ let output = tensor. select_assign ( 0 , indices, values) ;
261
+ let expected = TensorData :: from ( [ true , false , true ] ) ;
262
+
263
+ output. into_data ( ) . assert_eq ( & expected, false ) ;
264
+ }
265
+
266
+ #[ test]
267
+ fn should_select_assign_bool_true_or_true_accumulation ( ) {
268
+ // Test multiple true accumulations
269
+ let device = Default :: default ( ) ;
270
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true , false ] , & device) ;
271
+ let indices = TestTensorInt :: from_data ( [ 0 , 0 , 0 ] , & device) ;
272
+ let values = TestTensorBool :: < 1 > :: from_data ( [ true , true , true ] , & device) ;
273
+
274
+ let output = tensor. select_assign ( 0 , indices, values) ;
275
+ let expected = TensorData :: from ( [ true , false ] ) ;
276
+
277
+ output. into_data ( ) . assert_eq ( & expected, false ) ;
278
+ }
279
+
280
+ #[ test]
281
+ fn should_match_default_implementation_behavior ( ) {
282
+ // Verify optimized implementation matches original default logic
283
+ use burn_tensor:: backend:: Backend ;
284
+
285
+ let device = Default :: default ( ) ;
286
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true , false , true ] , & device) ;
287
+ let indices = TestTensorInt :: from_data ( [ 0 , 1 , 0 ] , & device) ;
288
+ let values = TestTensorBool :: < 1 > :: from_data ( [ false , true , true ] , & device) ;
289
+
290
+ let optimized_result = tensor
291
+ . clone ( )
292
+ . select_assign ( 0 , indices. clone ( ) , values. clone ( ) ) ;
293
+
294
+ // Manual default implementation logic
295
+ let int_tensor = tensor. int ( ) ;
296
+ let int_values = values. int ( ) ;
297
+ let assigned = int_tensor. select_assign ( 0 , indices, int_values) ;
298
+ let default_result = assigned. greater_elem ( 0 ) ;
299
+
300
+ optimized_result
301
+ . into_data ( )
302
+ . assert_eq ( & default_result. into_data ( ) , false ) ;
303
+ }
304
+
305
+ #[ test]
306
+ fn should_select_assign_bool_overlapping_indices_vs_default ( ) {
307
+ // Test overlapping indices against default implementation
308
+ use burn_tensor:: backend:: Backend ;
309
+
310
+ let device = Default :: default ( ) ;
311
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ false , true ] , & device) ;
312
+ let indices = TestTensorInt :: from_data ( [ 0 , 0 ] , & device) ;
313
+ let values = TestTensorBool :: < 1 > :: from_data ( [ true , false ] , & device) ;
314
+
315
+ let optimized_result = tensor
316
+ . clone ( )
317
+ . select_assign ( 0 , indices. clone ( ) , values. clone ( ) ) ;
318
+
319
+ let int_tensor = tensor. int ( ) ;
320
+ let int_values = values. int ( ) ;
321
+ let assigned = int_tensor. select_assign ( 0 , indices, int_values) ;
322
+ let default_result = assigned. greater_elem ( 0 ) ;
323
+
324
+ optimized_result
325
+ . into_data ( )
326
+ . assert_eq ( & default_result. into_data ( ) , false ) ;
327
+ }
328
+
329
+ #[ test]
330
+ fn should_select_assign_bool_true_or_true_accumulation_vs_default ( ) {
331
+ // Test multiple true accumulations against default implementation
332
+ use burn_tensor:: backend:: Backend ;
333
+
334
+ let device = Default :: default ( ) ;
335
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true , false ] , & device) ;
336
+ let indices = TestTensorInt :: from_data ( [ 0 , 0 , 0 ] , & device) ;
337
+ let values = TestTensorBool :: < 1 > :: from_data ( [ true , true , true ] , & device) ;
338
+
339
+ let optimized_result = tensor
340
+ . clone ( )
341
+ . select_assign ( 0 , indices. clone ( ) , values. clone ( ) ) ;
342
+
343
+ let int_tensor = tensor. int ( ) ;
344
+ let int_values = values. int ( ) ;
345
+ let assigned = int_tensor. select_assign ( 0 , indices, int_values) ;
346
+ let default_result = assigned. greater_elem ( 0 ) ;
347
+
348
+ optimized_result
349
+ . into_data ( )
350
+ . assert_eq ( & default_result. into_data ( ) , false ) ;
351
+ }
352
+
353
+ #[ test]
354
+ fn should_select_assign_bool_false_to_true_case_vs_default ( ) {
355
+ // Test false OR true case against default implementation
356
+ use burn_tensor:: backend:: Backend ;
357
+
358
+ let device = Default :: default ( ) ;
359
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ false ] , & device) ;
360
+ let indices = TestTensorInt :: from_data ( [ 0 ] , & device) ;
361
+ let values = TestTensorBool :: < 1 > :: from_data ( [ true ] , & device) ;
362
+
363
+ let optimized_result = tensor
364
+ . clone ( )
365
+ . select_assign ( 0 , indices. clone ( ) , values. clone ( ) ) ;
366
+
367
+ let int_tensor = tensor. int ( ) ;
368
+ let int_values = values. int ( ) ;
369
+ let assigned = int_tensor. select_assign ( 0 , indices, int_values) ;
370
+ let default_result = assigned. greater_elem ( 0 ) ;
371
+
372
+ optimized_result
373
+ . into_data ( )
374
+ . assert_eq ( & default_result. into_data ( ) , false ) ;
375
+ }
376
+
377
+ #[ test]
378
+ fn should_select_assign_bool_empty_indices_vs_default ( ) {
379
+ // Test empty indices against default implementation
380
+ use burn_tensor:: backend:: Backend ;
381
+
382
+ let device = Default :: default ( ) ;
383
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true , false , true ] , & device) ;
384
+ let indices = TestTensorInt :: < 1 > :: from_data ( [ ] as [ i32 ; 0 ] , & device) ;
385
+ let values = TestTensorBool :: < 1 > :: from_data ( [ ] as [ bool ; 0 ] , & device) ;
386
+
387
+ let optimized_result = tensor
388
+ . clone ( )
389
+ . select_assign ( 0 , indices. clone ( ) , values. clone ( ) ) ;
390
+
391
+ let int_tensor = tensor. int ( ) ;
392
+ let int_values = values. int ( ) ;
393
+ let assigned = int_tensor. select_assign ( 0 , indices, int_values) ;
394
+ let default_result = assigned. greater_elem ( 0 ) ;
395
+
396
+ optimized_result
397
+ . into_data ( )
398
+ . assert_eq ( & default_result. into_data ( ) , false ) ;
399
+ }
400
+
401
+ #[ test]
402
+ fn should_select_assign_bool_tensor_vs_default ( ) {
403
+ // Test existing basic case against default implementation
404
+ use burn_tensor:: backend:: Backend ;
405
+
406
+ let device = Default :: default ( ) ;
407
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true , false , true ] , & device) ;
408
+ let indices = TestTensorInt :: from_data ( [ 0 , 2 ] , & device) ;
409
+ let values = TestTensorBool :: < 1 > :: from_data ( [ false , false ] , & device) ;
410
+
411
+ let optimized_result = tensor
412
+ . clone ( )
413
+ . select_assign ( 0 , indices. clone ( ) , values. clone ( ) ) ;
414
+
415
+ let int_tensor = tensor. int ( ) ;
416
+ let int_values = values. int ( ) ;
417
+ let assigned = int_tensor. select_assign ( 0 , indices, int_values) ;
418
+ let default_result = assigned. greater_elem ( 0 ) ;
419
+
420
+ optimized_result
421
+ . into_data ( )
422
+ . assert_eq ( & default_result. into_data ( ) , false ) ;
423
+ }
424
+
425
+ #[ test]
426
+ #[ should_panic( expected = "Tensors are not eq" ) ]
427
+ fn should_fail_if_replacement_semantics_were_used ( ) {
428
+ // Test that framework uses accumulation, not replacement
429
+ let device = Default :: default ( ) ;
430
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true ] , & device) ;
431
+ let indices = TestTensorInt :: from_data ( [ 0 ] , & device) ;
432
+ let values = TestTensorBool :: < 1 > :: from_data ( [ false ] , & device) ;
433
+
434
+ let output = tensor. select_assign ( 0 , indices, values) ;
435
+ let replacement_expected = TensorData :: from ( [ false ] ) ;
436
+
437
+ output. into_data ( ) . assert_eq ( & replacement_expected, false ) ;
438
+ }
439
+
440
+ #[ test]
441
+ #[ should_panic( expected = "Tensors are not eq" ) ]
442
+ fn should_fail_if_replacement_semantics_were_used_vs_default ( ) {
443
+ // Test that default implementation also uses accumulation, not replacement
444
+ use burn_tensor:: backend:: Backend ;
445
+ let device = Default :: default ( ) ;
446
+ let tensor = TestTensorBool :: < 1 > :: from_data ( [ true ] , & device) ;
447
+ let indices = TestTensorInt :: from_data ( [ 0 ] , & device) ;
448
+ let values = TestTensorBool :: < 1 > :: from_data ( [ false ] , & device) ;
449
+
450
+ let int_tensor = tensor. int ( ) ;
451
+ let int_values = values. int ( ) ;
452
+ let assigned = int_tensor. select_assign ( 0 , indices, int_values) ;
453
+ let default_result = assigned. greater_elem ( 0 ) ;
454
+ let replacement_expected = TensorData :: from ( [ false ] ) ;
455
+
456
+ default_result
457
+ . into_data ( )
458
+ . assert_eq ( & replacement_expected, false ) ;
459
+ }
460
+
223
461
#[ test]
224
462
fn should_select_with_negative_dim_2d ( ) {
225
463
// Test using negative dimension indexing on 2D tensor
0 commit comments