@@ -2394,7 +2394,7 @@ template <typename T> struct MMAType {
23942394// / - m8n8k4 (f32.f16.f16.f32)
23952395// / - m8n8k16 (s32.s8.s8.s32)
23962396// / - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
2397- // / - m16n8k16 (f32.f16.f16.f32 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
2397+ // / - m16n8k16 (f32.f16.f16.f32 & f16.f16.f16.f16 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
23982398// / - m16n8k32 (s32.s8.s8.s32)
23992399// / Here, m, n & k define the shapes of A, B & C matrices respectively
24002400// / (A = [m x k], B = [k x n], C = [m x n]).
@@ -2671,6 +2671,66 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
26712671 static_cast <CDType>(ra[j + 4 ]) * static_cast <CDType>(rb[j + 4 ]);
26722672 }
26732673 }
2674+ } else if constexpr (std::is_same_v<CDType, sycl::half>) {
2675+ // Init D matrix fragment with C matrix fragment
2676+ sycl::half *d0 = const_cast <sycl::half *>(d[0 ]);
2677+ sycl::half *d1 = d0 + 1 ;
2678+ sycl::half *d2 = const_cast <sycl::half *>(d[1 ]);
2679+ sycl::half *d3 = d2 + 1 ;
2680+ *d0 = c[0 ];
2681+ *d1 = c[1 ];
2682+ *d2 = c[2 ];
2683+ *d3 = c[3 ];
2684+
2685+ // Each sub-group is responsible for computing a fragment size of 16*8
2686+ // elements of matrix D.
2687+ // Each work item computes 4 elements of matrix D by gathering
2688+ // their corresponding row & col matrix fragments of length k (8)
2689+ // from A & B matrices respectively using below mapping logic:
2690+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
2691+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
2692+ // As each row & col fragment of A & B matrices is distributed across
2693+ // 4 work items, each iteration of below loop loads a partial fragment
2694+ // of matrix A (row) and matrix B (col) using the row & col offsets.
2695+ for (int i = 0 ; i < 4 ; i++) {
2696+ typename MMAType<ABType>::PackType recv_a[4 ], recv_b[4 ];
2697+
2698+ // Load partial fragment from row0 of matrix A ({a0, a1})
2699+ recv_a[0 ] = dpct::select_from_sub_group (sg, a[0 ], row_load_offset + i);
2700+ // Load partial fragment from row0 of matrix A ({a2, a3})
2701+ recv_a[1 ] = dpct::select_from_sub_group (sg, a[2 ], row_load_offset + i);
2702+ // Load partial fragment from row1 of matrix A ({a0, a1})
2703+ recv_a[2 ] = dpct::select_from_sub_group (sg, a[1 ], row_load_offset + i);
2704+ // Load partial fragment from row1 of matrix A ({a2, a3})
2705+ recv_a[3 ] = dpct::select_from_sub_group (sg, a[3 ], row_load_offset + i);
2706+
2707+ // Load partial fragment from col0 of matrix B ({b0, b1})
2708+ recv_b[0 ] = dpct::select_from_sub_group (sg, b[0 ], col_load_offset + i);
2709+ // Load partial fragment from col0 of matrix B ({b2, b3})
2710+ recv_b[1 ] = dpct::select_from_sub_group (sg, b[1 ], col_load_offset + i);
2711+ // Load partial fragment from col1 of matrix B ({b0, b1})
2712+ recv_b[2 ] =
2713+ dpct::select_from_sub_group (sg, b[0 ], col_load_offset + 4 + i);
2714+ // Load partial fragment from col1 of matrix B ({b2, b3})
2715+ recv_b[3 ] =
2716+ dpct::select_from_sub_group (sg, b[1 ], col_load_offset + 4 + i);
2717+
2718+ auto ra = reinterpret_cast <ABType *>(recv_a);
2719+ auto rb = reinterpret_cast <ABType *>(recv_b);
2720+
2721+ // Each work item calculates a partial product of A & B matrix fragments
2722+ // and adds it to the corresponding D matrix fragment
2723+ // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
2724+ // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
2725+ // d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
2726+ // d3 += row1{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
2727+ for (int j = 0 ; j < 4 ; j++) {
2728+ *d0 += ra[j] * rb[j];
2729+ *d1 += ra[j] * rb[j + 4 ];
2730+ *d2 += ra[j + 4 ] * rb[j];
2731+ *d3 += ra[j + 4 ] * rb[j + 4 ];
2732+ }
2733+ }
26742734 } else if constexpr (std::is_integral_v<ABType>) {
26752735 // Init D matrix with fragments of C matrix
26762736 *d[0 ] = c[0 ];
0 commit comments