File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
include/ck_tile/ops/fmha/kernel Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -72,12 +72,14 @@ struct FmhaFwdKernel
72
72
static constexpr std::string_view kPipelineName = FmhaPipeline::name;
73
73
74
74
// clang-format off
75
- template <typename T > struct t2s ;
75
+ template <typename T1, typename T2 = T1 > struct t2s ;
76
76
template <> struct t2s <float > { static constexpr const char * name = " fp32" ; };
77
77
template <> struct t2s <ck_tile::fp16_t > { static constexpr const char * name = " fp16" ; };
78
78
template <> struct t2s <ck_tile::bf16_t > { static constexpr const char * name = " bf16" ; };
79
79
template <> struct t2s <ck_tile::fp8_t > { static constexpr const char * name = " fp8" ; };
80
80
template <> struct t2s <ck_tile::bf8_t > { static constexpr const char * name = " bf8" ; };
81
+ template <> struct t2s <ck_tile::fp8_t , ck_tile::bf16_t > { static constexpr const char * name = " fp8bf16" ; };
82
+ template <> struct t2s <ck_tile::fp8_t , ck_tile::fp32_t > { static constexpr const char * name = " fp8fp32" ; };
81
83
// clang-format on
82
84
83
85
CK_TILE_HOST static std::string GetName ()
@@ -99,7 +101,7 @@ struct FmhaFwdKernel
99
101
if (kPadHeadDimV ) n += " dv" ;
100
102
return n.empty () ? n : std::string (" p" ) + n; }();
101
103
return
102
- _SS_ (" fmha_fwd_d" ) + _TS_ (bfs::kQKHeaddim ) + " _" + _SS_ (t2s<QDataType>::name) +
104
+ _SS_ (" fmha_fwd_d" ) + _TS_ (bfs::kQKHeaddim ) + " _" + _SS_ (t2s<QDataType, ODataType >::name) +
103
105
" _" + (kIsGroupMode ? " group" : " batch" ) + " _"
104
106
" b" + _TS_ (bfs::kM0 ) + " x" + _TS_ (bfs::kN0 ) + " x" + _TS_ (bfs::kK0 ) + " x" +
105
107
_TS_ (bfs::kN1 ) + " x" + _TS_ (bfs::kK1 ) + " x" + _TS_ (bfs::kQKHeaddim ) + " _" +
You can’t perform that action at this time.
0 commit comments