From e1c87eac081f40bdbe4ff109e873d1885f78645a Mon Sep 17 00:00:00 2001 From: Sudhir Kylasa Date: Tue, 16 Sep 2025 18:31:35 +0000 Subject: [PATCH 1/6] Fix broken ping pong pipeline functionality in the develop branch. --- example/ck_tile/03_gemm/gemm_utils.hpp | 2 +- example/ck_tile/03_gemm/universal_gemm_invoker.hpp | 1 + include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 588b66ca43e..e9b2798172b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -252,7 +252,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr ck_tile::index_t NumWaveGroups = 2; }; template diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 19855c7f72f..d0762e4970c 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -102,6 +102,7 @@ struct UniversalInvoker TilePartitioner::NPerBlock, GemmConfig::M_Warp, GemmConfig::N_Warp, + GemmConfig::K_Warp, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 628af0e0b32..b4036cae045 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -40,6 +40,7 @@ template ; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); + static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size(); static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t MWave = MWave_; From f5acdb4671c40bdd4ba41695a890e78c3def5e5d Mon Sep 17 00:00:00 2001 From: Sudhir Kylasa Date: Tue, 16 Sep 2025 18:37:52 +0000 Subject: [PATCH 2/6] Document Format tool outputs commit --- example/ck_tile/03_gemm/gemm_utils.hpp | 4 +- .../ops/epilogue/cshuffle_epilogue.hpp | 49 ++++++++++--------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index e9b2798172b..a021514b205 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -250,8 +250,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; static constexpr ck_tile::index_t NumWaveGroups = 2; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index b4036cae045..2fe89f73b25 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -13,7 +13,10 @@ namespace ck_tile { template -concept HasDataType = requires { typename T::DataType; }; +concept HasDataType = requires +{ + typename T::DataType; +}; template struct GetDataType @@ -22,7 +25,7 @@ struct GetDataType }; template - requires HasDataType +requires HasDataType struct GetDataType { using type = typename T::DataType; // Use T::ScaleN::DataType @@ -52,23 +55,23 @@ template struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size(); - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t MWave = MWave_; - static constexpr index_t NWave = NWave_; - static constexpr index_t MPerXdl = MPerXdl_; - static constexpr index_t NPerXdl = NPerXdl_; - static constexpr index_t KPerXdl = KPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; @@ -239,8 +242,8 @@ struct CShuffleEpilogue using CWarpTensor = typename WG::CWarpTensor; using CWarpDstrEncoding = typename WG::CWarpDstrEncoding; using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; + sequence<0, 1>, + sequence>; template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() @@ -344,8 +347,8 @@ struct CShuffleEpilogue const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), - generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, - number{})); + generate_tie( + [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); } From c2ec85411be3183e10504c7c57ad25ed4b944be1 Mon Sep 17 00:00:00 2001 From: Sudhir Kylasa Date: Fri, 19 Sep 2025 22:57:01 +0000 Subject: [PATCH 3/6] Include KWave when computing the block size only when ping pong scheduler is used. In other cases, default back to the original value. --- .../ops/epilogue/cshuffle_epilogue.hpp | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 2fe89f73b25..47de9af3b5c 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -55,23 +55,24 @@ template struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size(); - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t MWave = MWave_; - static constexpr index_t NWave = NWave_; - static constexpr index_t MPerXdl = MPerXdl_; - static constexpr index_t NPerXdl = NPerXdl_; - static constexpr index_t KPerXdl = KPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = + MWave_ * NWave_ * (kNumWaveGroups_ > 1 ? KWave_ : 1) * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr bool FixedVectorSize = FixedVectorSize_; static constexpr index_t VectorSizeC = VectorSizeC_; From 396f70155881742dda2bfc60182f79939cb905e5 Mon Sep 17 00:00:00 2001 From: Sudhir Kylasa Date: Thu, 25 Sep 2025 22:22:29 +0000 Subject: [PATCH 4/6] Fix compilation issues with other instances of CShuffle usage. --- .../03_gemm/universal_gemm_invoker.hpp | 7 ++++-- .../ops/epilogue/cshuffle_epilogue.hpp | 23 +++++++++---------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index d0762e4970c..4d6c97ed370 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -102,13 +102,16 @@ struct UniversalInvoker TilePartitioner::NPerBlock, GemmConfig::M_Warp, GemmConfig::N_Warp, - GemmConfig::K_Warp, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, memory_operation, - GemmConfig::NumWaveGroups>>; + GemmConfig::NumWaveGroups, + false, + 1, + false, + GemmConfig::K_Warp>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index b8a65f26bee..9cba02464fc 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -43,7 +43,6 @@ template + bool TiledMMAPermuteN_ = false, + index_t KWave_ = 1> struct CShuffleEpilogueProblem { - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = - MWave_ * NWave_ * (kNumWaveGroups_ > 1 ? KWave_ : 1) * get_warp_size(); + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size(); static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; From 97a9b657df8f5650eaf6aa705a0e882df8bc9dc7 Mon Sep 17 00:00:00 2001 From: Sudhir Kylasa Date: Fri, 26 Sep 2025 07:37:57 +0000 Subject: [PATCH 5/6] Moved the NumWaveGroups condition to the user level files from the include directory. --- .../03_gemm/universal_gemm_invoker.hpp | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 4d6c97ed370..501238eadbe 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -89,29 +89,29 @@ struct UniversalInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue 1) ? GemmConfig::K_Warp : 1>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); From bd52e1a07bd3b39f6dcfe7760dcaf458f8717340 Mon Sep 17 00:00:00 2001 From: Sudhir Kylasa Date: Fri, 26 Sep 2025 18:56:15 +0000 Subject: [PATCH 6/6] output of ./scripts/clang-format-overwrite.sh on the files changes in this PR --- include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index aacf1602ff5..d3cfd4c7f36 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -239,8 +239,8 @@ struct CShuffleEpilogue using CWarpTensor = typename WG::CWarpTensor; using CWarpDstrEncoding = typename WG::CWarpDstrEncoding; using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; + sequence<0, 1>, + sequence>; template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() @@ -363,8 +363,8 @@ struct CShuffleEpilogue const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, + number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); }