Skip to content

Commit

Permalink
+add AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w6 for class SynetM…
Browse files Browse the repository at this point in the history
…ergedConvolution16b.
  • Loading branch information
ermig1979 committed Dec 24, 2024
1 parent e7dd15c commit 2452284
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 4 deletions.
1 change: 0 additions & 1 deletion docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ <h5>New features</h5>
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w6 for framework SynetMergedConvolution32f.</li>
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w8 for framework SynetMergedConvolution32f.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w8 for class SynetMergedConvolution16b.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w8 for class SynetMergedConvolution16b.</li>
<li>Base implementation of function Yuv444pToRgbaV2.</li>
</ul>
<h5>Improving</h5>
Expand Down
2 changes: 2 additions & 0 deletions docs/2025.html
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ <h5>New features</h5>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function SynetTiledScale2D32f.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w6 for class SynetMergedConvolution16b.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w4 for class SynetMergedConvolution16b.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w8 for class SynetMergedConvolution16b.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w6 for class SynetMergedConvolution16b.</li>
</ul>
<h5>Improving</h5>
<ul>
Expand Down
114 changes: 113 additions & 1 deletion src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,13 +482,120 @@ namespace Simd
}
}

//-------------------------------------------------------------------------------------------------

static SIMD_INLINE bool Preferable_k3p1d1s1w6(const ConvParam& p)
{
return p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) &&
(p.srcW >= 6 && (p.srcW % 6 == 0 || p.srcW % 6 >= 4));
}

template<typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k3p1d1s1w6(const uint8_t* src8,
const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float* weight, const float* bias, const float* params, uint8_t* dst)
{
assert(p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) && p.srcW >= 6);
const T* src = (T*)src8;
size_t srcH = p.srcH, srcW = p.srcW;
size_t sM = (a.bufH[1] - 1), sD = a.bufH[1] ? a.bufH[1] * p.srcW * F : F, sX = a.bufH[1] ? F : p.srcC, sY = sX * p.srcW, dstC = maC;
size_t dX = (a.bufH[2] ? a.maC * 2 : p.dstC * a.elem[1]), dY = p.dstW * dX, dy0 = a.bufH[2] ? yBeg : 0, dD = a.bufH[2] ? F * 2 : F * a.elem[1];
size_t wD = 9 * F, dstCF = AlignLo(dstC, F), dstW = p.dstW, endW = dstW - 6;
size_t dstCe = a.bufH[2] ? AlignHi(dstC, DF) : dstC;

__m512 s0, s1, w0, w1, w2, d0, d1, d2, d3, d4, d5;

