Skip to content

Commit

Permalink
+add AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w8 for class SynetM…
Browse files Browse the repository at this point in the history
…ergedConvolution16b.
  • Loading branch information
ermig1979 committed Dec 24, 2024
1 parent eabc7a1 commit e7dd15c
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ <h5>New features</h5>
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w6 for framework SynetMergedConvolution32f.</li>
<li>AVX-512BW kernel Convolution32fNhwcDepthwise_k7p3d1s1w8 for framework SynetMergedConvolution32f.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k5p2d1s1w8 for class SynetMergedConvolution16b.</li>
<li>AMX-BF16 kernel DepthwiseConvolution_k3p1d1s1w8 for class SynetMergedConvolution16b.</li>
<li>Base implementation of function Yuv444pToRgbaV2.</li>
</ul>
<h5>Improving</h5>
Expand Down
205 changes: 173 additions & 32 deletions src/Simd/SimdAmxBf16SynetMergedConvolution16bDepthwise3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,23 +308,19 @@ namespace Simd
if (H > 3) buf[I][5] = _mm512_setzero_ps();
}

template<typename T, int H, int I> static void LoadSrc(const T* src, size_t yB, size_t sM, size_t sY, __mmask16 mask0, __mmask16 mask1, __mmask16 mask2, __m512 buf[3][H + 2])
template<typename T, int H, int I> static void LoadSrc(const T* s0, const T* s1, const T* s2, const T* s3, const T* s4, const T* s5,
size_t offs, __mmask16 mask0, __mmask16 mask1, __mmask16 mask2, __m512 buf[3][H + 2])
{
buf[I][0] = _mm512_maskz_loadu_ps(mask0, src + ((yB + 0) & sM) * sY);
buf[I][1] = _mm512_maskz_loadu_ps(mask1, src + ((yB + 1) & sM) * sY);
if (H > 0) buf[I][2] = _mm512_maskz_loadu_ps(H == 1 ? mask2 : mask1, src + ((yB + 2) & sM) * sY);
if (H > 1) buf[I][3] = _mm512_maskz_loadu_ps(H == 2 ? mask2 : mask1, src + ((yB + 3) & sM) * sY);
if (H > 2) buf[I][4] = _mm512_maskz_loadu_ps(H == 3 ? mask2 : mask1, src + ((yB + 4) & sM) * sY);
if (H > 3) buf[I][5] = _mm512_maskz_loadu_ps(mask2, src + ((yB + 5) & sM) * sY);
buf[I][0] = LoadSrc(s0 + offs, mask0);
buf[I][1] = LoadSrc(s1 + offs, mask1);
if (H > 0) buf[I][2] = LoadSrc(s2 + offs, H == 1 ? mask2 : mask1);
if (H > 1) buf[I][3] = LoadSrc(s3 + offs, H == 2 ? mask2 : mask1);
if (H > 2) buf[I][4] = LoadSrc(s4 + offs, H == 3 ? mask2 : mask1);
if (H > 3) buf[I][5] = LoadSrc(s5 + offs, mask2);
}

