Skip to content

Commit

Permalink
Fix review comments
Browse files Browse the repository at this point in the history
Remove MulLower
Use MaskedMulOr instead
Replace MulAddLower with MaskedMulAddOr
  • Loading branch information
wbb-ccl committed Jan 29, 2025
1 parent 4844618 commit 11505cf
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 115 deletions.
8 changes: 2 additions & 6 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -738,9 +738,6 @@ All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
<code>V **MulHigh**(V a, V b)</code>: returns the upper half of `a[i] *
b[i]` in each lane.

* <code>V **MulLower**(V a, V b)</code>: returns `a[0] * b[0]` in the
first lane and `a[i]` otherwise.

* `V`: `i16` \
<code>V **MulFixedPoint15**(V a, V b)</code>: returns the result of
multiplying two Q1.15 fixed-point numbers. This corresponds to doubling the
Expand Down Expand Up @@ -863,9 +860,6 @@ variants are somewhat slower on Arm, and unavailable for integer inputs; if the
potentially more efficient than `MulAdd(PromoteOddTo(d, a), PromoteOddTo(d,
b), c)`.

* <code>V **MulAddLower**(V a, V b, V c)</code>: returns `a[0] * b[0] + c[0]`
and `a[i]` in all other lanes.

#### Masked arithmetic

