We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
present_key
mask_shape
1 parent b2623d4 commit 5a25417Copy full SHA for 5a25417
src/ntops/torch.py
@@ -473,7 +473,9 @@ def scaled_dot_product_attention(
473
"Number of heads in `query` must be divisible by number of heads in `key` and `value` when GQA is enabled."
474
)
475
476
- mask_shape = query.shape[:-1] + (key.shape[-2],)
+ mask_shape = query.shape[:-1] + (
477
+ key.shape[-2] if present_key is None else key.shape[-2] + present_key.shape[-2],
478
+ )
479
480
if attn_mask is not None:
481
with_attn_mask = True
0 commit comments