@@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
140
140
using Base::AMmaKStride;
141
141
using Base::BMmaKStride;
142
142
143
+ using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
144
+
143
145
static constexpr index_t PrefetchStages = 1 ;
144
146
static constexpr index_t PrefillStages = 1 ;
145
147
static constexpr index_t GlobalBufferNum = 1 ;
@@ -185,9 +187,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
185
187
CThreadBuffer& c_thread_buf,
186
188
index_t num_loop) const
187
189
{
188
- auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType >(
190
+ auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf >(
189
191
a_thread_desc_.GetElementSpaceSize ());
190
- auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType >(
192
+ auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf >(
191
193
b_thread_desc_.GetElementSpaceSize ());
192
194
193
195
// Global prefetch 1
@@ -240,20 +242,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
240
242
static_for<0 , KRepeat, 1 >{}([&](auto k0) {
241
243
static_for<0 , MRepeat, 1 >{}([&](auto m0) {
242
244
static_for<0 , NRepeat, 1 >{}([&](auto n0) {
243
- vector_type<ComputeDataType , KPack> a_thread_vec;
244
- vector_type<ComputeDataType , KPack> b_thread_vec;
245
+ vector_type<ComputeDataTypeBuf , KPack> a_thread_vec;
246
+ vector_type<ComputeDataTypeBuf , KPack> b_thread_vec;
245
247
246
248
static_for<0 , KPack, 1 >{}([&](auto ik) {
247
- a_thread_vec.template AsType <ComputeDataType >()(ik) =
249
+ a_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
248
250
a_thread_buf[Number<a_thread_desc_.CalculateOffset (
249
251
make_tuple (m0, I0, k0, ik))>{}];
250
- b_thread_vec.template AsType <ComputeDataType >()(ik) =
252
+ b_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
251
253
b_thread_buf[Number<b_thread_desc_.CalculateOffset (
252
254
make_tuple (n0, I0, k0, ik))>{}];
253
255
});
254
256
255
257
using mfma_input_type =
256
- typename vector_type<ComputeDataType ,
258
+ typename vector_type<ComputeDataTypeBuf ,
257
259
xdlops_gemm.K1PerXdlops >::type;
258
260
259
261
constexpr index_t c_offset =
@@ -301,20 +303,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
301
303
static_for<0 , KRepeat, 1 >{}([&](auto k0) {
302
304
static_for<0 , MRepeat, 1 >{}([&](auto m0) {
303
305
static_for<0 , NRepeat, 1 >{}([&](auto n0) {
304
- vector_type<ComputeDataType , KPack> a_thread_vec;
305
- vector_type<ComputeDataType , KPack> b_thread_vec;
306
+ vector_type<ComputeDataTypeBuf , KPack> a_thread_vec;
307
+ vector_type<ComputeDataTypeBuf , KPack> b_thread_vec;
306
308
307
309
static_for<0 , KPack, 1 >{}([&](auto ik) {
308
- a_thread_vec.template AsType <ComputeDataType >()(ik) =
310
+ a_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
309
311
a_thread_buf[Number<a_thread_desc_.CalculateOffset (
310
312
make_tuple (m0, I0, k0, ik))>{}];
311
- b_thread_vec.template AsType <ComputeDataType >()(ik) =
313
+ b_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
312
314
b_thread_buf[Number<b_thread_desc_.CalculateOffset (
313
315
make_tuple (n0, I0, k0, ik))>{}];
314
316
});
315
317
316
318
using mfma_input_type =
317
- typename vector_type<ComputeDataType , xdlops_gemm.K1PerXdlops >::type;
319
+ typename vector_type<ComputeDataTypeBuf , xdlops_gemm.K1PerXdlops >::type;
318
320
319
321
constexpr index_t c_offset =
320
322
c_thread_desc_.CalculateOffset (make_tuple (m0, n0, 0 ));
@@ -439,6 +441,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
439
441
using Base::a_block_desc_m0_m1_m2_k;
440
442
using Base::b_block_desc_n0_n1_n2_k;
441
443
444
+ using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
445
+
442
446
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
443
447
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
444
448
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
@@ -486,9 +490,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
486
490
CThreadBuffer& c_thread_buf,
487
491
index_t num_loop) const
488
492
{
489
- auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType >(
493
+ auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf >(
490
494
a_thread_desc_.GetElementSpaceSize ());
491
- auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType >(
495
+ auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf >(
492
496
b_thread_desc_.GetElementSpaceSize ());
493
497
494
498
// Global prefetch 1
@@ -551,20 +555,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
551
555
static_for<0 , KPerInnerLoop, KPack>{}([&](auto k_) {
552
556
static_for<0 , MRepeat, 1 >{}([&](auto m0) {
553
557
static_for<0 , NRepeat, 1 >{}([&](auto n0) {
554
- vector_type<ComputeDataType , KPack> a_thread_vec;
555
- vector_type<ComputeDataType , KPack> b_thread_vec;
558
+ vector_type<ComputeDataTypeBuf , KPack> a_thread_vec;
559
+ vector_type<ComputeDataTypeBuf , KPack> b_thread_vec;
556
560
557
561
static_for<0 , KPack, 1 >{}([&](auto ik) {
558
- a_thread_vec.template AsType <ComputeDataType >()(ik) =
562
+ a_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
559
563
a_thread_buf[Number<a_thread_desc_.CalculateOffset (
560
564
make_tuple (m0, I0, k0, k_ + ik))>{}];
561
- b_thread_vec.template AsType <ComputeDataType >()(ik) =
565
+ b_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
562
566
b_thread_buf[Number<b_thread_desc_.CalculateOffset (
563
567
make_tuple (n0, I0, k0, k_ + ik))>{}];
564
568
});
565
569
566
570
using mfma_input_type =
567
- typename vector_type<ComputeDataType ,
571
+ typename vector_type<ComputeDataTypeBuf ,
568
572
xdlops_gemm.K1PerXdlops >::type;
569
573
570
574
constexpr index_t c_offset =
@@ -640,20 +644,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
640
644
static_for<0 , KPerInnerLoop, KPack>{}([&](auto k_) {
641
645
static_for<0 , MRepeat, 1 >{}([&](auto m0) {
642
646
static_for<0 , NRepeat, 1 >{}([&](auto n0) {
643
- vector_type<ComputeDataType , KPack> a_thread_vec;
644
- vector_type<ComputeDataType , KPack> b_thread_vec;
647
+ vector_type<ComputeDataTypeBuf , KPack> a_thread_vec;
648
+ vector_type<ComputeDataTypeBuf , KPack> b_thread_vec;
645
649
646
650
static_for<0 , KPack, 1 >{}([&](auto ik) {
647
- a_thread_vec.template AsType <ComputeDataType >()(ik) =
651
+ a_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
648
652
a_thread_buf[Number<a_thread_desc_.CalculateOffset (
649
653
make_tuple (m0, I0, k0, k_ + ik))>{}];
650
- b_thread_vec.template AsType <ComputeDataType >()(ik) =
654
+ b_thread_vec.template AsType <ComputeDataTypeBuf >()(ik) =
651
655
b_thread_buf[Number<b_thread_desc_.CalculateOffset (
652
656
make_tuple (n0, I0, k0, k_ + ik))>{}];
653
657
});
654
658
655
659
using mfma_input_type =
656
- typename vector_type<ComputeDataType ,
660
+ typename vector_type<ComputeDataTypeBuf ,
657
661
xdlops_gemm.K1PerXdlops >::type;
658
662
659
663
constexpr index_t c_offset =
@@ -704,7 +708,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
704
708
I1));
705
709
706
710
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
707
- ComputeDataType ,
711
+ ComputeDataTypeBuf ,
708
712
decltype (a_block_desc_m0_m1_m2_k),
709
713
decltype(a_thread_desc_),
710
714
Sequence<1, 1, 1, KPerInnerLoop>,
@@ -714,7 +718,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
714
718
A_K1>;
715
719
716
720
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
717
- ComputeDataType ,
721
+ ComputeDataTypeBuf ,
718
722
decltype (b_block_desc_n0_n1_n2_k),
719
723
decltype(b_thread_desc_),
720
724
Sequence<1, 1, 1, KPerInnerLoop>,
0 commit comments