Skip to content

Commit

Permalink
+add SSE4.1 optimizations of class ResizerBf16Bilinear.
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Dec 26, 2024
1 parent 3bdb232 commit 579dd37
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/2025.html
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ <h5>New features</h5>
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w4 for class SynetMergedConvolution16b.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w8 for class SynetMergedConvolution16b.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w6 for class SynetMergedConvolution16b.</li>
<li>Base implementation of class ResizerBf16Bilinear.</li>
<li>Base implementation, SSE4.1 optimizations of class ResizerBf16Bilinear.</li>
</ul>
<h5>Improving</h5>
<ul>
Expand Down
9 changes: 9 additions & 0 deletions src/Simd/SimdResizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,15 @@ namespace Simd
public:
ResizerFloatBilinear(const ResParam& param);
};

//-------------------------------------------------------------------------------------------------

class ResizerBf16Bilinear : public Base::ResizerBf16Bilinear
{
virtual void Run(const uint16_t* src, size_t srcStride, uint16_t* dst, size_t dstStride);
public:
ResizerBf16Bilinear(const ResParam& param);
};

//-------------------------------------------------------------------------------------------------

Expand Down
2 changes: 2 additions & 0 deletions src/Simd/SimdSse41Resizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace Simd
return new ResizerShortBilinear(param);
else if (param.IsFloatBilinear())
return new ResizerFloatBilinear(param);
else if (param.IsBf16Bilinear())
return new ResizerBf16Bilinear(param);
else if (param.IsByteBicubic())
return new ResizerByteBicubic(param);
else if (param.IsByteArea2x2())
Expand Down
86 changes: 86 additions & 0 deletions src/Simd/SimdSse41ResizerBilinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "Simd/SimdStore.h"
#include "Simd/SimdResizer.h"
#include "Simd/SimdResizerCommon.h"
#include "Simd/SimdBFloat16.h"

namespace Simd
{
Expand Down Expand Up @@ -670,6 +671,91 @@ namespace Simd
dst[dx] = pbx[0][dx] * fy0 + pbx[1][dx] * fy1;
}
}

//-------------------------------------------------------------------------------------------------

ResizerBf16Bilinear::ResizerBf16Bilinear(const ResParam& param)
: Base::ResizerBf16Bilinear(param)
{
}

void ResizerBf16Bilinear::Run(const uint16_t* src, size_t srcStride, uint16_t* dst, size_t dstStride)
{
size_t cn = _param.channels;
size_t rs = _param.dstW * cn;
float* pbx[2] = { _bx[0].data, _bx[1].data };
int32_t prev = -2;
size_t rsh = AlignLo(rs, Sse41::F);
__m128 _1 = _mm_set1_ps(1.0f);
for (size_t dy = 0; dy < _param.dstH; dy++, dst += dstStride)
{
float fy1 = _ay[dy];
float fy0 = 1.0f - fy1;
int32_t sy = _iy[dy];
int32_t k = 0;

if (sy == prev)
k = 2;
else if (sy == prev + 1)
{
Swap(pbx[0], pbx[1]);
k = 1;
}

prev = sy;

for (; k < 2; k++)
{
float* pb = pbx[k];
const uint16_t* ps = src + (sy + k) * srcStride;
size_t dx = 0;
//if (cn == 1)
//{
// for (; dx < rsh; dx += Sse41::F)
// {
// __m128 s01 = Sse41::Load(ps + _ix[dx + 0], ps + _ix[dx + 1]);
// __m128 s23 = Sse41::Load(ps + _ix[dx + 2], ps + _ix[dx + 3]);
// __m128 fx1 = _mm_load_ps(_ax.data + dx);
// __m128 fx0 = _mm_sub_ps(_1, fx1);
// __m128 m0 = _mm_mul_ps(fx0, _mm_shuffle_ps(s01, s23, 0x88));
// __m128 m1 = _mm_mul_ps(fx1, _mm_shuffle_ps(s01, s23, 0xDD));
// _mm_store_ps(pb + dx, _mm_add_ps(m0, m1));
// }
//}
//if (cn == 3 && rs > 3)
//{
// size_t rs3 = rs - 3;
// for (; dx < rs3; dx += 3)
// {
// __m128 s0 = _mm_loadu_ps(ps + _ix[dx] + 0);
// __m128 s1 = _mm_loadu_ps(ps + _ix[dx] + 3);
// __m128 fx1 = _mm_set1_ps(_ax.data[dx]);
// __m128 fx0 = _mm_sub_ps(_1, fx1);
// _mm_storeu_ps(pb + dx, _mm_add_ps(_mm_mul_ps(fx0, s0), _mm_mul_ps(fx1, s1)));
// }
//}
for (; dx < rs; dx++)
{
int32_t sx = _ix[dx];
float fx = _ax[dx];
pb[dx] = Base::BFloat16ToFloat32(ps[sx]) * (1.0f - fx) + Base::BFloat16ToFloat32(ps[sx + cn]) * fx;
}
}

size_t dx = 0;
__m128 _fy0 = _mm_set1_ps(fy0);
__m128 _fy1 = _mm_set1_ps(fy1);
for (; dx < rsh; dx += Sse41::F)
{
__m128 m0 = _mm_mul_ps(_mm_load_ps(pbx[0] + dx), _fy0);
__m128 m1 = _mm_mul_ps(_mm_load_ps(pbx[1] + dx), _fy1);
__m128i d0 = Float32ToBFloat16(_mm_add_ps(m0, m1));
_mm_storel_epi64((__m128i*)(dst + dx), _mm_packus_epi32(d0, K_ZERO));
}
for (; dx < rs; dx++)
dst[dx] = Base::Float32ToBFloat16(pbx[0][dx] * fy0 + pbx[1][dx] * fy1);
}
}
}
#endif
}
Expand Down
24 changes: 12 additions & 12 deletions src/Test/TestResize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,15 @@ namespace Test
void Update(SimdResizeMethodType method, SimdResizeChannelType type, size_t channels, size_t srcW, size_t srcH, size_t dstW, size_t dstH)
{
std::stringstream ss;
#if 0
ss << description << "[" << ToString(method) << "-" << ToString(type) << "-" << channels;
ss << ":" << srcW << "x" << srcH << "->" << dstW << "x" << dstH << "]";
#else
ss << description << "[" << channels << ":" << srcW << "x" << srcH << "->" << dstW << "x" << dstH;
ss << ":" << ToString(method) << "-" << ToString(type) << "]";
#endif
description = ss.str();
}

