Skip to content

Commit

Permalink
add support for split barriers
Browse files Browse the repository at this point in the history
This may also be helpful to keep subgroups running approximately
together, which could also improve cache utilization.
  • Loading branch information
bashbaug committed Jan 19, 2024
1 parent 0fb3d66 commit d09b982
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions samples/99_matrixexperiments/matrix_kernel_tiled.cl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@
#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension."
#endif

#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS)
#if !defined(cl_intel_split_work_group_barrier)
#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?"
#endif
#define split_barrier_arrive()
#define split_barrier_wait()
#else
#define split_barrier_arrive() intel_work_group_barrier_arrive(0)
#define split_barrier_wait() intel_work_group_barrier_wait(0)
#endif

#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN
#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN)

Expand All @@ -36,6 +47,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
}
}

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
int8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
Expand All @@ -52,8 +65,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl
sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]);
}
}

split_barrier_wait();
split_barrier_arrive();
}

split_barrier_wait();

for (int nn = 0; nn < NN; nn++) {
for (int mm = 0; mm < MM; mm++) {
store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N);
Expand All @@ -77,6 +95,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
}
}

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
int8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
Expand All @@ -93,8 +113,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float*
sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]);
}
}

split_barrier_wait();
split_barrier_arrive();
}

split_barrier_wait();

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N);
Expand All @@ -120,6 +145,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
}
}

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
Expand All @@ -136,8 +163,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}

split_barrier_wait();
split_barrier_arrive();
}

split_barrier_wait();

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N);
Expand All @@ -161,6 +193,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
}
}

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
for (int mm = 0; mm < MM; mm++) {
Expand All @@ -177,8 +211,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}

split_barrier_wait();
split_barrier_arrive();
}

split_barrier_wait();

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N);
Expand All @@ -205,6 +244,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
}
}

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
//if (MM % 2 == 0) {
Expand All @@ -229,8 +270,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}

split_barrier_wait();
split_barrier_arrive();
}

split_barrier_wait();

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn]));
Expand All @@ -255,6 +301,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
}
}

split_barrier_arrive();

for (int k = 0; k < K; k += tK) {
short8 aData[MM];
//if (MM % 2 == 0) {
Expand All @@ -279,8 +327,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl
sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]);
}
}

split_barrier_wait();
split_barrier_arrive();
}

split_barrier_wait();

for (int mm = 0; mm < MM; mm++) {
for (int nn = 0; nn < NN; nn++) {
intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn]));
Expand Down

0 comments on commit d09b982

Please sign in to comment.