__m512 _params[2], _bias[1];
_params[0] = _mm512_set1_ps(params[0]);
if (type == SimdConvolutionActivationRestrictRange ||
type == SimdConvolutionActivationHswish ||
type == SimdConvolutionActivationHardSigmoid)
_params[1] = _mm512_set1_ps(params[1]);
for (size_t dc = 0; dc < dstCe; dc += F)
{
_bias[0] = _mm512_loadu_ps(bias + dc);
if (type == ::SimdConvolutionActivationPrelu)
_params[0] = _mm512_loadu_ps(params + dc);
__mmask16 tailS = TailMask16(dstC - dc);
__mmask32 tailC = (dc == dstCF && a.bufH[2]) ? TailMask32(dstCe - dstCF) : tailS;
for (size_t dy = yBeg; dy < yEnd; ++dy)
{
for (size_t dx = 0;; dx += Min<size_t>(6, endW - dx))
{
d0 = _mm512_setzero_ps();
d1 = _mm512_setzero_ps();
d2 = _mm512_setzero_ps();
d3 = _mm512_setzero_ps();
d4 = _mm512_setzero_ps();
d5 = _mm512_setzero_ps();
__mmask16 tailS0 = dx == 0 ? 0 : tailS;
__mmask16 tailS1 = dx == endW ? 0 : tailS;
for (size_t ky = 0; ky < 3; ++ky)
{
size_t sy = dy + ky - 1;
const T* ps = src + (sy & sM) * sY + (dx - 1) * sX;
const float* pw = weight + ky * 3 * F;
if (sy < srcH)
{
w0 = _mm512_maskz_loadu_ps(tailS, pw + 0 * F);
s0 = LoadSrc(ps + 0 * sX, tailS0);
d0 = _mm512_fmadd_ps(s0, w0, d0);

w1 = _mm512_maskz_loadu_ps(tailS, pw + 1 * F);
s1 = LoadSrc(ps + 1 * sX, tailS);
d0 = _mm512_fmadd_ps(s1, w1, d0);
d1 = _mm512_fmadd_ps(s1, w0, d1);

s0 = LoadSrc(ps + 2 * sX, tailS);
w2 = _mm512_maskz_loadu_ps(tailS, pw + 2 * F);
d0 = _mm512_fmadd_ps(s0, w2, d0);
d1 = _mm512_fmadd_ps(s0, w1, d1);
d2 = _mm512_fmadd_ps(s0, w0, d2);

s1 = LoadSrc(ps + 3 * sX, tailS);
d1 = _mm512_fmadd_ps(s1, w2, d1);
d2 = _mm512_fmadd_ps(s1, w1, d2);
d3 = _mm512_fmadd_ps(s1, w0, d3);

s0 = LoadSrc(ps + 4 * sX, tailS);
d2 = _mm512_fmadd_ps(s0, w2, d2);
d3 = _mm512_fmadd_ps(s0, w1, d3);
d4 = _mm512_fmadd_ps(s0, w0, d4);

s1 = LoadSrc(ps + 5 * sX, tailS);
d3 = _mm512_fmadd_ps(s1, w2, d3);
d4 = _mm512_fmadd_ps(s1, w1, d4);
d5 = _mm512_fmadd_ps(s1, w0, d5);

s0 = LoadSrc(ps + 6 * sX, tailS);
d4 = _mm512_fmadd_ps(s0, w2, d4);
d5 = _mm512_fmadd_ps(s0, w1, d5);

s1 = LoadSrc(ps + 7 * sX, tailS1);
d5 = _mm512_fmadd_ps(s1, w2, d5);
}
}
uint8_t* pd = dst + (dy - dy0) * dY + dx * dX;
Save1<term, type>(pd + 0 * dX, dD, d0, _bias, _params, tailC);
Save1<term, type>(pd + 1 * dX, dD, d1, _bias, _params, tailC);
Save1<term, type>(pd + 2 * dX, dD, d2, _bias, _params, tailC);
Save1<term, type>(pd + 3 * dX, dD, d3, _bias, _params, tailC);
Save1<term, type>(pd + 4 * dX, dD, d4, _bias, _params, tailC);
Save1<term, type>(pd + 5 * dX, dD, d5, _bias, _params, tailC);
if (dx == endW)
break;
}
}
src += sD;
dst += dD;
weight += wD;
}
}

//-------------------------------------------------------------------------------------------------

static SIMD_INLINE bool Preferable_k3p1d1s1w8(const ConvParam& p)
{
return p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) &&
(p.srcW >= 8 && (p.srcW % 8 == 0 || p.srcW % 8 >= 6)/*&& AlignHiAny(p.srcW, 8) < AlignHiAny(p.srcW, 6) * 1.2*/);
(p.srcW >= 8 && (p.srcW % 8 == 0 || p.srcW % 8 >= 6));
}

template<typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k3p1d1s1w8(const uint8_t* src8,
Expand Down Expand Up @@ -614,6 +721,11 @@ namespace Simd
depthwise = DepthwiseConvolution_k3p1d1s1w8<T, term, type>;
return true;
}
else if (Preferable_k3p1d1s1w6(p))
{
depthwise = DepthwiseConvolution_k3p1d1s1w6<T, term, type>;
return true;
}
else if (IsKernel(p, 3) && IsDilation(p, 1) && IsStride(p, 1))
{
depthwise = DepthwiseConvolution3x3_V2<T, term, type>;
Expand Down
4 changes: 2 additions & 2 deletions src/Test/TestSynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ namespace Test
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(2, 512, 127, 127), Cnv(a0, 1, 1, 1024), Cnv(a1, 3, 1), Cnv(a2, 1, 1, 512), t, f32, f32), f1, f2);
#endif
#if 1
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 6, 6), Cnv(aSw, 1, 1, 512), Cnv(aId, 3, 1), b16, f32), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 4, 8), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, f32), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 8, 6), Cnv(aSw, 1, 1, 512), Cnv(aId, 3, 1), b16, f32), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 6, 8), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, f32), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 12, 12), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, b16), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 128, 24, 24), Cnv(aSw, 1, 1, 128), Cnv(aSw, 3, 1), b16, b16), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 64, 48, 48), Cnv(aSw, 1, 1, 64), Cnv(aSw, 3, 1), b16, b16), f1, f2);
Expand Down

0 comments on commit 2452284

Please sign in to comment.