Skip to content

Commit 326dba3

Browse files
Merge pull request google#2428 from cambridgeconsultants:cc_up_masked_ops
PiperOrigin-RevId: 722985349
2 parents f129551 + af4183d commit 326dba3

6 files changed

+245
-0
lines changed

g3doc/quick_reference.md

+21
Original file line numberDiff line numberDiff line change
@@ -1125,6 +1125,9 @@ types, and on SVE/RVV.
11251125

11261126
* <code>V **AndNot**(V a, V b)</code>: returns `~a[i] & b[i]`.
11271127

1128+
* <code>V **MaskedOr**(M m, V a, V b)</code>: returns `a[i] | b[i]`
1129+
or `zero` if `m[i]` is false.
1130+
11281131
The following three-argument functions may be more efficient than assembling
11291132
them from 2-argument functions:
11301133

@@ -2491,6 +2494,24 @@ more efficient on some targets.
24912494
* <code>T **ReduceMin**(D, V v)</code>: returns the minimum of all lanes.
24922495
* <code>T **ReduceMax**(D, V v)</code>: returns the maximum of all lanes.
24932496
2497+
### Masked reductions
2498+
2499+
**Note**: Horizontal operations (across lanes of the same vector) such as
2500+
reductions are slower than normal SIMD operations and are typically used outside
2501+
critical loops.
2502+
2503+
All ops in this section ignore lanes where `mask=false`. These are equivalent
2504+
to, and potentially more efficient than, `GetLane(SumOfLanes(d,
2505+
IfThenElseZero(m, v)))` etc. The result is implementation-defined when all mask
2506+
elements are false.
2507+
2508+
* <code>T **MaskedReduceSum**(D, M m, V v)</code>: returns the sum of all lanes
2509+
where `m[i]` is `true`.
2510+
* <code>T **MaskedReduceMin**(D, M m, V v)</code>: returns the minimum of all
2511+
lanes where `m[i]` is `true`.
2512+
* <code>T **MaskedReduceMax**(D, M m, V v)</code>: returns the maximum of all
2513+
lanes where `m[i]` is `true`.
2514+
24942515
### Crypto
24952516
24962517
Ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:

hwy/ops/arm_sve-inl.h

+47
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
260260
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
261261
return sv##OP##_##CHAR##BITS##_x(m, a, b); \
262262
}
263+
// User-specified mask. Mask=false value is zero.
264+
#define HWY_SVE_RETV_ARGMVV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
265+
HWY_API HWY_SVE_V(BASE, BITS) \
266+
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
267+
return sv##OP##_##CHAR##BITS##_z(m, a, b); \
268+
}
263269

264270
#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
265271
HWY_API HWY_SVE_V(BASE, BITS) \
@@ -763,6 +769,9 @@ HWY_API V Or(const V a, const V b) {
763769
return BitCast(df, Or(BitCast(du, a), BitCast(du, b)));
764770
}
765771

772+
// ------------------------------ MaskedOr
773+
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV_Z, MaskedOr, orr)
774+
766775
// ------------------------------ Xor
767776

768777
namespace detail {
@@ -1678,6 +1687,7 @@ namespace detail {
16781687
return sv##OP##_##CHAR##BITS(pg, v); \
16791688
}
16801689

1690+
// TODO: Remove SumOfLanesM in favor of using MaskedReduceSum
16811691
HWY_SVE_FOREACH_UI(HWY_SVE_REDUCE_ADD, SumOfLanesM, addv)
16821692
HWY_SVE_FOREACH_F(HWY_SVE_REDUCE, SumOfLanesM, addv)
16831693

@@ -1725,6 +1735,25 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
17251735
return detail::MaxOfLanesM(detail::MakeMask(d), v);
17261736
}
17271737

