Skip to content
6 changes: 3 additions & 3 deletions example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};

template <typename PrecType>
Expand Down
1 change: 1 addition & 0 deletions example/ck_tile/03_gemm/universal_gemm_invoker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct UniversalInvoker
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::K_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
Expand Down
50 changes: 27 additions & 23 deletions include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
namespace ck_tile {

template <typename T>
concept HasDataType = requires { typename T::DataType; };
concept HasDataType = requires
{
typename T::DataType;
};

template <typename T>
struct GetDataType
Expand All @@ -22,7 +25,7 @@ struct GetDataType
};

template <typename T>
requires HasDataType<T>
requires HasDataType<T>
struct GetDataType<T>
{
using type = typename T::DataType; // Use T::ScaleN::DataType
Expand All @@ -40,6 +43,7 @@ template <typename ADataType_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t KWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
Expand All @@ -51,23 +55,23 @@ template <typename ADataType_,
bool TiledMMAPermuteN_ = false>
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * KWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
Expand Down Expand Up @@ -238,8 +242,8 @@ struct CShuffleEpilogue
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;

template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
Expand Down Expand Up @@ -343,8 +347,8 @@ struct CShuffleEpilogue

const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));

tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
}
Expand Down
Loading