We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent dca6d89 commit d0787acCopy full SHA for d0787ac
tests/test_flash_attn.py
@@ -1430,7 +1430,7 @@ def test_flash_attn_varlen_output(
1430
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
1431
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
1432
if not alibi:
1433
- assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
+ assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04)
1434
1435
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
1436
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
0 commit comments