1738+
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
1739+
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
1740+
#else
1741+
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
1742+
#endif
1743+
1744+
template <class D, class M>
1745+
HWY_API TFromD<D> MaskedReduceSum(D /*d*/, M m, VFromD<D> v) {
1746+
return detail::SumOfLanesM(m, v);
1747+
}
1748+
template <class D, class M>
1749+
HWY_API TFromD<D> MaskedReduceMin(D /*d*/, M m, VFromD<D> v) {
1750+
return detail::MinOfLanesM(m, v);
1751+
}
1752+
template <class D, class M>
1753+
HWY_API TFromD<D> MaskedReduceMax(D /*d*/, M m, VFromD<D> v) {
1754+
return detail::MaxOfLanesM(m, v);
1755+
}
1756+
17281757
// ------------------------------ SumOfLanes
17291758

17301759
template <class D, HWY_IF_LANES_GT_D(D, 1)>
@@ -5056,6 +5085,23 @@ HWY_API V IfNegativeThenElse(V v, V yes, V no) {
50565085
static_assert(IsSigned<TFromV<V>>(), "Only works for signed/float");
50575086
return IfThenElse(IsNegative(v), yes, no);
50585087
}
5088+
// ------------------------------ IfNegativeThenNegOrUndefIfZero
5089+
5090+
#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
5091+
#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
5092+
#else
5093+
#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
5094+
#endif
5095+
5096+
#define HWY_SVE_NEG_IF(BASE, CHAR, BITS, HALF, NAME, OP) \
5097+
HWY_API HWY_SVE_V(BASE, BITS) \
5098+
NAME(HWY_SVE_V(BASE, BITS) mask, HWY_SVE_V(BASE, BITS) v) { \
5099+
return sv##OP##_##CHAR##BITS##_m(v, IsNegative(mask), v); \
5100+
}
5101+
5102+
HWY_SVE_FOREACH_IF(HWY_SVE_NEG_IF, IfNegativeThenNegOrUndefIfZero, neg)
5103+
5104+
#undef HWY_SVE_NEG_IF
50595105

50605106
// ------------------------------ AverageRound (ShiftRight)
50615107