template<int H, int I0, int I1, int I2> static void Convolution3x3(__m512 src[3][H + 2], const __m512* weight, __m512 dst[H])
{
if (H > 0) dst[0] = _mm512_setzero_ps();
if (H > 1) dst[1] = _mm512_setzero_ps();
if (H > 2) dst[2] = _mm512_setzero_ps();
if (H > 3) dst[3] = _mm512_setzero_ps();

if (H > 0) dst[0] = _mm512_fmadd_ps(src[I0][0], weight[0], dst[0]);
if (H > 1) dst[1] = _mm512_fmadd_ps(src[I0][1], weight[0], dst[1]);
if (H > 2) dst[2] = _mm512_fmadd_ps(src[I0][2], weight[0], dst[2]);
Expand Down Expand Up @@ -376,52 +372,66 @@ namespace Simd
template<typename T, Term16bType term, SimdConvolutionActivationType type, int H> static void DepthwiseConvolution3x3xH(const T* src, size_t dstH, size_t dstW,
size_t yB, size_t sM, size_t sY, size_t sX, __mmask16 tailS, const __m512* weight, const __m512* bias, const __m512* params, uint8_t* dst, size_t dY, size_t dX, __mmask32 tailD)
{
const T* src0 = src + ((yB - 1) & sM) * sY;
const T* src1 = src + ((yB + 0) & sM) * sY;
const T* src2 = src + ((yB + 1) & sM) * sY;
const T* src3 = src + ((yB + 2) & sM) * sY;
const T* src4 = src + ((yB + 3) & sM) * sY;
const T* src5 = src + ((yB + 4) & sM) * sY;
size_t endW = dstW - 1;
__m512 s[3][H + 2], d[H];
__mmask16 mask0 = yB == 0 ? 0 : tailS;
__mmask16 mask2 = yB + H == dstH ? 0 : tailS;

LoadSrc<T, H, 0>(src, yB, sM, sY, mask0, tailS, mask2, s), src += sX;
LoadSrc<T, H, 1>(src, yB, sM, sY, mask0, tailS, mask2, s), src += sX;
for (size_t dx = 0; dx < dstW; dx += 1)
ZeroSrc<H, 0>(s);
LoadSrc<T, H, 1>(src0, src1, src2, src3, src4, src5, 0, mask0, tailS, mask2, s);
for (size_t dx = 0, offs = sX; dx < dstW; dx += 1, offs += sX)
{
if (H > 0) d[0] = _mm512_setzero_ps();
if (H > 1) d[1] = _mm512_setzero_ps();
if (H > 2) d[2] = _mm512_setzero_ps();
if (H > 3) d[3] = _mm512_setzero_ps();
switch (dx % 3)
{
case 0:
{
if (dx == endW)
ZeroSrc<H, 0>(s);
ZeroSrc<H, 2>(s);
else
LoadSrc<T, H, 0>(src, yB, sM, sY, mask0, tailS, mask2, s);
Convolution3x3<H, 1, 2, 0>(s, weight, d);
LoadSrc<T, H, 2>(src0, src1, src2, src3, src4, src5, offs, mask0, tailS, mask2, s);
Convolution3x3<H, 0, 1, 2>(s, weight, d);
break;
}
case 1:
{
if (dx == endW)
ZeroSrc<H, 1>(s);
ZeroSrc<H, 0>(s);
else
LoadSrc<T, H, 1>(src, yB, sM, sY, mask0, tailS, mask2, s);
Convolution3x3<H, 2, 0, 1>(s, weight, d);
LoadSrc<T, H, 0>(src0, src1, src2, src3, src4, src5, offs, mask0, tailS, mask2, s);
Convolution3x3<H, 1, 2, 0>(s, weight, d);
break;
}
case 2:
{
if (dx == endW)
ZeroSrc<H, 2>(s);
ZeroSrc<H, 1>(s);
else
LoadSrc<T, H, 2>(src, yB, sM, sY, mask0, tailS, mask2, s);
Convolution3x3<H, 0, 1, 2>(s, weight, d);
LoadSrc<T, H, 1>(src0, src1, src2, src3, src4, src5, offs, mask0, tailS, mask2, s);
Convolution3x3<H, 2, 0, 1>(s, weight, d);
}
break;
}
if (H > 0) Save1<term, type>(dst + 0 * dY, 0, d[0], bias, params, tailD);
if (H > 1) Save1<term, type>(dst + 1 * dY, 0, d[1], bias, params, tailD);
if (H > 2) Save1<term, type>(dst + 2 * dY, 0, d[2], bias, params, tailD);
if (H > 3) Save1<term, type>(dst + 3 * dY, 0, d[3], bias, params, tailD);
src += sX;
dst += dX;
}
}

template<typename T> using DepthwiseConvolution3x3xH_Ptr = void (*)(const T * src, size_t dstH, size_t dstW, size_t yB, size_t sM, size_t sY, size_t sX, __mmask16 tailS,
const __m512 * weight, const __m512 * bias, const __m512 * params, uint8_t * dst, size_t dY, size_t dX, __mmask32 tailD);

template<typename T, Term16bType term, SimdConvolutionActivationType type> DepthwiseConvolution3x3xH_Ptr<T> GetDepthwiseConvolution3x3xH(int H)
template<typename T, Term16bType term, SimdConvolutionActivationType type> DepthwiseConvolution3x3xH_Ptr<T> GetDepthwiseConvolution3x3xH(size_t H)
{
switch (H)
{
Expand All @@ -439,14 +449,12 @@ namespace Simd
{
assert(p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1));
const T* src = (T*)src8;
size_t N = 4, M = (yEnd - yBeg) % N, yBody = AlignLoAny(yEnd - yBeg, N) + yBeg;
size_t N = 3, M = (yEnd - yBeg) % N, yBody = AlignLoAny(yEnd - yBeg, N) + yBeg;
DepthwiseConvolution3x3xH_Ptr<T> body = GetDepthwiseConvolution3x3xH<T, term, type>(N);
DepthwiseConvolution3x3xH_Ptr<T> tail = GetDepthwiseConvolution3x3xH<T, term, type>(M);
size_t srcH = p.srcH, srcW = p.srcW;
size_t sM = (a.bufH[1] - 1), sD = a.bufH[1] ? a.bufH[1] * p.srcW * F : F, sX = a.bufH[1] ? F : p.srcC, sY = sX * p.srcW, dstC = maC;
size_t dX = (a.bufH[2] ? a.maC * 2 : p.dstC * a.elem[1]), dY = p.dstW * dX, dy0 = a.bufH[2] ? yBeg : 0, dD = a.bufH[2] ? F * 2 : F * a.elem[1];
size_t wD = 9 * F, dstCF = AlignLo(dstC, F), dstW = p.dstW, endW = dstW - 8;
size_t dstCe = a.bufH[2] ? AlignHi(dstC, DF) : dstC;
size_t wD = 9 * F, dstCF = AlignLo(dstC, F), dstCe = (a.bufH[2] ? AlignHi(dstC, DF) : dstC);

__m512 _params[2], _bias[1], _weight[9];
_params[0] = _mm512_set1_ps(params[0]);
Expand Down Expand Up @@ -474,11 +482,144 @@ namespace Simd
}
}


//-------------------------------------------------------------------------------------------------

static SIMD_INLINE bool Preferable_k3p1d1s1w8(const ConvParam& p)
{
return p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) &&
(p.srcW >= 8 && (p.srcW % 8 == 0 || p.srcW % 8 >= 6)/*&& AlignHiAny(p.srcW, 8) < AlignHiAny(p.srcW, 6) * 1.2*/);
}

