Skip to content

Commit

Permalink
switch the tiled dpas order
Browse files Browse the repository at this point in the history
We want to prioritize reuse of the A matrix to make best use of
read suppression buffers.
  • Loading branch information
bashbaug committed Jan 17, 2024
1 parent d76df7e commit 756d2e9
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
}
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N);
}
}
Expand Down Expand Up @@ -83,8 +83,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N);
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]);
}
}
Expand Down Expand Up @@ -126,8 +126,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N);
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}
Expand Down Expand Up @@ -167,8 +167,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N);
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}
Expand Down Expand Up @@ -219,8 +219,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k)));
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}
Expand Down Expand Up @@ -269,8 +269,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2)));
}

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}
Expand Down

0 comments on commit 756d2e9

Please sign in to comment.