@@ -6610,6 +6656,7 @@ HWY_SVE_FOREACH_UI(HWY_SVE_MASKED_LEADING_ZERO_COUNT, MaskedLeadingZeroCount,
66106656
#undef HWY_SVE_IF_NOT_EMULATED_D
66116657
#undef HWY_SVE_PTRUE
66126658
#undef HWY_SVE_RETV_ARGMVV
6659+
#undef HWY_SVE_RETV_ARGMVV_Z
66136660
#undef HWY_SVE_RETV_ARGMV_Z
66146661
#undef HWY_SVE_RETV_ARGMV
66156662
#undef HWY_SVE_RETV_ARGPV

hwy/ops/generic_ops-inl.h

+26
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,28 @@ HWY_API TFromD<D> ReduceMax(D d, VFromD<D> v) {
10131013
}
10141014
#endif // HWY_NATIVE_REDUCE_MINMAX_4_UI8
10151015

1016+
#if (defined(HWY_NATIVE_MASKED_REDUCE_SCALAR) == defined(HWY_TARGET_TOGGLE))
1017+
#ifdef HWY_NATIVE_MASKED_REDUCE_SCALAR
1018+
#undef HWY_NATIVE_MASKED_REDUCE_SCALAR
1019+
#else
1020+
#define HWY_NATIVE_MASKED_REDUCE_SCALAR
1021+
#endif
1022+
1023+
template <class D, class M>
1024+
HWY_API TFromD<D> MaskedReduceSum(D d, M m, VFromD<D> v) {
1025+
return ReduceSum(d, IfThenElseZero(m, v));
1026+
}
1027+
template <class D, class M>
1028+
HWY_API TFromD<D> MaskedReduceMin(D d, M m, VFromD<D> v) {
1029+
return ReduceMin(d, IfThenElse(m, v, Set(d, hwy::PositiveInfOrHighestValue <TFromD<D>>())));
1030+
}
1031+
template <class D, class M>
1032+
HWY_API TFromD<D> MaskedReduceMax(D d, M m, VFromD<D> v) {
1033+
return ReduceMax(d, IfThenElse(m, v, Set(d, hwy::NegativeInfOrLowestValue<TFromD<D>>())));
1034+
}
1035+
1036+
#endif // HWY_NATIVE_MASKED_REDUCE_SCALAR
1037+
10161038
// ------------------------------ IsEitherNaN
10171039
#if (defined(HWY_NATIVE_IS_EITHER_NAN) == defined(HWY_TARGET_TOGGLE))
10181040
#ifdef HWY_NATIVE_IS_EITHER_NAN
@@ -7568,6 +7590,10 @@ HWY_API V BitShuffle(V v, VI idx) {
75687590

75697591
#endif // HWY_NATIVE_BITSHUFFLE
75707592

7593+
template <class V, class M>
7594+
HWY_API V MaskedOr(M m, V a, V b) {
7595+
return IfThenElseZero(m, Or(a, b));
7596+
}
75717597
// ------------------------------ AllBits1/AllBits0
75727598
#if (defined(HWY_NATIVE_ALLONES) == defined(HWY_TARGET_TOGGLE))
75737599
#ifdef HWY_NATIVE_ALLONES

hwy/ops/rvv-inl.h

+2
Original file line numberDiff line numberDiff line change
@@ -4755,6 +4755,8 @@ HWY_API T ReduceMax(D d, const VFromD<D> v) {
47554755

47564756
#undef HWY_RVV_REDUCE
47574757

4758+
// TODO: add MaskedReduceSum/Min/Max
4759+
47584760
// ------------------------------ SumOfLanes
47594761

47604762
template <class D, HWY_IF_LANES_GT_D(D, 1)>

hwy/tests/logical_test.cc

+23
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,28 @@ HWY_NOINLINE void TestAllTestBit() {
146146
ForIntegerTypes(ForPartialVectors<TestTestBit>());
147147
}
148148

149+
struct TestMaskedOr {
150+
template <typename T, class D>
151+
HWY_NOINLINE void operator()(T /*unused*/, D d) {
152+
const MFromD<D> all_true = MaskTrue(d);
153+
const auto v1 = Iota(d, 1);
154+
const auto v2 = Iota(d, 2);
155+
156+
HWY_ASSERT_VEC_EQ(d, Or(v2, v1), MaskedOr(all_true, v1, v2));
157+
158+
const MFromD<D> first_five = FirstN(d, 5);
159+
const Vec<D> v0 = Zero(d);
160+
161+
const Vec<D> v1_exp = IfThenElse(first_five, Or(v2, v1), v0);
162+
163+
HWY_ASSERT_VEC_EQ(d, v1_exp, MaskedOr(first_five, v1, v2));
164+
}
165+
};
166+
167+
HWY_NOINLINE void TestAllMaskedLogical() {
168+
ForAllTypes(ForPartialVectors<TestMaskedOr>());
169+
}
170+
149171
struct TestAllBits {
150172
template <class T, class D>
151173
HWY_NOINLINE void operator()(T /*unused*/, D d) {
@@ -185,6 +207,7 @@ HWY_BEFORE_TEST(HwyLogicalTest);
185207
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllNot);
186208
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllLogical);
187209
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllTestBit);
210+
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllMaskedLogical);
188211
HWY_EXPORT_AND_TEST_P(HwyLogicalTest, TestAllAllBits);
189212

190213
HWY_AFTER_TEST();

hwy/tests/reduction_test.cc

+126
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,128 @@ HWY_NOINLINE void TestAllSumsOf8() {
352352
ForGEVectors<64, TestSumsOf8>()(uint8_t());
353353
}
354354

355+
struct TestMaskedReduceSum {
356+
template <typename T, class D>
357+
HWY_NOINLINE void operator()(T /*unused*/, D d) {
358+
RandomState rng;
359+
360+
using TI = MakeSigned<T>;
361+
const Rebind<TI, D> di;
362+
const Vec<D> v2 = Iota(d, 2);
363+
364+
const size_t N = Lanes(d);
365+
auto bool_lanes = AllocateAligned<TI>(N);
366+
HWY_ASSERT(bool_lanes);
367+
368+
for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
369+
T expected = 0;
370+
for (size_t i = 0; i < N; ++i) {
371+
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
372+
if (bool_lanes[i]) {
373+
expected += ConvertScalarTo<T>(i + 2);
374+
}
375+
}
376+
377+
const auto mask_i = Load(di, bool_lanes.get());
378+
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));
379+
380+
// If all elements are disabled the result is implementation defined
381+
if (AllFalse(d, mask)) {
382+
continue;
383+
}
384+
385+
HWY_ASSERT_EQ(expected, MaskedReduceSum(d, mask, v2));
386+
}
387+
}
388+
};
389+
390+
HWY_NOINLINE void TestAllMaskedReduceSum() {
391+
ForAllTypes(ForPartialVectors<TestMaskedReduceSum>());
392+
}
393+
394+
struct TestMaskedReduceMin {
395+
template <typename T, class D>
396+
HWY_NOINLINE void operator()(T /*unused*/, D d) {
397+
RandomState rng;
398+
399+
using TI = MakeSigned<T>;
400+
const Rebind<TI, D> di;
401+
const Vec<D> v2 = Iota(d, 2);
402+
403+
const size_t N = Lanes(d);
404+
auto bool_lanes = AllocateAligned<TI>(N);
405+
HWY_ASSERT(bool_lanes);
406+
407+
for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
408+
T expected =
409+
ConvertScalarTo<T>(N + 3); // larger than any values in the vector
410+
for (size_t i = 0; i < N; ++i) {
411+
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
412+
if (bool_lanes[i]) {
413+
if (expected > ConvertScalarTo<T>(i + 2)) {
414+
expected = ConvertScalarTo<T>(i + 2);
415+
}
416+
}
417+
}
418+
419+
const auto mask_i = Load(di, bool_lanes.get());
420+
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));
421+
422+
// If all elements are disabled the result is implementation defined
423+
if (AllFalse(d, mask)) {
424+
continue;
425+
}
426+
427+
HWY_ASSERT_EQ(expected, MaskedReduceMin(d, mask, v2));
428+
}
429+
}
430+
};
431+
432+
HWY_NOINLINE void TestAllMaskedReduceMin() {
433+
ForAllTypes(ForPartialVectors<TestMaskedReduceMin>());
434+
}
435+
436+
struct TestMaskedReduceMax {
437+
template <typename T, class D>
438+
HWY_NOINLINE void operator()(T /*unused*/, D d) {
439+
RandomState rng;
440+
441+
using TI = MakeSigned<T>;
442+
const Rebind<TI, D> di;
443+
const Vec<D> v2 = Iota(d, 2);
444+
445+
const size_t N = Lanes(d);
446+
auto bool_lanes = AllocateAligned<TI>(N);
447+
HWY_ASSERT(bool_lanes);
448+
449+
for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
450+
T expected = 0;
451+
for (size_t i = 0; i < N; ++i) {
452+
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
453+
if (bool_lanes[i]) {
454+
if (expected < ConvertScalarTo<T>(i + 2)) {
455+
expected = ConvertScalarTo<T>(i + 2);
456+
}
457+
}
458+
}
459+
460+
const auto mask_i = Load(di, bool_lanes.get());
461+
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));
462+
463+
// If all elements are disabled the result is implementation defined
464+
if (AllFalse(d, mask)) {
465+
continue;
466+
}
467+
468+
HWY_ASSERT_EQ(expected, MaskedReduceMax(d, mask, v2));
469+
}
470+
}
471+
};
472+
473+
HWY_NOINLINE void TestAllMaskedReduceMax() {
474+
ForAllTypes(ForPartialVectors<TestMaskedReduceMax>());
475+
}
476+
355477
} // namespace
356478
// NOLINTNEXTLINE(google-readability-namespace-comments)
357479
} // namespace HWY_NAMESPACE
@@ -367,6 +489,10 @@ HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMinMaxOfLanes);
367489
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf2);
368490
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf4);
369491
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllSumsOf8);
492+
493+
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceSum);
494+
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMin);
495+
HWY_EXPORT_AND_TEST_P(HwyReductionTest, TestAllMaskedReduceMax);
370496
HWY_AFTER_TEST();
371497
} // namespace
372498
} // namespace hwy

0 commit comments

Comments
 (0)