|
38 | 38 |
|
39 | 39 | import torch |
40 | 40 |
|
41 | | -from einops import rearrange, repeat |
42 | | - |
43 | 41 | import triton |
44 | 42 | import triton.language as tl |
45 | 43 |
|
@@ -605,11 +603,7 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): |
605 | 603 | else: |
606 | 604 | raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' |
607 | 605 | ' or (seqlen_q, seqlen_k)') |
608 | | - if bias.shape[:2] == (1, nheads): |
609 | | - bias = repeat(bias, '1 h ... -> b h ...', b=batch) |
610 | | - elif bias.shape[:2] == (batch, 1): |
611 | | - bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads) |
612 | | - assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)' |
| 606 | + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) |
613 | 607 | bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) |
614 | 608 |
|
615 | 609 | seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 |
@@ -684,11 +678,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals |
684 | 678 | else: |
685 | 679 | raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' |
686 | 680 | ' or (seqlen_q, seqlen_k)') |
687 | | - if bias.shape[:2] == (1, nheads): |
688 | | - bias = repeat(bias, '1 h ... -> b h ...', b=batch) |
689 | | - elif bias.shape[:2] == (batch, 1): |
690 | | - bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads) |
691 | | - assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)' |
| 681 | + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) |
692 | 682 | bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) |
693 | 683 |
|
694 | 684 | # BLOCK_M = 128 |
|
0 commit comments