Skip to content

Commit e99df68

Browse files
Merge pull request google#2458 from johnplatts:hwy_get_exp_enh_013125
PiperOrigin-RevId: 721763477
2 parents 140f307 + 45a3804 commit e99df68

8 files changed

+125
-10
lines changed

g3doc/quick_reference.md

+4
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,10 @@ from left to right, of the arguments passed to `Create{2-4}`.
674674
<code>V **GetExponent**(V v)</code>: returns the exponent of `v[i]` as a floating point value.
675675
Essentially calculates `floor(log2(x))`.
676676

677+
* `V`: `{f}`, `VU`: `Vec<RebindToUnsigned<DFromV<V>>>` \
678+
<code>VU **GetBiasedExponent**(V v)</code>: returns the biased exponent of
679+
`v[i]` as an unsigned integer value.
680+
677681
#### Min/Max
678682

679683
**Note**: Min/Max corner cases are target-specific and may change. If either

hwy/ops/generic_ops-inl.h

+22-3
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,27 @@ HWY_API V MulByFloorPow2(V v, V exp) {
12491249

12501250
#endif // HWY_NATIVE_MUL_BY_POW2
12511251

1252+
// ------------------------------ GetBiasedExponent
1253+
#if (defined(HWY_NATIVE_GET_BIASED_EXPONENT) == defined(HWY_TARGET_TOGGLE))
1254+
#ifdef HWY_NATIVE_GET_BIASED_EXPONENT
1255+
#undef HWY_NATIVE_GET_BIASED_EXPONENT
1256+
#else
1257+
#define HWY_NATIVE_GET_BIASED_EXPONENT
1258+
#endif
1259+
1260+
template <class V, HWY_IF_FLOAT_V(V)>
1261+
HWY_API VFromD<RebindToUnsigned<DFromV<V>>> GetBiasedExponent(V v) {
1262+
using T = TFromV<V>;
1263+
1264+
const DFromV<V> d;
1265+
const RebindToUnsigned<decltype(d)> du;
1266+
1267+
constexpr int kNumOfMantBits = MantissaBits<T>();
1268+
return ShiftRight<kNumOfMantBits>(BitCast(du, Abs(v)));
1269+
}
1270+
1271+
#endif
1272+
12521273
// ------------------------------ GetExponent
12531274

12541275
#if (defined(HWY_NATIVE_GET_EXPONENT) == defined(HWY_TARGET_TOGGLE))
@@ -1262,14 +1283,12 @@ template <class V, HWY_IF_FLOAT_V(V)>
12621283
HWY_API V GetExponent(V v) {
12631284
const DFromV<V> d;
12641285
using T = TFromV<V>;
1265-
const RebindToUnsigned<decltype(d)> du;
12661286
const RebindToSigned<decltype(d)> di;
12671287

1268-
constexpr uint8_t mantissa_bits = MantissaBits<T>();
12691288
const auto exponent_offset = Set(di, MaxExponentField<T>() >> 1);
12701289

12711290
// extract exponent bits as integer
1272-
const auto encoded_exponent = ShiftRight<mantissa_bits>(BitCast(du, Abs(v)));
1291+
const auto encoded_exponent = GetBiasedExponent(v);
12731292
const auto exponent_int = Sub(BitCast(di, encoded_exponent), exponent_offset);
12741293

12751294
// convert integer to original type

hwy/ops/ppc_vsx-inl.h

+17-3
Original file line numberDiff line numberDiff line change
@@ -1939,9 +1939,6 @@ HWY_API Vec128<T, N> ApproximateReciprocal(Vec128<T, N> v) {
19391939
#endif
19401940
}
19411941

1942-
// TODO: Implement GetExponent using vec_extract_exp (which returns the biased
1943-
// exponent) followed by a subtraction by MaxExponentField<T>() >> 1
1944-
19451942
// ------------------------------ Floating-point square root
19461943

19471944
#if HWY_S390X_HAVE_Z14
@@ -1979,6 +1976,23 @@ HWY_API Vec128<T, N> Sqrt(Vec128<T, N> v) {
19791976
return Vec128<T, N>{vec_sqrt(v.raw)};
19801977
}
19811978

1979+
// ------------------------------ GetBiasedExponent
1980+
1981+
#if HWY_PPC_HAVE_9
1982+
1983+
#ifdef HWY_NATIVE_GET_BIASED_EXPONENT
1984+
#undef HWY_NATIVE_GET_BIASED_EXPONENT
1985+
#else
1986+
#define HWY_NATIVE_GET_BIASED_EXPONENT
1987+
#endif
1988+
1989+
template <class V, HWY_IF_FLOAT3264_V(V)>
1990+
HWY_API VFromD<RebindToUnsigned<DFromV<V>>> GetBiasedExponent(V v) {
1991+
return VFromD<RebindToUnsigned<DFromV<V>>>{vec_extract_exp(v.raw)};
1992+
}
1993+
1994+
#endif // HWY_PPC_HAVE_9
1995+
19821996
// ------------------------------ Min (Gt, IfThenElse)
19831997

19841998
template <typename T, size_t N, HWY_IF_NOT_SPECIAL_FLOAT(T)>

hwy/ops/shared-inl.h

+3
Original file line numberDiff line numberDiff line change
@@ -631,8 +631,11 @@ HWY_API bool IsAligned(D d, T* ptr) {
631631
#define HWY_IF_SIGNED_V(V) HWY_IF_SIGNED(hwy::HWY_NAMESPACE::TFromV<V>)
632632
#define HWY_IF_FLOAT_V(V) HWY_IF_FLOAT(hwy::HWY_NAMESPACE::TFromV<V>)
633633
#define HWY_IF_NOT_FLOAT_V(V) HWY_IF_NOT_FLOAT(hwy::HWY_NAMESPACE::TFromV<V>)
634+
#define HWY_IF_FLOAT3264_V(V) HWY_IF_FLOAT3264(hwy::HWY_NAMESPACE::TFromV<V>)
634635
#define HWY_IF_SPECIAL_FLOAT_V(V) \
635636
HWY_IF_SPECIAL_FLOAT(hwy::HWY_NAMESPACE::TFromV<V>)
637+
#define HWY_IF_FLOAT_OR_SPECIAL_V(V) \
638+
HWY_IF_FLOAT_OR_SPECIAL(hwy::HWY_NAMESPACE::TFromV<V>)
636639
#define HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V) \
637640
HWY_IF_NOT_FLOAT_NOR_SPECIAL(hwy::HWY_NAMESPACE::TFromV<V>)
638641

hwy/ops/x86_128-inl.h

+27
Original file line numberDiff line numberDiff line change
@@ -5170,6 +5170,33 @@ HWY_API V AbsDiff(V a, V b) {
51705170
return Abs(a - b);
51715171
}
51725172

5173+
// ------------------------------ GetExponent
5174+
5175+
#if HWY_TARGET <= HWY_AVX3
5176+
5177+
#ifdef HWY_NATIVE_GET_EXPONENT
5178+
#undef HWY_NATIVE_GET_EXPONENT
5179+
#else
5180+
#define HWY_NATIVE_GET_EXPONENT
5181+
#endif
5182+
5183+
#if HWY_HAVE_FLOAT16
5184+
template <class V, HWY_IF_F16(TFromV<V>), HWY_IF_V_SIZE_LE_V(V, 16)>
5185+
HWY_API V GetExponent(V v) {
5186+
return V{_mm_getexp_ph(v.raw)};
5187+
}
5188+
#endif
5189+
template <class V, HWY_IF_F32(TFromV<V>), HWY_IF_V_SIZE_LE_V(V, 16)>
5190+
HWY_API V GetExponent(V v) {
5191+
return V{_mm_getexp_ps(v.raw)};
5192+
}
5193+
template <class V, HWY_IF_F64(TFromV<V>), HWY_IF_V_SIZE_LE_V(V, 16)>
5194+
HWY_API V GetExponent(V v) {
5195+
return V{_mm_getexp_pd(v.raw)};
5196+
}
5197+
5198+
#endif
5199+
51735200
// ------------------------------ MaskedMinOr
51745201

51755202
#if HWY_TARGET <= HWY_AVX3

hwy/ops/x86_256-inl.h

+21
Original file line numberDiff line numberDiff line change
@@ -2734,6 +2734,27 @@ HWY_API Vec256<double> ApproximateReciprocal(Vec256<double> v) {
27342734
}
27352735
#endif
27362736

2737+
// ------------------------------ GetExponent
2738+
2739+
#if HWY_TARGET <= HWY_AVX3
2740+
2741+
#if HWY_HAVE_FLOAT16
2742+
template <class V, HWY_IF_F16(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)>
2743+
HWY_API V GetExponent(V v) {
2744+
return V{_mm256_getexp_ph(v.raw)};
2745+
}
2746+
#endif
2747+
template <class V, HWY_IF_F32(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)>
2748+
HWY_API V GetExponent(V v) {
2749+
return V{_mm256_getexp_ps(v.raw)};
2750+
}
2751+
template <class V, HWY_IF_F64(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)>
2752+
HWY_API V GetExponent(V v) {
2753+
return V{_mm256_getexp_pd(v.raw)};
2754+
}
2755+
2756+
#endif
2757+
27372758
// ------------------------------ MaskedMinOr
27382759

27392760
#if HWY_TARGET <= HWY_AVX3

hwy/ops/x86_512-inl.h

+16-1
Original file line numberDiff line numberDiff line change
@@ -1842,7 +1842,22 @@ HWY_API Vec512<double> ApproximateReciprocal(Vec512<double> v) {
18421842
return Vec512<double>{_mm512_rcp14_pd(v.raw)};
18431843
}
18441844

1845-
// TODO: Implement GetExponent using _mm_getexp_ps/_mm_getexp_pd/_mm_getexp_ph
1845+
// ------------------------------ GetExponent
1846+
1847+
#if HWY_HAVE_FLOAT16
1848+
template <class V, HWY_IF_F16(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)>
1849+
HWY_API V GetExponent(V v) {
1850+
return V{_mm512_getexp_ph(v.raw)};
1851+
}
1852+
#endif
1853+
template <class V, HWY_IF_F32(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)>
1854+
HWY_API V GetExponent(V v) {
1855+
return V{_mm512_getexp_ps(v.raw)};
1856+
}
1857+
template <class V, HWY_IF_F64(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)>
1858+
HWY_API V GetExponent(V v) {
1859+
return V{_mm512_getexp_pd(v.raw)};
1860+
}
18461861

18471862
// ------------------------------ MaskedMinOr
18481863

hwy/tests/float_test.cc

+15-3
Original file line numberDiff line numberDiff line change
@@ -605,18 +605,30 @@ HWY_NOINLINE void TestAllAbsDiff() {
605605
struct TestGetExponent {
606606
template <typename T, class D>
607607
HWY_NOINLINE void operator()(T /*unused*/, D d) {
608+
const RebindToUnsigned<decltype(d)> du;
609+
610+
using TFArith = If<IsSpecialFloat<T>(), float, T>;
611+
using TU = MakeUnsigned<T>;
608612
const size_t N = Lanes(d);
609613

610614
auto v = Iota(d, 1);
611615

612616
auto expected = AllocateAligned<T>(N);
613-
HWY_ASSERT(expected);
617+
auto expected_biased = AllocateAligned<TU>(N);
618+
HWY_ASSERT(expected && expected_biased);
619+
620+
constexpr int kNumOfMantBits = MantissaBits<T>();
614621

615622
for (size_t i = 0; i < N; ++i) {
616-
auto test_val = (float)(i + 1);
617-
expected[i] = ConvertScalarTo<T>(std::floor(std::log2(test_val)));
623+
const T test_val = ConvertScalarTo<T>(i + 1);
624+
expected[i] = ConvertScalarTo<T>(
625+
std::floor(std::log2(ConvertScalarTo<TFArith>(test_val))));
626+
expected_biased[i] =
627+
static_cast<TU>((BitCastScalar<TU>(test_val) >> kNumOfMantBits) &
628+
static_cast<TU>(MaxExponentField<T>()));
618629
}
619630
HWY_ASSERT_VEC_EQ(d, expected.get(), GetExponent(v));
631+
HWY_ASSERT_VEC_EQ(du, expected_biased.get(), GetBiasedExponent(v));
620632
}
621633
};
622634

0 commit comments

Comments
 (0)