From fa8c22f5b69b07e362a5d2588e1edffe5eb42829 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Sun, 5 Jan 2025 08:55:44 +0300 Subject: [PATCH] +add AVX2 optimizations of class ResizerBf16Bilinear (part 3). --- src/Simd/SimdAvx2ResizerBilinear.cpp | 134 +++++++++++++++++---------- src/Simd/SimdBaseResizerBilinear.cpp | 2 +- 2 files changed, 87 insertions(+), 49 deletions(-) diff --git a/src/Simd/SimdAvx2ResizerBilinear.cpp b/src/Simd/SimdAvx2ResizerBilinear.cpp index 183bbf90af..6af67d37f2 100644 --- a/src/Simd/SimdAvx2ResizerBilinear.cpp +++ b/src/Simd/SimdAvx2ResizerBilinear.cpp @@ -954,13 +954,22 @@ namespace Simd { __m128 s0 = Sse41::BFloat16ToFloat32(Sse41::UnpackU16<0>(_mm_loadl_epi64((__m128i*)src))); __m128 s1 = Sse41::BFloat16ToFloat32(Sse41::UnpackU16<0>(_mm_loadl_epi64((__m128i*)(src + channels)))); - return _mm_add_ps(_mm_mul_ps(fx0, s0), _mm_mul_ps(fx1, s1)); + return _mm_fmadd_ps(fx0, s0, _mm_mul_ps(fx1, s1)); + } + + SIMD_INLINE __m256 BilinearRowSumBf16(const uint16_t* src, size_t channels, __m256 fx0, __m256 fx1) + { + __m256 s0 = BFloat16ToFloat32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)src))); + __m256 s1 = BFloat16ToFloat32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i*)(src + channels)))); + return _mm256_fmadd_ps(fx0, s0, _mm256_mul_ps(fx1, s1)); } void ResizerBf16Bilinear::Run(const uint16_t* src, size_t srcStride, uint16_t* dst, size_t dstStride) { - size_t cn = _param.channels, cnF = AlignLo(cn, Sse41::F), cnT = cn - cnF, cnL = cnT - Sse41::F; - __m128 _1 = _mm_set1_ps(1.0f); + size_t cn = _param.channels, + cnH = AlignLo(cn, Sse41::F), cnTH = cn - cnH, cnLH = cnTH - Sse41::F, + cnF = AlignLo(cn, F), cnTF = cn - cnF, cnLF = cnTF - F; + __m256 _1 = _mm256_set1_ps(1.0f); if (_rowBuf) { size_t rs = _param.dstW * cn, rsH = AlignLo(rs, Sse41::F), rsF = AlignLo(rs, F); @@ -988,6 +997,19 @@ namespace Simd float* pb = pbx[k]; const uint16_t* ps = src + (sy + k) * srcStride; size_t dx = 0; + if (cn >= 4) + { + for (; dx < rs;) + { + const uint16_t* ps0 = ps + _ix[dx]; + __m128 fx1 = _mm_set1_ps(_ax[dx]); + __m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1); + for (size_t end = dx + cnH; dx < end; dx += Sse41::F, ps0 += Sse41::F) + _mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps0, cn, fx0, fx1)); + if (cnTH) + _mm_storeu_ps(pb + dx + cnLH, BilinearRowSumBf16(ps0 + cnLH, cn, fx0, fx1)), dx += cnTH; + } + } if (cn == 1) { for (; dx < rsH; dx += Sse41::F) @@ -1001,10 +1023,8 @@ namespace Simd __m128 s0 = Sse41::BFloat16ToFloat32Even(_src); __m128 s1 = Sse41::BFloat16ToFloat32Odd(_src); __m128 fx1 = _mm_loadu_ps(_ax.data + dx); - __m128 fx0 = _mm_sub_ps(_1, fx1); - __m128 m0 = _mm_mul_ps(fx0, s0); - __m128 m1 = _mm_mul_ps(fx1, s1); - _mm_storeu_ps(pb + dx, _mm_add_ps(m0, m1)); + __m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1); + _mm_storeu_ps(pb + dx, _mm_fmadd_ps(fx0, s0, _mm_mul_ps(fx1, s1))); } } if (cn == 2) @@ -1015,10 +1035,8 @@ namespace Simd __m128 s0 = _mm_castsi128_ps(_mm_shuffle_epi8(_src, K8_IDX_20)); __m128 s1 = _mm_castsi128_ps(_mm_shuffle_epi8(_src, K8_IDX_21)); __m128 fx1 = _mm_loadu_ps(_ax.data + dx); - __m128 fx0 = _mm_sub_ps(_1, fx1); - __m128 m0 = _mm_mul_ps(fx0, s0); - __m128 m1 = _mm_mul_ps(fx1, s1); - _mm_storeu_ps(pb + dx, _mm_add_ps(m0, m1)); + __m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1); + _mm_storeu_ps(pb + dx, _mm_fmadd_ps(fx0, s0, _mm_mul_ps(fx1, s1))); } } if (cn == 3 && rs > 3) @@ -1027,23 +1045,10 @@ namespace Simd for (; dx < rs3; dx += 3) { __m128 fx1 = _mm_set1_ps(_ax.data[dx]); - __m128 fx0 = _mm_sub_ps(_1, fx1); + __m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1); _mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps + _ix[dx], cn, fx0, fx1)); } } - if (cn >= 4) - { - for (; dx < rs;) - { - const uint16_t* ps0 = ps + _ix[dx]; - __m128 fx1 = _mm_set1_ps(_ax[dx]); - __m128 fx0 = _mm_sub_ps(_1, fx1); - for (size_t end = dx + cnF; dx < end; dx += Sse41::F, ps0 += Sse41::F) - _mm_storeu_ps(pb + dx, BilinearRowSumBf16(ps0, cn, fx0, fx1)); - if (cnT) - _mm_storeu_ps(pb + dx + cnL, BilinearRowSumBf16(ps0 + cnL, cn, fx0, fx1)), dx += cnT; - } - } for (; dx < rs; dx++) { int32_t sx = _ix[dx]; @@ -1062,9 +1067,8 @@ namespace Simd } for (; dx < rsH; dx += Sse41::F) { - __m128 m0 = _mm_mul_ps(_mm_loadu_ps(pbx[0] + dx), _mm256_castps256_ps128(_fy0)); - __m128 m1 = _mm_mul_ps(_mm_loadu_ps(pbx[1] + dx), _mm256_castps256_ps128(_fy1)); - __m128i d0 = Sse41::Float32ToBFloat16(_mm_add_ps(m0, m1)); + __m128i d0 = Sse41::Float32ToBFloat16(_mm_fmadd_ps(_mm_loadu_ps(pbx[0] + dx), _mm256_castps256_ps128(_fy0), + _mm_mul_ps(_mm_loadu_ps(pbx[1] + dx), _mm256_castps256_ps128(_fy1)))); _mm_storel_epi64((__m128i*)(dst + dx), _mm_packus_epi32(d0, Sse41::K_ZERO)); } for (; dx < rs; dx++) @@ -1073,31 +1077,65 @@ namespace Simd } else { - for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride) + if (cnF) { - __m128 fy1 = _mm_set1_ps(_ay[dy]); - __m128 fy0 = _mm_sub_ps(_1, fy1); - const uint16_t* src0 = src + _iy[dy] * srcStride, * src1 = src0 + srcStride; - for (size_t dx = 0; dx < _param.dstW; dx++) + for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride) { - size_t os = _ix[dx], end = os + cnF, od = dx * cn; - __m128 fx1 = _mm_set1_ps(_ax[dx]); - __m128 fx0 = _mm_sub_ps(_1, fx1); - for (; os < end; os += Sse41::F, od += Sse41::F) + __m256 fy1 = _mm256_set1_ps(_ay[dy]); + __m256 fy0 = _mm256_sub_ps(_1, fy1); + const uint16_t* src0 = src + _iy[dy] * srcStride, * src1 = src0 + srcStride; + for (size_t dx = 0; dx < _param.dstW; dx++) { - __m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1); - __m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1); - __m128i d0 = Sse41::Float32ToBFloat16(_mm_add_ps(_mm_mul_ps(r0, fy0), _mm_mul_ps(r1, fy1))); - _mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO)); + size_t os = _ix[dx], end = os + cnF, od = dx * cn; + __m256 fx1 = _mm256_set1_ps(_ax[dx]); + __m256 fx0 = _mm256_sub_ps(_1, fx1); + for (; os < end; os += F, od += F) + { + __m256 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1); + __m256 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1); + __m256i d0 = Float32ToBFloat16(_mm256_fmadd_ps(r0, fy0, _mm256_mul_ps(r1, fy1))); + _mm_storeu_si128((__m128i*)(dst + od), _mm256_castsi256_si128(_mm256_permute4x64_epi64(_mm256_packus_epi32(d0, K_ZERO), 0xD8))); + } + if (cnTH) + { + os += cnLH; + od += cnLH; + __m256 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1); + __m256 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1); + __m256i d0 = Float32ToBFloat16(_mm256_fmadd_ps(r0, fy0, _mm256_mul_ps(r1, fy1))); + _mm_storeu_si128((__m128i*)(dst + od), _mm256_castsi256_si128(_mm256_permute4x64_epi64(_mm256_packus_epi32(d0, K_ZERO), 0xD8))); + } } - if (cnT) + } + } + else + { + for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride) + { + __m128 fy1 = _mm_set1_ps(_ay[dy]); + __m128 fy0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fy1); + const uint16_t* src0 = src + _iy[dy] * srcStride, * src1 = src0 + srcStride; + for (size_t dx = 0; dx < _param.dstW; dx++) { - os += cnL; - od += cnL; - __m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1); - __m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1); - __m128i d0 = Sse41::Float32ToBFloat16(_mm_add_ps(_mm_mul_ps(r0, fy0), _mm_mul_ps(r1, fy1))); - _mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO)); + size_t os = _ix[dx], end = os + cnH, od = dx * cn; + __m128 fx1 = _mm_set1_ps(_ax[dx]); + __m128 fx0 = _mm_sub_ps(_mm256_castps256_ps128(_1), fx1); + for (; os < end; os += Sse41::F, od += Sse41::F) + { + __m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1); + __m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1); + __m128i d0 = Sse41::Float32ToBFloat16(_mm_fmadd_ps(r0, fy0, _mm_mul_ps(r1, fy1))); + _mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO)); + } + if (cnTH) + { + os += cnLH; + od += cnLH; + __m128 r0 = BilinearRowSumBf16(src0 + os, cn, fx0, fx1); + __m128 r1 = BilinearRowSumBf16(src1 + os, cn, fx0, fx1); + __m128i d0 = Sse41::Float32ToBFloat16(_mm_fmadd_ps(r0, fy0, _mm_mul_ps(r1, fy1))); + _mm_storel_epi64((__m128i*)(dst + od), _mm_packus_epi32(d0, Sse41::K_ZERO)); + } } } } diff --git a/src/Simd/SimdBaseResizerBilinear.cpp b/src/Simd/SimdBaseResizerBilinear.cpp index 42398402fb..1774a2119d 100644 --- a/src/Simd/SimdBaseResizerBilinear.cpp +++ b/src/Simd/SimdBaseResizerBilinear.cpp @@ -395,7 +395,7 @@ namespace Simd ResizerBf16Bilinear::ResizerBf16Bilinear(const ResParam& param) : Resizer(param) { - _rowBuf = !(_param.align >= 16 && (_param.channels >= _param.align / 4 || _param.channels == 64)) || _param.dstH >= _param.srcH; + _rowBuf = _param.align < 16 || _param.channels < 4 || _param.dstH >= _param.srcH; _ay.Resize(_param.dstH, false, _param.align); _iy.Resize(_param.dstH, false, _param.align); EstimateIndexAlpha(_param, _param.srcH, _param.dstH, 1, 1, _iy.data, _ay.data);