Skip to content

Commit 5a25417

Browse files
committed
Take present_key into account when generating mask_shape
1 parent b2623d4 commit 5a25417

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/ntops/torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,9 @@ def scaled_dot_product_attention(
473473
"Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled."
474474
)
475475

476-
mask_shape = query.shape[:-1] + (key.shape[-2],)
476+
mask_shape = query.shape[:-1] + (
477+
key.shape[-2] if present_key is None else key.shape[-2] + present_key.shape[-2],
478+
)
477479

478480
if attn_mask is not None:
479481
with_attn_mask = True

0 commit comments

Comments
 (0)