Skip to content

Commit d0787ac

Browse files
committed
Relax dropout_fraction test
1 parent dca6d89 commit d0787ac

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1430,7 +1430,7 @@ def test_flash_attn_varlen_output(
14301430
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
14311431
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
14321432
if not alibi:
1433-
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
1433+
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
14341434

14351435
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
14361436
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()

0 commit comments

Comments
 (0)