All ops in this section return `no` for `mask=false` lanes, and suppress any
Expand Down Expand Up @@ -896,6 +890,8 @@ not a concern, these are equivalent to, and potentially more efficient than,
<code>V **MaskedSatSubOr**(V no, M m, V a, V b)</code>: returns `a[i] +
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if
`m[i]` is false.
* <code>V **MaskedMulAddOr**(V no, M m, V mul, V x, V add)</code>: returns
`mul[i] * x[i] + add[i]` or `no[i]` if `m[i]` is false.

#### Shifts

Expand Down
51 changes: 12 additions & 39 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_m(m, a, b, c); \
return sv##OP##_##CHAR##BITS##_x(m, a, b, c); \
}

// ------------------------------ Lanes
Expand Down Expand Up @@ -1284,31 +1284,6 @@ HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)

#undef HWY_SVE_FMA

// ------------------------------ MaskedMulAdd
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad)
}

// ------------------------------ MulAddLower
#if (defined(HWY_NATIVE_MUL_ADD_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MUL_ADD_LOWER
#undef HWY_NATIVE_MUL_ADD_LOWER
#else
#define HWY_NATIVE_MUL_ADD_LOWER
#endif

#define HWY_SVE_MUL_ADD_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return detail::MaskedMulAdd(svptrue_pat_b##BITS(SV_VL1), a, b, c); \
}

HWY_SVE_FOREACH(HWY_SVE_MUL_ADD_LOWER, MulAddLower, _)
#undef HWY_SVE_MUL_ADD_LOWER

#endif // HWY_NATIVE_MUL_ADD_LOWER

// ------------------------------ Round etc.

HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
Expand Down Expand Up @@ -1621,25 +1596,23 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif

// ------------------------------ MaskedMul_M
// ------------------------------ MaskedMulAddOr
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_M, MaskedMul_M, mul);
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad)
}

// ------------------------------ MulLower
#ifdef HWY_NATIVE_MUL_LOWER
#undef HWY_NATIVE_MUL_LOWER
// Per-target flag to prevent generic_ops-inl.h from defining int
// MaskedMulAddOr.
#ifdef HWY_NATIVE_MASKED_INT_FMA
#undef HWY_NATIVE_MASKED_INT_FMA
#else
#define HWY_NATIVE_MUL_LOWER
#define HWY_NATIVE_MASKED_INT_FMA
#endif

#define HWY_SVE_MUL_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return detail::MaskedMul_M(svptrue_pat_b##BITS(SV_VL1), a, b); \
}

HWY_SVE_FOREACH(HWY_SVE_MUL_LOWER, MulLower, _)
template <class V, class M>
HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) {
return IfThenElse(m, detail::MaskedMulAdd(m, mul, x, add), no);
}

// ================================================== COMPARE

Expand Down
36 changes: 9 additions & 27 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,22 +520,6 @@ HWY_API V AddSub(V a, V b) {
return Add(a, negated_even_b);
}

#if (defined(HWY_NATIVE_MUL_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MUL_LOWER
#undef HWY_NATIVE_MUL_LOWER
#else
#define HWY_NATIVE_MUL_LOWER
#endif

template <class V>
HWY_API V MulLower(V a, V b) {
const DFromV<V> d;
const auto first_mask = FirstN(d, 1);
return MaskedMulOr(a, first_mask, a, b);
}

#endif // HWY_NATIVE_MUL_LOWER

// ------------------------------ MaskedAddOr etc.
#if (defined(HWY_NATIVE_MASKED_ARITH) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_ARITH
Expand Down Expand Up @@ -4327,22 +4311,20 @@ HWY_API V MulSub(V mul, V x, V sub) {
}
#endif // HWY_NATIVE_INT_FMA

// ------------------------------ MulAddLower
#if (defined(HWY_NATIVE_MUL_ADD_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MUL_ADD_LOWER
#undef HWY_NATIVE_MUL_ADD_LOWER
// ------------------------------ MaskedMulAddOr
#if (defined(HWY_NATIVE_MASKED_INT_FMA) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_INT_FMA
#undef HWY_NATIVE_MASKED_INT_FMA
#else
#define HWY_NATIVE_MUL_ADD_LOWER
#define HWY_NATIVE_MASKED_INT_FMA
#endif

template <class V>
HWY_API V MulAddLower(const V a, const V b, const V c) {
const DFromV<V> d;
const MFromD<DFromV<V>> LowerMask = FirstN(d, 1);
return IfThenElse(LowerMask, MulAdd(a, b, c), a);
template <class V, class M>
HWY_API V MaskedMulAddOr(V no, M m, V mul, V x, V add) {
return IfThenElse(m, MulAdd(mul, x, add), no);
}

#endif // HWY_NATIVE_MUL_ADD_LOWER
#endif // HWY_NATIVE_MASKED_INT_FMA

// ------------------------------ Integer MulSub / NegMulSub
#if (defined(HWY_NATIVE_INT_FMSUB) == defined(HWY_TARGET_TOGGLE))
Expand Down
74 changes: 31 additions & 43 deletions hwy/tests/masked_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,65 +379,54 @@ HWY_NOINLINE void TestAllFloatExceptions() {
ForFloatTypes(ForPartialVectors<TestFloatExceptions>());
}

struct TestMulLower {
struct TestMaskedMulAdd {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const auto v0 = Zero(d);

HWY_ASSERT_VEC_EQ(d, v0, MulLower(v0, v0));

const auto v2 = Iota(d, 2);
const auto v3 = Iota(d, 3);
RandomState rng;
const Vec<D> k0 = Zero(d);
const Vec<D> v1 = Iota(d, 1);
const Vec<D> v2 = Iota(d, 2);

using TI = MakeSigned<T>; // For mask > 0 comparison
const Rebind<TI, D> di;
using VI = Vec<decltype(di)>;
const size_t N = Lanes(d);
auto bool_lanes = AllocateAligned<TI>(N);
auto expected = AllocateAligned<T>(N);
HWY_ASSERT(bool_lanes && expected);
HWY_ASSERT_VEC_EQ(d, k0, MaskedMulAddOr(v1, MaskTrue(d), k0, k0, k0));
HWY_ASSERT_VEC_EQ(d, v2, MaskedMulAddOr(v1, MaskTrue(d), k0, v1, v2));
HWY_ASSERT_VEC_EQ(d, v2, MaskedMulAddOr(v1, MaskTrue(d), v1, k0, v2));
HWY_ASSERT_VEC_EQ(d, v1, MaskedMulAddOr(v1, MaskFalse(d), k0, k0, k0));
HWY_ASSERT_VEC_EQ(d, v1, MaskedMulAddOr(v1, MaskFalse(d), k0, v1, v2));
HWY_ASSERT_VEC_EQ(d, v1, MaskedMulAddOr(v1, MaskFalse(d), v1, k0, v2));

for (size_t i = 0; i < N; ++i) {
if (i == 0) {
expected[i] = ConvertScalarTo<T>(2 * 3);
bool_lanes[i] = (Random32(&rng) & 1024) ? TI(1) : TI(0);
if (bool_lanes[i]) {
expected[i] = ConvertScalarTo<T>((i + 1) * (i + 2));
} else {
expected[i] = ConvertScalarTo<T>(i + 2);
expected[i] = ConvertScalarTo<T>(i + 1);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(), MulLower(v2, v3));
}
};

HWY_NOINLINE void TestAllMulLower() {
ForAllTypes(ForPartialVectors<TestMulLower>());
}

struct TestMulAddLower {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const Vec<D> v0 = Zero(d);

// Test all zeros
HWY_ASSERT_VEC_EQ(d, v0, MulAddLower(v0, v0, v0));

// Test upper lanes of a being passed through
const Vec<D> v1 = Iota(d, 1);
const Vec<D> v2 = Iota(d, 2);
const Vec<D> v3 = Iota(d, 3);

const size_t N = Lanes(d);
auto expected = AllocateAligned<T>(N);
const VI mask_i = Load(di, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(di)));
HWY_ASSERT_VEC_EQ(d, expected.get(), MaskedMulAddOr(v1, mask, v2, v1, k0));
HWY_ASSERT_VEC_EQ(d, expected.get(), MaskedMulAddOr(v1, mask, v1, v2, k0));

for (size_t i = 0; i < N; ++i) {
if (i == 0) {
expected[i] = ConvertScalarTo<T>(5);
if (bool_lanes[i]) {
expected[i] = ConvertScalarTo<T>((i + 2) * (i + 2) + (i + 1));
} else {
expected[i] = static_cast<T>(i + 1);
expected[i] = ConvertScalarTo<T>(i + 2);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(), MulAddLower(v1, v2, v3));
HWY_ASSERT_VEC_EQ(d, expected.get(), MaskedMulAddOr(v2, mask, v2, v2, v1));
}
};

HWY_NOINLINE void TestAllTestMulAddLower() {
ForAllTypes(ForPartialVectors<TestMulAddLower>());
HWY_NOINLINE void TestAllMaskedMulAdd() {
ForAllTypes(ForPartialVectors<TestMaskedMulAdd>());
}
} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
Expand All @@ -454,8 +443,7 @@ HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllSatAddSub);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllDiv);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllIntegerDivMod);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllFloatExceptions);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllMulLower);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllTestMulAddLower);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllMaskedMulAdd);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down

0 comments on commit 11505cf

Please sign in to comment.