Skip to content

Commit 6b5f271

Browse files
committed
[Triton] Avoid einops repeat by using Tensor.expand
1 parent 88c4e5d commit 6b5f271

File tree

1 file changed

+2
-12
lines changed

1 file changed

+2
-12
lines changed

flash_attn/flash_attn_triton.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@
3838

3939
import torch
4040

41-
from einops import rearrange, repeat
42-
4341
import triton
4442
import triton.language as tl
4543

@@ -605,11 +603,7 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
605603
else:
606604
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
607605
' or (seqlen_q, seqlen_k)')
608-
if bias.shape[:2] == (1, nheads):
609-
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
610-
elif bias.shape[:2] == (batch, 1):
611-
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
612-
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
606+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
613607
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
614608

615609
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
@@ -684,11 +678,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
684678
else:
685679
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
686680
' or (seqlen_q, seqlen_k)')
687-
if bias.shape[:2] == (1, nheads):
688-
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
689-
elif bias.shape[:2] == (batch, 1):
690-
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
691-
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
681+
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
692682
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
693683

694684
# BLOCK_M = 128

0 commit comments

Comments
 (0)