Skip to content

Commit

Permalink
try a cooperative prefetch for the A matrix tile
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Mar 14, 2024
1 parent 90b23b0 commit ab01142
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in
if (KK % 2 == 0 & MM % 4 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=4) {
//if (get_sub_group_local_id() == 0) {
// printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM);
//}
ushort8 tmp[2][4];
intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp);
for (int tkk = 0; tkk < 2; tkk++) {
Expand Down Expand Up @@ -523,7 +526,15 @@ void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K,

void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k)
{
if (KK % 2 == 0 & MM % 4 == 0) {
if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) {
const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X)
const int kk = 0;
const int mm = sg_index_x % 2 * 2;
//if (get_sub_group_local_id() == 0) {
// printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM);
//}
intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
} else if (KK % 2 == 0 & MM % 4 == 0) {
for (int kk = 0; kk < KK; kk+=2) {
for (int mm = 0; mm < MM; mm+=4) {
intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM));
Expand Down

0 comments on commit ab01142

Please sign in to comment.