Skip to content

Commit 5a83425

Browse files
committed
Change constexpr int to constexpr static int
1 parent 3a9fe7b commit 5a83425

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ includes QKV projection, output projection), see the MHA [implementation](https:
198198

199199
## Changelog
200200

201-
### 2.0
201+
### 2.0: Complete rewrite, 2x faster
202202
Upgrading from FlashAttention (1.x) to FlashAttention-2
203203

204204
These functions have been renamed:
@@ -214,7 +214,7 @@ flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
214214
```python
215215
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
216216
```
217-
### 2.1
217+
### 2.1: Change behavior of causal flag
218218

219219
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the
220220
bottom right corner of the attention matrix, instead of the top-left corner.
@@ -243,7 +243,7 @@ v2.1:
243243
1 1
244244
If the row of the mask is all zero, the output will be zero.
245245

246-
### 2.2
246+
### 2.2: Optimize for inference
247247

248248
Optimize for inference (iterative decoding) when query has very small sequence
249249
length (e.g., query sequence length = 1). The bottleneck here is to load KV
@@ -256,7 +256,7 @@ See the function `flash_attn_with_kvcache` with more features for inference
256256
Thanks to the xformers team, and in particular Daniel Haziza, for this
257257
collaboration.
258258

259-
### 2.3
259+
### 2.3: Local (i.e., sliding window) attention
260260

261261
Implement sliding window attention (i.e., local attention). Thanks to [Mistral
262262
AI](https://mistral.ai/) and in particular Timothée Lacroix for this

csrc/flash_attn/src/flash_bwd_launch_template.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool con
137137

138138
template<typename T>
139139
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
140-
constexpr int Headdim = 32;
140+
constexpr static int Headdim = 32;
141141
int device;
142142
cudaGetDevice(&device);
143143
int max_smem_per_block;
@@ -158,7 +158,7 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream, const boo
158158

159159
template<typename T>
160160
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
161-
constexpr int Headdim = 64;
161+
constexpr static int Headdim = 64;
162162
int device;
163163
cudaGetDevice(&device);
164164
int max_smem_per_block;
@@ -201,7 +201,7 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream, const boo
201201

202202
template<typename T>
203203
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
204-
constexpr int Headdim = 96;
204+
constexpr static int Headdim = 96;
205205
int device;
206206
cudaGetDevice(&device);
207207
int max_smem_per_block;
@@ -228,7 +228,7 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream, const boo
228228

229229
template<typename T>
230230
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
231-
constexpr int Headdim = 128;
231+
constexpr static int Headdim = 128;
232232
int device;
233233
cudaGetDevice(&device);
234234
int max_smem_per_block;
@@ -264,7 +264,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream, const bo
264264

265265
template<typename T>
266266
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
267-
constexpr int Headdim = 160;
267+
constexpr static int Headdim = 160;
268268
int device;
269269
cudaGetDevice(&device);
270270
int max_smem_per_block;
@@ -281,7 +281,7 @@ void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream, const bo
281281

282282
template<typename T>
283283
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
284-
constexpr int Headdim = 192;
284+
constexpr static int Headdim = 192;
285285
int device;
286286
cudaGetDevice(&device);
287287
int max_smem_per_block;
@@ -298,15 +298,15 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream, const bo
298298

299299
template<typename T>
300300
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
301-
constexpr int Headdim = 224;
301+
constexpr static int Headdim = 224;
302302
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
303303
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
304304
});
305305
}
306306

307307
template<typename T>
308308
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
309-
constexpr int Headdim = 256;
309+
constexpr static int Headdim = 256;
310310
int device;
311311
cudaGetDevice(&device);
312312
int max_smem_per_block;

csrc/flash_attn/src/flash_fwd_launch_template.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
104104
// We want kBlockM to be as small as possible for more parallelism.
105105
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
106106
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
107-
constexpr int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
107+
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
108108
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
109109
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
110110
if (params.num_splits <= 2) {
@@ -129,17 +129,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
129129

130130
template<typename T, int Headdim>
131131
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
132-
constexpr int kBlockM = 64; // Fixed for all head dimensions
132+
constexpr static int kBlockM = 64; // Fixed for all head dimensions
133133
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
134134
// and for headdim 192 with block size 64 x 128.
135135
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
136-
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
136+
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
137137
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
138138
}
139139

140140
template<typename T>
141141
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
142-
constexpr int Headdim = 32;
142+
constexpr static int Headdim = 32;
143143
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
144144
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
145145
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
@@ -149,7 +149,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
149149

150150
template<typename T>
151151
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
152-
constexpr int Headdim = 64;
152+
constexpr static int Headdim = 64;
153153
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
154154
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
155155
if constexpr(!Is_dropout) {
@@ -171,7 +171,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
171171

172172
template<typename T>
173173
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
174-
constexpr int Headdim = 96;
174+
constexpr static int Headdim = 96;
175175
auto dprops = at::cuda::getCurrentDeviceProperties();
176176
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
177177
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@@ -197,7 +197,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
197197

198198
template<typename T>
199199
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
200-
constexpr int Headdim = 128;
200+
constexpr static int Headdim = 128;
201201
auto dprops = at::cuda::getCurrentDeviceProperties();
202202
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
203203
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@@ -234,7 +234,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
234234

235235
template<typename T>
236236
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
237-
constexpr int Headdim = 160;
237+
constexpr static int Headdim = 160;
238238
auto dprops = at::cuda::getCurrentDeviceProperties();
239239
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
240240
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@@ -264,7 +264,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
264264

265265
template<typename T>
266266
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
267-
constexpr int Headdim = 192;
267+
constexpr static int Headdim = 192;
268268
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
269269
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
270270
if constexpr(!Is_dropout) {
@@ -283,7 +283,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
283283

284284
template<typename T>
285285
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
286-
constexpr int Headdim = 224;
286+
constexpr static int Headdim = 224;
287287
int device;
288288
cudaGetDevice(&device);
289289
int max_smem_per_block;
@@ -309,7 +309,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
309309

310310
template<typename T>
311311
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
312-
constexpr int Headdim = 256;
312+
constexpr static int Headdim = 256;
313313
int device;
314314
cudaGetDevice(&device);
315315
int max_smem_per_sm, max_smem_per_block;

0 commit comments

Comments
 (0)