Skip to content

Commit 34a3656

Browse files
committed
[Cute] Bench multiple seqlens
1 parent 50e0736 commit 34a3656

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

benchmarks/benchmark_attn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,10 @@ def run(*args, **kwargs):
240240
headdim = 256
241241
# for headdim in [64, 128, 256]:
242242
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
243-
# bs_seqlen_vals = [(16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
243+
bs_seqlen_vals = [(32, 1024), (16, 2048), (8, 4096), (4, 8192), (2, 16384), (1, 32768)]
244244
# bs_seqlen_vals = [(32, 512), (16, 1024)]
245245
# bs_seqlen_vals = [(2, 64 * 132)]
246-
bs_seqlen_vals = [(4, 8192)]
246+
# bs_seqlen_vals = [(4, 8192)]
247247
# bs_seqlen_vals = [(1, 16 * 1024)]
248248
time_f = {}
249249
time_b = {}
@@ -254,7 +254,7 @@ def run(*args, **kwargs):
254254
# for headdim in [64, 96, 128]:
255255
# for headdim in [64, 128, 256]:
256256
# for headdim in [64, 96, 128, 192, 256]:
257-
for headdim in [64, 128]:
257+
for headdim in [128]:
258258
nheads = dim // headdim
259259
# nheads = 128
260260
# headdim = 64
@@ -312,8 +312,8 @@ def run(*args, **kwargs):
312312
else:
313313
page_table = None
314314

315-
for causal in [False, True]:
316-
# for causal in [False]:
315+
# for causal in [False, True]:
316+
for causal in [True]:
317317
print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
318318
nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim if not has_qv else headdim + headdim_v, headdim_v, causal=causal, window_size=window_size)
319319
if cudnn is not None:

flash_attn/cute/softmax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ def apply_exp2_convert(
175175
acc_S_row: cute.Tensor,
176176
acc_S_row_converted: cute.Tensor,
177177
e2e: cutlass.Constexpr[bool] = False,
178-
e2e_freq: cutlass.Constexpr[bool] = 32,
179-
e2e_res: cutlass.Constexpr[bool] = 4,
180-
e2e_frg_limit: cutlass.Constexpr[bool] = 1,
178+
e2e_freq: cutlass.Constexpr[int] = 16,
179+
e2e_res: cutlass.Constexpr[int] = 4,
180+
e2e_frg_limit: cutlass.Constexpr[int] = 1,
181181
):
182182
assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements"
183183
frg_tile = 32

0 commit comments

Comments
 (0)