From ab01142a5052d9d0b64b8df64a2642355e6eaa44 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 14 Mar 2024 13:51:52 -0700 Subject: [PATCH] try a cooperative prefetch for the A matrix tile --- samples/99_matrixexperiments/matrix_kernel_tiled.cl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index a8d6b4d..2f7fdc7 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -412,6 +412,9 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { + //if (get_sub_group_local_id() == 0) { + // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); + //} ushort8 tmp[2][4]; intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); for (int tkk = 0; tkk < 2; tkk++) { @@ -523,7 +526,15 @@ void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) { - if (KK % 2 == 0 & MM % 4 == 0) { + if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) { + const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X) + const int kk = 0; + const int mm = sg_index_x % 2 * 2; + //if (get_sub_group_local_id() == 0) { + // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } else if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));