Skip to content

Commit

Permalink
MulRound, MulLower and MulAddLower ops
Browse files Browse the repository at this point in the history
  • Loading branch information
mazimkhan committed Nov 18, 2024
1 parent d77be29 commit 4844618
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 0 deletions.
10 changes: 10 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,10 @@ All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
truncating it to the lower half for integer inputs. Currently unavailable on
SVE/RVV; use the equivalent `Mul` instead.

* `V`: `f`
<code>V **MulRound**(V a, V b)</code>: Multiplies `a[i]` by `b[i]` and rounds
the result to the nearest int with ties going to even.

* `V`: `f`, `VI`: `Vec<RebindToSigned<DFromV<V>>>` \
<code>V **MulByPow2**(V a, VI b)</code>: Multiplies `a[i]` by `2^b[i]`.

Expand Down Expand Up @@ -734,6 +738,9 @@ All other ops in this section are only available if `HWY_TARGET != HWY_SCALAR`:
<code>V **MulHigh**(V a, V b)</code>: returns the upper half of `a[i] *
b[i]` in each lane.

* <code>V **MulLower**(V a, V b)</code>: returns `a[0] * b[0]` in the
first lane and `a[i]` otherwise.

* `V`: `i16` \
<code>V **MulFixedPoint15**(V a, V b)</code>: returns the result of
multiplying two Q1.15 fixed-point numbers. This corresponds to doubling the
Expand Down Expand Up @@ -856,6 +863,9 @@ variants are somewhat slower on Arm, and unavailable for integer inputs; if the
potentially more efficient than `MulAdd(PromoteOddTo(d, a), PromoteOddTo(d,
b), c)`.

* <code>V **MulAddLower**(V a, V b, V c)</code>: returns `a[0] * b[0] + c[0]`
and `a[i]` in all other lanes.

#### Masked arithmetic

All ops in this section return `no` for `mask=false` lanes, and suppress any
Expand Down
71 changes: 71 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_x(m, a, b); \
}
#define HWY_SVE_RETV_ARGMVV_M(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS##_m(m, a, b); \
}

#define HWY_SVE_RETV_ARGVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
Expand All @@ -260,6 +265,13 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
return sv##OP##_##CHAR##BITS(a, b, c); \
}

#define HWY_SVE_RETV_ARGMVVV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return sv##OP##_##CHAR##BITS##_m(m, a, b, c); \
}

// ------------------------------ Lanes

namespace detail {
Expand Down Expand Up @@ -1272,6 +1284,31 @@ HWY_SVE_FOREACH_F(HWY_SVE_FMA, NegMulSub, nmad)

#undef HWY_SVE_FMA

// ------------------------------ MaskedMulAdd
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVVV, MaskedMulAdd, mad)
}

// ------------------------------ MulAddLower
#if (defined(HWY_NATIVE_MUL_ADD_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MUL_ADD_LOWER
#undef HWY_NATIVE_MUL_ADD_LOWER
#else
#define HWY_NATIVE_MUL_ADD_LOWER
#endif

#define HWY_SVE_MUL_ADD_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
HWY_SVE_V(BASE, BITS) c) { \
return detail::MaskedMulAdd(svptrue_pat_b##BITS(SV_VL1), a, b, c); \
}

HWY_SVE_FOREACH(HWY_SVE_MUL_ADD_LOWER, MulAddLower, _)
#undef HWY_SVE_MUL_ADD_LOWER

#endif // HWY_NATIVE_MUL_ADD_LOWER

// ------------------------------ Round etc.

HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Round, rintn)
Expand Down Expand Up @@ -1584,6 +1621,26 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif

// ------------------------------ MaskedMul_M
namespace detail {
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV_M, MaskedMul_M, mul);
}

// ------------------------------ MulLower
#ifdef HWY_NATIVE_MUL_LOWER
#undef HWY_NATIVE_MUL_LOWER
#else
#define HWY_NATIVE_MUL_LOWER
#endif

#define HWY_SVE_MUL_LOWER(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b) { \
return detail::MaskedMul_M(svptrue_pat_b##BITS(SV_VL1), a, b); \
}

HWY_SVE_FOREACH(HWY_SVE_MUL_LOWER, MulLower, _)

// ================================================== COMPARE

// mask = f(vector, vector)
Expand Down Expand Up @@ -1783,6 +1840,18 @@ HWY_API svbool_t IsFinite(const V v) {
return RebindMask(d, detail::LtN(exp, hwy::MaxExponentField<T>()));
}

// ------------------------------ MulByPow2/MulByFloorPow2

#define HWY_SVE_MUL_BY_POW2(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) \
NAME(HWY_SVE_V(BASE, BITS) v, HWY_SVE_V(int, BITS) exp) { \
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v, exp); \
}

HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale)

#undef HWY_SVE_MUL_BY_POW2

// ================================================== MEMORY

// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream
Expand Down Expand Up @@ -6297,7 +6366,9 @@ HWY_API V HighestSetBitIndex(V v) {
#undef HWY_SVE_RETV_ARGV
#undef HWY_SVE_RETV_ARGVN
#undef HWY_SVE_RETV_ARGVV
#undef HWY_SVE_RETV_ARGMVV_M
#undef HWY_SVE_RETV_ARGVVV
#undef HWY_SVE_RETV_ARGMVVV
#undef HWY_SVE_T
#undef HWY_SVE_UNDEFINED
#undef HWY_SVE_V
Expand Down
39 changes: 39 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,22 @@ HWY_API V AddSub(V a, V b) {
return Add(a, negated_even_b);
}

#if (defined(HWY_NATIVE_MUL_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MUL_LOWER
#undef HWY_NATIVE_MUL_LOWER
#else
#define HWY_NATIVE_MUL_LOWER
#endif

template <class V>
HWY_API V MulLower(V a, V b) {
const DFromV<V> d;
const auto first_mask = FirstN(d, 1);
return MaskedMulOr(a, first_mask, a, b);
}

#endif // HWY_NATIVE_MUL_LOWER

// ------------------------------ MaskedAddOr etc.
#if (defined(HWY_NATIVE_MASKED_ARITH) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_ARITH
Expand Down Expand Up @@ -4274,6 +4290,12 @@ HWY_API V operator*(V x, V y) {

#endif // HWY_NATIVE_MUL_64

// ------------------------------ MulRound
template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulRound(V a, V b) {
return Round(Mul(a, b));
}

// ------------------------------ MulAdd / NegMulAdd

#if (defined(HWY_NATIVE_INT_FMA) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -4305,6 +4327,23 @@ HWY_API V MulSub(V mul, V x, V sub) {
}
#endif // HWY_NATIVE_INT_FMA

// ------------------------------ MulAddLower
#if (defined(HWY_NATIVE_MUL_ADD_LOWER) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MUL_ADD_LOWER
#undef HWY_NATIVE_MUL_ADD_LOWER
#else
#define HWY_NATIVE_MUL_ADD_LOWER
#endif

template <class V>
HWY_API V MulAddLower(const V a, const V b, const V c) {
const DFromV<V> d;
const MFromD<DFromV<V>> LowerMask = FirstN(d, 1);
return IfThenElse(LowerMask, MulAdd(a, b, c), a);
}

#endif // HWY_NATIVE_MUL_ADD_LOWER

// ------------------------------ Integer MulSub / NegMulSub
#if (defined(HWY_NATIVE_INT_FMSUB) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_INT_FMSUB
Expand Down
62 changes: 62 additions & 0 deletions hwy/tests/masked_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,66 @@ HWY_NOINLINE void TestAllFloatExceptions() {
ForFloatTypes(ForPartialVectors<TestFloatExceptions>());
}

struct TestMulLower {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const auto v0 = Zero(d);

HWY_ASSERT_VEC_EQ(d, v0, MulLower(v0, v0));

const auto v2 = Iota(d, 2);
const auto v3 = Iota(d, 3);

const size_t N = Lanes(d);
auto expected = AllocateAligned<T>(N);

for (size_t i = 0; i < N; ++i) {
if (i == 0) {
expected[i] = ConvertScalarTo<T>(2 * 3);
} else {
expected[i] = ConvertScalarTo<T>(i + 2);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(), MulLower(v2, v3));
}
};

HWY_NOINLINE void TestAllMulLower() {
ForAllTypes(ForPartialVectors<TestMulLower>());
}

struct TestMulAddLower {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const Vec<D> v0 = Zero(d);

// Test all zeros
HWY_ASSERT_VEC_EQ(d, v0, MulAddLower(v0, v0, v0));

// Test upper lanes of a being passed through
const Vec<D> v1 = Iota(d, 1);
const Vec<D> v2 = Iota(d, 2);
const Vec<D> v3 = Iota(d, 3);

const size_t N = Lanes(d);
auto expected = AllocateAligned<T>(N);

for (size_t i = 0; i < N; ++i) {
if (i == 0) {
expected[i] = ConvertScalarTo<T>(5);
} else {
expected[i] = static_cast<T>(i + 1);
}
}

HWY_ASSERT_VEC_EQ(d, expected.get(), MulAddLower(v1, v2, v3));
}
};

HWY_NOINLINE void TestAllTestMulAddLower() {
ForAllTypes(ForPartialVectors<TestMulAddLower>());
}
} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
Expand All @@ -394,6 +454,8 @@ HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllSatAddSub);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllDiv);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllIntegerDivMod);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllFloatExceptions);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllMulLower);
HWY_EXPORT_AND_TEST_P(HwyMaskedArithmeticTest, TestAllTestMulAddLower);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
28 changes: 28 additions & 0 deletions hwy/tests/mul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,33 @@ HWY_NOINLINE void TestAllMulOdd() {
// uint64_t MulOdd is already tested in TestMulEvenOdd64
}

struct TestMulRound {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
const Vec<D> v0 = Zero(d);

// Test that we correctly get all zeros
HWY_ASSERT_VEC_EQ(d, v0, MulRound(v0, v0));

// Test that we round to closest even in case of tie
const Vec<D> v_half = Set(d, ConvertScalarTo<T>(0.5f));
const Vec<D> v_1 = Set(d, ConvertScalarTo<T>(1));

HWY_ASSERT_VEC_EQ(d, v0, MulRound(v_half, v_1));

// Test arbitrary multiplication
const Vec<D> v_2 = Set(d, ConvertScalarTo<T>(6.75));
const Vec<D> v_3 = Set(d, ConvertScalarTo<T>(3.33));
const Vec<D> expected = Set(d, ConvertScalarTo<T>(22));

HWY_ASSERT_VEC_EQ(d, expected, MulRound(v_2, v_3));
}
};

HWY_NOINLINE void TestAllMulRound() {
ForFloatTypes(ForPartialVectors<TestMulRound>());
}

} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
Expand All @@ -439,6 +466,7 @@ HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulHigh);
HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulFixedPoint15);
HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulEven);
HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulOdd);
HWY_EXPORT_AND_TEST_P(HwyMulTest, TestAllMulRound);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down

0 comments on commit 4844618

Please sign in to comment.