Skip to content

Commit df97a28

Browse files
authored
Conv:TF32: add more instances - 1 (#2867)
* conv:tf32:add more instances * add instances of device_grouped_conv_fwd_xdl_f32_comp_instances * add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances * add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances * remove gnhwc/ngchw/ngcdhw instances
1 parent f076f20 commit df97a28

File tree

92 files changed

+4274
-444
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+4274
-444
lines changed

include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base
5454
static constexpr auto xdlops_gemm =
5555
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType, TransposeC>{};
5656

57+
using ComputeDataTypeBuf =
58+
conditional_t<std::is_same<ComputeDataType, ck::tf32_t>::value, float, ComputeDataType>;
59+
5760
static constexpr index_t AMmaKStride = KPack;
5861
static constexpr index_t BMmaKStride = KPack;
5962

@@ -376,7 +379,7 @@ struct BlockwiseGemmXdlops_pipeline_base
376379
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
377380

378381
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
379-
ComputeDataType,
382+
ComputeDataTypeBuf,
380383
decltype(a_block_desc_m0_m1_m2_k),
381384
decltype(a_thread_desc_),
382385
Sequence<1, 1, 1, KPack>,
@@ -386,7 +389,7 @@ struct BlockwiseGemmXdlops_pipeline_base
386389
A_K1>;
387390

388391
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
389-
ComputeDataType,
392+
ComputeDataTypeBuf,
390393
decltype(b_block_desc_n0_n1_n2_k),
391394
decltype(b_thread_desc_),
392395
Sequence<1, 1, 1, KPack>,

include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
140140
using Base::AMmaKStride;
141141
using Base::BMmaKStride;
142142

143+
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
144+
143145
static constexpr index_t PrefetchStages = 1;
144146
static constexpr index_t PrefillStages = 1;
145147
static constexpr index_t GlobalBufferNum = 1;
@@ -185,9 +187,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
185187
CThreadBuffer& c_thread_buf,
186188
index_t num_loop) const
187189
{
188-
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
190+
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
189191
a_thread_desc_.GetElementSpaceSize());
190-
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
192+
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
191193
b_thread_desc_.GetElementSpaceSize());
192194

193195
// Global prefetch 1
@@ -240,20 +242,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
240242
static_for<0, KRepeat, 1>{}([&](auto k0) {
241243
static_for<0, MRepeat, 1>{}([&](auto m0) {
242244
static_for<0, NRepeat, 1>{}([&](auto n0) {
243-
vector_type<ComputeDataType, KPack> a_thread_vec;
244-
vector_type<ComputeDataType, KPack> b_thread_vec;
245+
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
246+
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
245247

246248
static_for<0, KPack, 1>{}([&](auto ik) {
247-
a_thread_vec.template AsType<ComputeDataType>()(ik) =
249+
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
248250
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
249251
make_tuple(m0, I0, k0, ik))>{}];
250-
b_thread_vec.template AsType<ComputeDataType>()(ik) =
252+
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
251253
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
252254
make_tuple(n0, I0, k0, ik))>{}];
253255
});
254256

255257
using mfma_input_type =
256-
typename vector_type<ComputeDataType,
258+
typename vector_type<ComputeDataTypeBuf,
257259
xdlops_gemm.K1PerXdlops>::type;
258260

259261
constexpr index_t c_offset =
@@ -301,20 +303,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
301303
static_for<0, KRepeat, 1>{}([&](auto k0) {
302304
static_for<0, MRepeat, 1>{}([&](auto m0) {
303305
static_for<0, NRepeat, 1>{}([&](auto n0) {
304-
vector_type<ComputeDataType, KPack> a_thread_vec;
305-
vector_type<ComputeDataType, KPack> b_thread_vec;
306+
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
307+
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
306308

307309
static_for<0, KPack, 1>{}([&](auto ik) {
308-
a_thread_vec.template AsType<ComputeDataType>()(ik) =
310+
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
309311
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
310312
make_tuple(m0, I0, k0, ik))>{}];
311-
b_thread_vec.template AsType<ComputeDataType>()(ik) =
313+
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
312314
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
313315
make_tuple(n0, I0, k0, ik))>{}];
314316
});
315317

316318
using mfma_input_type =
317-
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
319+
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
318320

319321
constexpr index_t c_offset =
320322
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -439,6 +441,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
439441
using Base::a_block_desc_m0_m1_m2_k;
440442
using Base::b_block_desc_n0_n1_n2_k;
441443

444+
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
445+
442446
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
443447
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
444448
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
@@ -486,9 +490,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
486490
CThreadBuffer& c_thread_buf,
487491
index_t num_loop) const
488492
{
489-
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
493+
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
490494
a_thread_desc_.GetElementSpaceSize());
491-
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
495+
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
492496
b_thread_desc_.GetElementSpaceSize());
493497

494498
// Global prefetch 1
@@ -551,20 +555,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
551555
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
552556
static_for<0, MRepeat, 1>{}([&](auto m0) {
553557
static_for<0, NRepeat, 1>{}([&](auto n0) {
554-
vector_type<ComputeDataType, KPack> a_thread_vec;
555-
vector_type<ComputeDataType, KPack> b_thread_vec;
558+
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
559+
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
556560

557561
static_for<0, KPack, 1>{}([&](auto ik) {
558-
a_thread_vec.template AsType<ComputeDataType>()(ik) =
562+
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
559563
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
560564
make_tuple(m0, I0, k0, k_ + ik))>{}];
561-
b_thread_vec.template AsType<ComputeDataType>()(ik) =
565+
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
562566
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
563567
make_tuple(n0, I0, k0, k_ + ik))>{}];
564568
});
565569

566570
using mfma_input_type =
567-
typename vector_type<ComputeDataType,
571+
typename vector_type<ComputeDataTypeBuf,
568572
xdlops_gemm.K1PerXdlops>::type;
569573

570574
constexpr index_t c_offset =
@@ -640,20 +644,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
640644
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
641645
static_for<0, MRepeat, 1>{}([&](auto m0) {
642646
static_for<0, NRepeat, 1>{}([&](auto n0) {
643-
vector_type<ComputeDataType, KPack> a_thread_vec;
644-
vector_type<ComputeDataType, KPack> b_thread_vec;
647+
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
648+
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
645649

646650
static_for<0, KPack, 1>{}([&](auto ik) {
647-
a_thread_vec.template AsType<ComputeDataType>()(ik) =
651+
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
648652
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
649653
make_tuple(m0, I0, k0, k_ + ik))>{}];
650-
b_thread_vec.template AsType<ComputeDataType>()(ik) =
654+
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
651655
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
652656
make_tuple(n0, I0, k0, k_ + ik))>{}];
653657
});
654658

655659
using mfma_input_type =
656-
typename vector_type<ComputeDataType,
660+
typename vector_type<ComputeDataTypeBuf,
657661
xdlops_gemm.K1PerXdlops>::type;
658662

659663
constexpr index_t c_offset =
@@ -704,7 +708,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
704708
I1));
705709

706710
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
707-
ComputeDataType,
711+
ComputeDataTypeBuf,
708712
decltype(a_block_desc_m0_m1_m2_k),
709713
decltype(a_thread_desc_),
710714
Sequence<1, 1, 1, KPerInnerLoop>,
@@ -714,7 +718,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
714718
A_K1>;
715719

716720
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
717-
ComputeDataType,
721+
ComputeDataTypeBuf,
718722
decltype(b_block_desc_n0_n1_n2_k),
719723
decltype(b_thread_desc_),
720724
Sequence<1, 1, 1, KPerInnerLoop>,

0 commit comments

Comments
 (0)