@@ -352,6 +352,128 @@ HWY_NOINLINE void TestAllSumsOf8() {
352
352
ForGEVectors<64 , TestSumsOf8>()(uint8_t ());
353
353
}
354
354
355
+ struct TestMaskedReduceSum {
356
+ template <typename T, class D >
357
+ HWY_NOINLINE void operator ()(T /* unused*/ , D d) {
358
+ RandomState rng;
359
+
360
+ using TI = MakeSigned<T>;
361
+ const Rebind<TI, D> di;
362
+ const Vec<D> v2 = Iota (d, 2 );
363
+
364
+ const size_t N = Lanes (d);
365
+ auto bool_lanes = AllocateAligned<TI>(N);
366
+ HWY_ASSERT (bool_lanes);
367
+
368
+ for (size_t rep = 0 ; rep < AdjustedReps (200 ); ++rep) {
369
+ T expected = 0 ;
370
+ for (size_t i = 0 ; i < N; ++i) {
371
+ bool_lanes[i] = (Random32 (&rng) & 1024 ) ? TI (1 ) : TI (0 );
372
+ if (bool_lanes[i]) {
373
+ expected += ConvertScalarTo<T>(i + 2 );
374
+ }
375
+ }
376
+
377
+ const auto mask_i = Load (di, bool_lanes.get ());
378
+ const Mask<D> mask = RebindMask (d, Gt (mask_i, Zero (di)));
379
+
380
+ // If all elements are disabled the result is implementation defined
381
+ if (AllFalse (d, mask)) {
382
+ continue ;
383
+ }
384
+
385
+ HWY_ASSERT_EQ (expected, MaskedReduceSum (d, mask, v2));
386
+ }
387
+ }
388
+ };
389
+
390
+ HWY_NOINLINE void TestAllMaskedReduceSum () {
391
+ ForAllTypes (ForPartialVectors<TestMaskedReduceSum>());
392
+ }
393
+
394
+ struct TestMaskedReduceMin {
395
+ template <typename T, class D >
396
+ HWY_NOINLINE void operator ()(T /* unused*/ , D d) {
397
+ RandomState rng;
398
+
399
+ using TI = MakeSigned<T>;
400
+ const Rebind<TI, D> di;
401
+ const Vec<D> v2 = Iota (d, 2 );
402
+
403
+ const size_t N = Lanes (d);
404
+ auto bool_lanes = AllocateAligned<TI>(N);
405
+ HWY_ASSERT (bool_lanes);
406
+
407
+ for (size_t rep = 0 ; rep < AdjustedReps (200 ); ++rep) {
408
+ T expected =
409
+ ConvertScalarTo<T>(N + 3 ); // larger than any values in the vector
410
+ for (size_t i = 0 ; i < N; ++i) {
411
+ bool_lanes[i] = (Random32 (&rng) & 1024 ) ? TI (1 ) : TI (0 );
412
+ if (bool_lanes[i]) {
413
+ if (expected > ConvertScalarTo<T>(i + 2 )) {
414
+ expected = ConvertScalarTo<T>(i + 2 );
415
+ }
416
+ }
417
+ }
418
+
419
+ const auto mask_i = Load (di, bool_lanes.get ());
420
+ const Mask<D> mask = RebindMask (d, Gt (mask_i, Zero (di)));
421
+
422
+ // If all elements are disabled the result is implementation defined
423
+ if (AllFalse (d, mask)) {
424
+ continue ;
425
+ }
426
+
427
+ HWY_ASSERT_EQ (expected, MaskedReduceMin (d, mask, v2));
428
+ }
429
+ }
430
+ };
431
+
432
+ HWY_NOINLINE void TestAllMaskedReduceMin () {
433
+ ForAllTypes (ForPartialVectors<TestMaskedReduceMin>());
434
+ }
435
+
436
+ struct TestMaskedReduceMax {
437
+ template <typename T, class D >
438
+ HWY_NOINLINE void operator ()(T /* unused*/ , D d) {
439
+ RandomState rng;
440
+
441
+ using TI = MakeSigned<T>;
442
+ const Rebind<TI, D> di;
443
+ const Vec<D> v2 = Iota (d, 2 );
444
+
445
+ const size_t N = Lanes (d);
446
+ auto bool_lanes = AllocateAligned<TI>(N);
447
+ HWY_ASSERT (bool_lanes);
448
+
449
+ for (size_t rep = 0 ; rep < AdjustedReps (200 ); ++rep) {
450
+ T expected = 0 ;
451
+ for (size_t i = 0 ; i < N; ++i) {
452
+ bool_lanes[i] = (Random32 (&rng) & 1024 ) ? TI (1 ) : TI (0 );
453
+ if (bool_lanes[i]) {
454
+ if (expected < ConvertScalarTo<T>(i + 2 )) {
455
+ expected = ConvertScalarTo<T>(i + 2 );
456
+ }
457
+ }
458
+ }
459
+
460
+ const auto mask_i = Load (di, bool_lanes.get ());
461
+ const Mask<D> mask = RebindMask (d, Gt (mask_i, Zero (di)));
462
+
463
+ // If all elements are disabled the result is implementation defined
464
+ if (AllFalse (d, mask)) {
465
+ continue ;
466
+ }
467
+
468
+ HWY_ASSERT_EQ (expected, MaskedReduceMax (d, mask, v2));
469
+ }
470
+ }
471
+ };
472
+
473
+ HWY_NOINLINE void TestAllMaskedReduceMax () {
474
+ ForAllTypes (ForPartialVectors<TestMaskedReduceMax>());
475
+ }
476
+
355
477
} // namespace
356
478
// NOLINTNEXTLINE(google-readability-namespace-comments)
357
479
} // namespace HWY_NAMESPACE
@@ -367,6 +489,10 @@ HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMinMaxOfLanes);
367
489
HWY_EXPORT_AND_TEST_P (HwyReductionTest, TestAllSumsOf2);
368
490
HWY_EXPORT_AND_TEST_P (HwyReductionTest, TestAllSumsOf4);
369
491
HWY_EXPORT_AND_TEST_P (HwyReductionTest, TestAllSumsOf8);
492
+
493
+ HWY_EXPORT_AND_TEST_P (HwyReductionTest, TestAllMaskedReduceSum);
494
+ HWY_EXPORT_AND_TEST_P (HwyReductionTest, TestAllMaskedReduceMin);
495
+ HWY_EXPORT_AND_TEST_P (HwyReductionTest, TestAllMaskedReduceMax);
370
496
HWY_AFTER_TEST ();
371
497
} // namespace
372
498
} // namespace hwy
0 commit comments