From 5328b2fd88c10dca31f8da49abc07178c76a9db5 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 14 Jan 2025 11:01:43 +0300 Subject: [PATCH] +add AVX-512BW optimizations of class ResizerBf16Bilinear (part 7: channels=4). --- src/Simd/SimdAvx512bwResizerBilinear.cpp | 58 ++++++++++++------------ src/Test/TestResize.cpp | 2 + 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/Simd/SimdAvx512bwResizerBilinear.cpp b/src/Simd/SimdAvx512bwResizerBilinear.cpp index f308b057e3..ed135a9b71 100644 --- a/src/Simd/SimdAvx512bwResizerBilinear.cpp +++ b/src/Simd/SimdAvx512bwResizerBilinear.cpp @@ -924,7 +924,7 @@ namespace Simd __m512 _1 = _mm512_set1_ps(1.0f); if (_rowBuf) { - if (cn > 3) + if (cn > 4) { Avx2::ResizerBf16Bilinear::Run(src, srcStride, dst, dstStride); return; @@ -975,6 +975,7 @@ namespace Simd __m512 fx1 = _mm512_maskz_loadu_ps(rsMF, _ax.data + dx); __m512 fx0 = _mm512_sub_ps(_1, fx1); _mm512_mask_storeu_ps(pb + dx, rsMF, _mm512_fmadd_ps(fx0, s0, _mm512_mul_ps(fx1, s1))); + dx = rs; } } else if (cn == 2) @@ -996,20 +997,21 @@ namespace Simd __m512 fx1 = _mm512_loadu_ps(_ax.data + dx); __m512 fx0 = _mm512_sub_ps(_1, fx1); _mm512_storeu_ps(pb + dx, _mm512_fmadd_ps(fx0, s0, _mm512_mul_ps(fx1, s1))); + dx = rs; } } else if (cn == 3 && rs >= 3) { - //for (; dx < rs12; dx += 12) - //{ - // const float *pax = _ax.data + dx; - // __m512 fx1 = Load(pax + 0, pax + 3, pax + 6, pax + 9); - // __m512 fx0 = _mm512_sub_ps(_1, fx1); - // __m512i _src = Load((__m128i*)(ps + _ix[dx + 0]), (__m128i*)(ps + _ix[dx + 3]), (__m128i*)(ps + _ix[dx + 6]), (__m128i*)(ps + _ix[dx + 9])); - // __m512 s0 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_src, K8_IDX_30)); - // __m512 s1 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_src, K8_IDX_31)); - // _mm512_storeu_ps(pb + dx, _mm512_permutexvar_ps(RSB_3_P1, _mm512_fmadd_ps(fx0, s0, _mm512_mul_ps(fx1, s1)))); - //} + for (; dx < rs12; dx += 12) + { + const float *pax = _ax.data + dx; + __m512 fx1 = Load(pax + 0, pax + 3, pax + 6, pax + 9); + __m512 fx0 = _mm512_sub_ps(_1, fx1); + __m512i _src = Load((__m128i*)(ps + _ix[dx + 0]), (__m128i*)(ps + _ix[dx + 3]), (__m128i*)(ps + _ix[dx + 6]), (__m128i*)(ps + _ix[dx + 9])); + __m512 s0 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_src, K8_IDX_30)); + __m512 s1 = _mm512_castsi512_ps(_mm512_shuffle_epi8(_src, K8_IDX_31)); + _mm512_storeu_ps(pb + dx, _mm512_permutexvar_ps(RSB_3_P1, _mm512_fmadd_ps(fx0, s0, _mm512_mul_ps(fx1, s1)))); + } for (; dx < rs6; dx += 6) { __m256 fx1 = Avx2::Load(_ax.data + dx, _ax.data + dx + 3); @@ -1026,23 +1028,23 @@ namespace Simd _mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps + _ix[dx], cn, fx0, fx1)); } } - // else if (cn == 4) - // { - // for (; dx < rsF; dx += F) - // { - // __m256 fx1 = Load(_ax.data + dx, _ax.data + dx + 4); - // __m256 fx0 = _mm256_sub_ps(_1, fx1); - // __m256i _src = Load((__m128i*)(ps + _ix[dx + 0]), (__m128i*)(ps + _ix[dx + 4])); - // _mm256_storeu_ps(pb + dx, _mm256_fmadd_ps(fx0, BFloat16ToFloat32<0>(_src), _mm256_mul_ps(fx1, BFloat16ToFloat32<1>(_src)))); - // } - // for (; dx < rs; dx += 4) - // { - // __m128 fx1 = _mm_set1_ps(_ax[dx]); - // __m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1); - // __m128i _src = _mm_loadu_si128((__m128i*)(ps + _ix[dx])); - // _mm_storeu_ps(pb + dx, _mm_add_ps(_mm_mul_ps(fx0, Sse41::BFloat16ToFloat32<0>(_src)), _mm_mul_ps(fx1, Sse41::BFloat16ToFloat32<1>(_src)))); - // } - // } + else if (cn == 4) + { + for (; dx < rsF; dx += F) + { + __m512i _src = Load((__m128i*)(ps + _ix[dx + 0]), (__m128i*)(ps + _ix[dx + 4]), (__m128i*)(ps + _ix[dx + 8]), (__m128i*)(ps + _ix[dx + 12])); + __m512 fx1 = _mm512_loadu_ps(_ax.data + dx); + __m512 fx0 = _mm512_sub_ps(_1, fx1); + _mm512_storeu_ps(pb + dx, _mm512_fmadd_ps(fx0, _mm512_castsi512_ps(UnpackU16<0>(K_ZERO, _src)), _mm512_mul_ps(fx1, _mm512_castsi512_ps(UnpackU16<1>(K_ZERO, _src))))); + } + for (; dx < rs; dx += 4) + { + __m128 fx1 = _mm_set1_ps(_ax[dx]); + __m128 fx0 = _mm_sub_ps(_mm512_castps512_ps128(_1), fx1); + __m128i _src = _mm_loadu_si128((__m128i*)(ps + _ix[dx])); + _mm_storeu_ps(pb + dx, _mm_add_ps(_mm_mul_ps(fx0, Sse41::BFloat16ToFloat32<0>(_src)), _mm_mul_ps(fx1, Sse41::BFloat16ToFloat32<1>(_src)))); + } + } // if (cn >= 8) // { // for (; dx < rs;) diff --git a/src/Test/TestResize.cpp b/src/Test/TestResize.cpp index ac2ac3d0c6..671802f08f 100644 --- a/src/Test/TestResize.cpp +++ b/src/Test/TestResize.cpp @@ -246,6 +246,8 @@ namespace Test { bool result = true; + result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 4, f1, f2); + #if 0 #if defined(SIMD_X64_ENABLE) result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelFloat, 64, f1, f2);