From 459c109dc032a4f5fe33038851755583a8612fa9 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 17:14:54 -0800 Subject: [PATCH] increase prefetch distance add helper functions for tiled kernels --- .../99_matrixexperiments/matrix_helpers.cl | 21 +- .../matrix_kernel_tiled.cl | 490 ++++++++---------- 2 files changed, 244 insertions(+), 267 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 34857c0..5b323e1 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -289,7 +289,8 @@ void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int co { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(A + offset, 1); + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 8); #endif // defined(PREFETCH_DEFAULT) } @@ -377,7 +378,8 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(A + offset, 1); + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 8); #endif // defined(PREFETCH_DEFAULT) } @@ -446,8 +448,10 @@ void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int co { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(B + offset, 1); offset += 8 * stride; - prefetch(B + offset, 1); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 8); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 8); offset += 8 * stride; #endif // defined(PREFETCH_DEFAULT) } @@ -456,7 +460,8 @@ void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(B + offset, 1); + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 8); #endif // defined(PREFETCH_DEFAULT) } @@ -466,7 +471,8 @@ void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colSta #if defined(PREFETCH_DEFAULT) global uint* B_ui = (global uint*)B; uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; - prefetch(B_ui + offset_ui, 1); + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 4); #endif // defined(PREFETCH_DEFAULT) } @@ -476,7 +482,8 @@ void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colS #if defined(PREFETCH_DEFAULT) global uint* B_ui = (global uint*)B; uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; - prefetch(B_ui + offset_ui, 1); + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 4); #endif // defined(PREFETCH_DEFAULT) } diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 963a2cc..0655d10 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -28,6 +28,9 @@ #define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN #define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) +#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN +#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) + #if !defined(SGS_PER_WG) // Launch four subgroups per work-group, to maximize cache reuse. #define SGS_PER_WG 4 @@ -39,6 +42,33 @@ #if HAS_SIMD8 +void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + } + } +} + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { @@ -49,16 +79,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -72,17 +105,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; int8 aData[KK][MM]; if (KK % 2 == 0) { @@ -140,16 +174,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -163,17 +200,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; int8 aData[KK][MM]; if (KK % 2 == 0) { @@ -223,6 +261,70 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* #endif // HAS_SIMD8 +void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { @@ -235,16 +337,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -260,41 +354,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -330,16 +398,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -355,41 +415,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -415,8 +449,90 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float #ifdef cl_intel_subgroup_extended_block_read +void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[2][4]; + intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + ushort8 tmp[2][2]; + intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 2; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[4]; + intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk][mm + tmm] = as_short8(tmp[tmm]); + } + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } + } + } +} + +void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK))); + } + } +} + +void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[KK][NN]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[kk + 0][nn] = bTemp.lo; + bData[kk + 1][nn] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } + } + } +} + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) + { const int tM = 8; const int tN = 16; @@ -428,16 +544,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -453,75 +561,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0 & MM % 4 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[2][4]; - intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0 & MM % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - ushort8 tmp[2][2]; - intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 2; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else if (MM % 4 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[4]; - intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk][mm + tmm] = as_short8(tmp[tmm]); - } - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - } - } - } + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK))); - } - } + HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -558,16 +606,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -583,85 +623,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0 & MM % 4 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[2][4]; - intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0 & MM % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - ushort8 tmp[2][2]; - intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 2; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else if (MM % 4 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[4]; - intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk][mm + tmm] = as_short8(tmp[tmm]); - } - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - } - } - } + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); int8 bData[KK][NN]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - bData[kk + 0][nn] = bTemp.lo; - bData[kk + 1][nn] = bTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - } - } - } + HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) {