Skip to content

Commit

Permalink
Fix apply input
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 14, 2024
1 parent 81ea0eb commit 09b3f38
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,16 @@ def forward(
grid = (1216,)
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
else:
kwargs = {
"max_seq_len": self.max_seq_len,
"alpha": kwargs["alpha"],
"q": kwargs["Q"],
"k": kwargs["K"],
"v":kwargs["V"],
"seq_offsets": kwargs["seq_offsets"],
"invalid_attn_mask_type": kwargs["INVALID_MASK_TYPE"],
"num_targets": kwargs["num_targets"],
}
_RaggedAttentionRelativeBiasFunction.apply(**kwargs)
_RaggedAttentionRelativeBiasFunction.apply(
self.max_seq_len,
kwargs["alpha"],
q,
k,
v,
kwargs["seq_offsets"],
kwargs["INVALID_MASK_TYPE"],
kwargs["num_targets"]
)

return out

Expand Down

0 comments on commit 09b3f38

Please sign in to comment.