template<typename T, Term16bType term, SimdConvolutionActivationType type> static void DepthwiseConvolution_k3p1d1s1w8(const uint8_t* src8,
const ConvParam& p, const AlgParam& a, size_t maC, size_t yBeg, size_t yEnd, const float* weight, const float* bias, const float* params, uint8_t* dst)
{
assert(p.IsKernel(3) && p.IsPad(1) && p.IsStride(1) && p.IsDilation(1) && p.srcW >= 8);
const T* src = (T*)src8;
size_t srcH = p.srcH, srcW = p.srcW;
size_t sM = (a.bufH[1] - 1), sD = a.bufH[1] ? a.bufH[1] * p.srcW * F : F, sX = a.bufH[1] ? F : p.srcC, sY = sX * p.srcW, dstC = maC;
size_t dX = (a.bufH[2] ? a.maC * 2 : p.dstC * a.elem[1]), dY = p.dstW * dX, dy0 = a.bufH[2] ? yBeg : 0, dD = a.bufH[2] ? F * 2 : F * a.elem[1];
size_t wD = 9 * F, dstCF = AlignLo(dstC, F), dstW = p.dstW, endW = dstW - 8;
size_t dstCe = a.bufH[2] ? AlignHi(dstC, DF) : dstC;

__m512 s0, s1, w0, w1, w2, d0, d1, d2, d3, d4, d5, d6, d7;

__m512 _params[2], _bias[1];
_params[0] = _mm512_set1_ps(params[0]);
if (type == SimdConvolutionActivationRestrictRange ||
type == SimdConvolutionActivationHswish ||
type == SimdConvolutionActivationHardSigmoid)
_params[1] = _mm512_set1_ps(params[1]);
for (size_t dc = 0; dc < dstCe; dc += F)
{
_bias[0] = _mm512_loadu_ps(bias + dc);
if (type == ::SimdConvolutionActivationPrelu)
_params[0] = _mm512_loadu_ps(params + dc);
__mmask16 tailS = TailMask16(dstC - dc);
__mmask32 tailC = (dc == dstCF && a.bufH[2]) ? TailMask32(dstCe - dstCF) : tailS;
for (size_t dy = yBeg; dy < yEnd; ++dy)
{
for (size_t dx = 0;; dx += Min<size_t>(8, endW - dx))
{
d0 = _mm512_setzero_ps();
d1 = _mm512_setzero_ps();
d2 = _mm512_setzero_ps();
d3 = _mm512_setzero_ps();
d4 = _mm512_setzero_ps();
d5 = _mm512_setzero_ps();
d6 = _mm512_setzero_ps();
d7 = _mm512_setzero_ps();
__mmask16 tailS0 = dx == 0 ? 0 : tailS;
__mmask16 tailS1 = dx == endW ? 0 : tailS;
for (size_t ky = 0; ky < 3; ++ky)
{
size_t sy = dy + ky - 1;
const T* ps = src + (sy & sM) * sY + (dx - 1) * sX;
const float* pw = weight + ky * 3 * F;
if (sy < srcH)
{
w0 = _mm512_maskz_loadu_ps(tailS, pw + 0 * F);
s0 = LoadSrc(ps + 0 * sX, tailS0);
d0 = _mm512_fmadd_ps(s0, w0, d0);

w1 = _mm512_maskz_loadu_ps(tailS, pw + 1 * F);
s1 = LoadSrc(ps + 1 * sX, tailS);
d0 = _mm512_fmadd_ps(s1, w1, d0);
d1 = _mm512_fmadd_ps(s1, w0, d1);

s0 = LoadSrc(ps + 2 * sX, tailS);
w2 = _mm512_maskz_loadu_ps(tailS, pw + 2 * F);
d0 = _mm512_fmadd_ps(s0, w2, d0);
d1 = _mm512_fmadd_ps(s0, w1, d1);
d2 = _mm512_fmadd_ps(s0, w0, d2);

s1 = LoadSrc(ps + 3 * sX, tailS);
d1 = _mm512_fmadd_ps(s1, w2, d1);
d2 = _mm512_fmadd_ps(s1, w1, d2);
d3 = _mm512_fmadd_ps(s1, w0, d3);

s0 = LoadSrc(ps + 4 * sX, tailS);
d2 = _mm512_fmadd_ps(s0, w2, d2);
d3 = _mm512_fmadd_ps(s0, w1, d3);
d4 = _mm512_fmadd_ps(s0, w0, d4);

s1 = LoadSrc(ps + 5 * sX, tailS);
d3 = _mm512_fmadd_ps(s1, w2, d3);
d4 = _mm512_fmadd_ps(s1, w1, d4);
d5 = _mm512_fmadd_ps(s1, w0, d5);

s0 = LoadSrc(ps + 6 * sX, tailS);
d4 = _mm512_fmadd_ps(s0, w2, d4);
d5 = _mm512_fmadd_ps(s0, w1, d5);
d6 = _mm512_fmadd_ps(s0, w0, d6);

s1 = LoadSrc(ps + 7 * sX, tailS);
d5 = _mm512_fmadd_ps(s1, w2, d5);
d6 = _mm512_fmadd_ps(s1, w1, d6);
d7 = _mm512_fmadd_ps(s1, w0, d7);

s0 = LoadSrc(ps + 8 * sX, tailS);
d6 = _mm512_fmadd_ps(s0, w2, d6);
d7 = _mm512_fmadd_ps(s0, w1, d7);

s1 = LoadSrc(ps + 9 * sX, tailS1);
d7 = _mm512_fmadd_ps(s1, w2, d7);
}
}
uint8_t* pd = dst + (dy - dy0) * dY + dx * dX;
Save1<term, type>(pd + 0 * dX, dD, d0, _bias, _params, tailC);
Save1<term, type>(pd + 1 * dX, dD, d1, _bias, _params, tailC);
Save1<term, type>(pd + 2 * dX, dD, d2, _bias, _params, tailC);
Save1<term, type>(pd + 3 * dX, dD, d3, _bias, _params, tailC);
Save1<term, type>(pd + 4 * dX, dD, d4, _bias, _params, tailC);
Save1<term, type>(pd + 5 * dX, dD, d5, _bias, _params, tailC);
Save1<term, type>(pd + 6 * dX, dD, d6, _bias, _params, tailC);
Save1<term, type>(pd + 7 * dX, dD, d7, _bias, _params, tailC);
if (dx == endW)
break;
}
}
src += sD;
dst += dD;
weight += wD;
}
}

