Skip to content

Commit

Permalink
Don't expand mask when not necessary.
Browse files Browse the repository at this point in the history
Expanding seems to slow down inference.
  • Loading branch information
comfyanonymous committed Dec 16, 2024
1 parent 61b5072 commit 19ee5d9
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
mask = mask.expand(b, heads, -1, -1)


if SDP_BATCH_LIMIT >= b:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
Expand All @@ -434,11 +432,16 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
else:
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, b, SDP_BATCH_LIMIT):
m = mask
if mask is not None:
if mask.shape[0] > 1:
m = mask[i : i + SDP_BATCH_LIMIT]

out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
q[i : i + SDP_BATCH_LIMIT],
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=None if mask is None else mask[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
Expand Down

0 comments on commit 19ee5d9

Please sign in to comment.