diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 588b66ca43e..a021514b205 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -250,9 +250,9 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; }; template diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 19855c7f72f..501238eadbe 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -89,25 +89,29 @@ struct UniversalInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue 1) ? GemmConfig::K_Warp : 1>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index e0a39a5aea2..d3cfd4c7f36 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -32,18 +32,20 @@ template + bool TiledMMAPermuteN_ = false, + index_t KWave_ = 1> struct CShuffleEpilogueProblem { - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t MWave = MWave_;