//-------------------------------------------------------------------------------------------------

template<typename T, Term16bType term, SimdConvolutionActivationType type> static bool SetDepthwise3x3(const ConvParam& p, DepthwisePtr& depthwise)
{
if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
if (Preferable_k3p1d1s1w8(p))
{
depthwise = DepthwiseConvolution_k3p1d1s1w8<T, term, type>;
return true;
}
else if (IsKernel(p, 3) && IsDilation(p, 1) && IsStride(p, 1))
{
depthwise = DepthwiseConvolution3x3_V2<T, term, type>;
return true;
}
else if (IsKernel(p, 3) && IsDilation(p, 1) && Aligned(p.dstC, F))
{
depthwise = DepthwiseConvolution3x3<T, term, type>;
return true;
Expand Down
9 changes: 8 additions & 1 deletion src/Test/TestSynetMergedConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ namespace Test
result = result && SynetMergedConvolution16bForwardAutoTest(eps, p, f1, f2);
}
#endif
#if 1
#if 0
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 68, 56, 56), Cnv(aSw, 1, 1, 84), Cnv(aSw, 5, 2), Cnv(aSw, 1, 1, 100), f, f32, f32), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 68, 56, 56), Cnv(aSw, 1, 1, 84), Cnv(aSw, 5, 1), Cnv(aSw, 1, 1, 100), f, f32, f32), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 134, 28, 28), Cnv(aSw, 1, 1, 168), Cnv(aSw, 5, 1), Cnv(aSw, 1, 1, 200), f, f32, f32), f1, f2);
Expand All @@ -281,6 +281,13 @@ namespace Test
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(2, 128, 15, 15), Cnv(a0, 1, 1, 256), Cnv(a1, 3, 1), Cnv(a2, 1, 1, 128), t, b16, b16), f1, f2);
//result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(2, 512, 127, 127), Cnv(a0, 1, 1, 1024), Cnv(a1, 3, 1), Cnv(a2, 1, 1, 512), t, f32, f32), f1, f2);
#endif
#if 1
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 6, 6), Cnv(aSw, 1, 1, 512), Cnv(aId, 3, 1), b16, f32), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 4, 8), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, f32), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 512, 12, 12), Cnv(aSw, 1, 1, 512), Cnv(aSw, 3, 1), b16, b16), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 128, 24, 24), Cnv(aSw, 1, 1, 128), Cnv(aSw, 3, 1), b16, b16), f1, f2);
result = result && SynetMergedConvolution16bForwardAutoTest(eps, Param(Shp(1, 64, 48, 48), Cnv(aSw, 1, 1, 64), Cnv(aSw, 3, 1), b16, b16), f1, f2);
#endif
#else
{
Param p(Shp(1, 68, 56, 56), Cnv(aHs, 1, 1, 84), Cnv(aId, 5, 2), Cnv(aHs, 1, 1, 100), f, f32, f32);
Expand Down

0 comments on commit e7dd15c

Please sign in to comment.