Skip to content

Commit ab22f91

Browse files
authored
fix fmha fwd kernel name (#2880)
* fix fmha fwd kernel name * if the input and output types are the same, keep the original code
1 parent df97a28 commit ab22f91

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,14 @@ struct FmhaFwdKernel
7272
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
7373

7474
// clang-format off
75-
template <typename T> struct t2s;
75+
template <typename T1, typename T2 = T1> struct t2s;
7676
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
7777
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
7878
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
7979
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
8080
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
81+
template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
82+
template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
8183
// clang-format on
8284

8385
CK_TILE_HOST static std::string GetName()
@@ -99,7 +101,7 @@ struct FmhaFwdKernel
99101
if (kPadHeadDimV) n += "dv";
100102
return n.empty() ? n : std::string("p") + n; }();
101103
return
102-
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
104+
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
103105
"_" + (kIsGroupMode ? "group" : "batch") + "_"
104106
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
105107
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +

0 commit comments

Comments
 (0)