@@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
104104 // We want kBlockM to be as small as possible for more parallelism.
105105 // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
106106 // If headdim is divisible by 64, then we set kBlockM = 8, etc.
107- constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16 );
107+ constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16 );
108108 dim3 grid_combine ((params.b * params.h * params.seqlen_q + kBlockM - 1 ) / kBlockM );
109109 BOOL_SWITCH (is_even_K, IsEvenKConst, [&] {
110110 if (params.num_splits <= 2 ) {
@@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
129129
130130template <typename T, int Headdim>
131131void run_mha_fwd_splitkv_dispatch (Flash_fwd_params ¶ms, cudaStream_t stream) {
132- constexpr int kBlockM = 64 ; // Fixed for all head dimensions
132+ constexpr static int kBlockM = 64 ; // Fixed for all head dimensions
133133 // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
134134 // and for headdim 192 with block size 64 x 128.
135135 // Also for headdim 160 with block size 64 x 128 after the rotary addition.
136- constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64 );
136+ constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64 );
137137 run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM , kBlockN , 4 , false , false , T>>(params, stream);
138138}
139139
140140template <typename T>
141141void run_mha_fwd_hdim32 (Flash_fwd_params ¶ms, cudaStream_t stream) {
142- constexpr int Headdim = 32 ;
142+ constexpr static int Headdim = 32 ;
143143 BOOL_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
144144 BOOL_SWITCH (params.is_causal , Is_causal, [&] {
145145 run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128 , 128 , 4 , false , false , T>, Is_dropout, Is_causal>(params, stream);
@@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
149149
150150template <typename T>
151151void run_mha_fwd_hdim64 (Flash_fwd_params ¶ms, cudaStream_t stream) {
152- constexpr int Headdim = 64 ;
152+ constexpr static int Headdim = 64 ;
153153 BOOL_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
154154 BOOL_SWITCH (params.is_causal , Is_causal, [&] {
155155 if constexpr (!Is_dropout) {
@@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
171171
172172template <typename T>
173173void run_mha_fwd_hdim96 (Flash_fwd_params ¶ms, cudaStream_t stream) {
174- constexpr int Headdim = 96 ;
174+ constexpr static int Headdim = 96 ;
175175 auto dprops = at::cuda::getCurrentDeviceProperties ();
176176 bool is_sm8x = dprops->major == 8 && dprops->minor > 0 ;
177177 BOOL_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
@@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
197197
198198template <typename T>
199199void run_mha_fwd_hdim128 (Flash_fwd_params ¶ms, cudaStream_t stream) {
200- constexpr int Headdim = 128 ;
200+ constexpr static int Headdim = 128 ;
201201 auto dprops = at::cuda::getCurrentDeviceProperties ();
202202 bool is_sm8x = dprops->major == 8 && dprops->minor > 0 ;
203203 BOOL_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
@@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
234234
235235template <typename T>
236236void run_mha_fwd_hdim160 (Flash_fwd_params ¶ms, cudaStream_t stream) {
237- constexpr int Headdim = 160 ;
237+ constexpr static int Headdim = 160 ;
238238 auto dprops = at::cuda::getCurrentDeviceProperties ();
239239 bool is_sm8x = dprops->major == 8 && dprops->minor > 0 ;
240240 BOOL_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
@@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
264264
265265template <typename T>
266266void run_mha_fwd_hdim192 (Flash_fwd_params ¶ms, cudaStream_t stream) {
267- constexpr int Headdim = 192 ;
267+ constexpr static int Headdim = 192 ;
268268 BOOL_SWITCH (params.p_dropout < 1 .f , Is_dropout, [&] {
269269 BOOL_SWITCH (params.is_causal , Is_causal, [&] {
270270 if constexpr (!Is_dropout) {
@@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
283283
284284template <typename T>
285285void run_mha_fwd_hdim224 (Flash_fwd_params ¶ms, cudaStream_t stream) {
286- constexpr int Headdim = 224 ;
286+ constexpr static int Headdim = 224 ;
287287 int device;
288288 cudaGetDevice (&device);
289289 int max_smem_per_block;
@@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
309309
310310template <typename T>
311311void run_mha_fwd_hdim256 (Flash_fwd_params ¶ms, cudaStream_t stream) {
312- constexpr int Headdim = 256 ;
312+ constexpr static int Headdim = 256 ;
313313 int device;
314314 cudaGetDevice (&device);
315315 int max_smem_per_sm, max_smem_per_block;
0 commit comments