void Call(const View & src, View & dst, size_t channels, SimdResizeChannelType type, SimdResizeMethodType method) const
{
void * resizer = NULL;
if(src.format == View::Float || src.format == View::Int16)
if (src.format == View::Float || src.format == View::Int16)
resizer = func(src.width / channels, src.height, dst.width / channels, dst.height, channels, type, method);
else
resizer = func(src.width, src.height, dst.width, dst.height, channels, type, method);
Expand Down Expand Up @@ -113,7 +108,7 @@ namespace Test
f1.Update(method, type, channels, srcW, srcH, dstW, dstH);
f2.Update(method, type, channels, srcW, srcH, dstW, dstH);

TEST_LOG_SS(Info, "Test " << f1.description << " & " << f2.description << " [" << srcW << ", " << srcH << "] -> [" << dstW << ", " << dstH << "].");
TEST_LOG_SS(Info, "Test " << f1.description << " & " << f2.description << ".");

View::Format format;
if (type == SimdResizeChannelFloat)
Expand Down Expand Up @@ -143,15 +138,15 @@ namespace Test
else
assert(0);

View src(srcW, srcH, format, NULL, TEST_ALIGN(srcW));
View src(srcW, srcH, format);
if (type == SimdResizeChannelFloat)
FillRandom32f(src);
else if (type == SimdResizeChannelShort)
FillRandom16u(src);
else if (type == SimdResizeChannelBf16)
{
View src32f(srcW, srcH, View::Float);
FillRandom32f(src32f);
FillRandom32f(src32f, 0.0f, 10.0f);
for (size_t row = 0; row < srcH; row++)
SimdFloat32ToBFloat16(src32f.Row<float>(row), srcW, src.Row<uint16_t>(row));
}
Expand All @@ -165,12 +160,12 @@ namespace Test
#endif
}

View dst1(dstW, dstH, format, NULL, TEST_ALIGN(dstW));
View dst2(dstW, dstH, format, NULL, TEST_ALIGN(dstW));
View dst1(dstW, dstH, format);
View dst2(dstW, dstH, format);
if (format == View::Int16)
{
Simd::FillPixel(dst1, uint16_t(0x0001));
Simd::FillPixel(dst1, uint16_t(0x0002));
Simd::FillPixel(dst2, uint16_t(0x0002));
}
else
{
Expand Down Expand Up @@ -251,7 +246,12 @@ namespace Test
{
bool result = true;

result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelFloat, 16, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelBf16, 16, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodNearest, SimdResizeChannelFloat, 16, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodNearest, SimdResizeChannelBf16, 16, f1, f2);

return result;

result = result && ResizerAutoTest(SimdResizeMethodNearest, SimdResizeChannelBf16, 1, f1, f2);
result = result && ResizerAutoTest(SimdResizeMethodNearest, SimdResizeChannelBf16, 3, f1, f2);
Expand Down

0 comments on commit 579dd37

Please sign in to comment.