@@ -240,10 +240,10 @@ def run(*args, **kwargs):
240240headdim = 256
241241# for headdim in [64, 128, 256]:
242242# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)]
243- # bs_seqlen_vals = [(16 , 1024), (8 , 2048), (4 , 4096), (2 , 8192), (1 , 16384)]
243+ bs_seqlen_vals = [(32 , 1024 ), (16 , 2048 ), (8 , 4096 ), (4 , 8192 ), (2 , 16384 ), ( 1 , 32768 )]
244244# bs_seqlen_vals = [(32, 512), (16, 1024)]
245245# bs_seqlen_vals = [(2, 64 * 132)]
246- bs_seqlen_vals = [(4 , 8192 )]
246+ # bs_seqlen_vals = [(4, 8192)]
247247# bs_seqlen_vals = [(1, 16 * 1024)]
248248time_f = {}
249249time_b = {}
@@ -254,7 +254,7 @@ def run(*args, **kwargs):
254254# for headdim in [64, 96, 128]:
255255# for headdim in [64, 128, 256]:
256256# for headdim in [64, 96, 128, 192, 256]:
257- for headdim in [64 , 128 ]:
257+ for headdim in [128 ]:
258258 nheads = dim // headdim
259259 # nheads = 128
260260 # headdim = 64
@@ -312,8 +312,8 @@ def run(*args, **kwargs):
312312 else :
313313 page_table = None
314314
315- for causal in [False , True ]:
316- # for causal in [False ]:
315+ # for causal in [False, True]:
316+ for causal in [True ]:
317317 print (f"\n ### { headdim = } , { causal = } , { seqlen = } ###" )
318318 nFLOPS = flops (batch_size , nheads , seqlen_q , seqlen , headdim if not has_qv else headdim + headdim_v , headdim_v , causal = causal , window_size = window_size )
319319 if cudnn is not None :
0 commit comments