Skip to content

Commit

Permalink
Fix internal ragged_attengion operator and its tests
Browse files Browse the repository at this point in the history
Summary: Fix ragged_attention and its API change caused by upstream

Reviewed By: manman-ren

Differential Revision: D66116624

fbshipit-source-id: c808764df77412aaf82bf1bd95c40438abd7a6b5
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Nov 19, 2024
1 parent c0a8479 commit 6e52ed2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

# Internal Import
from hammer.oss.generative_recommenders.ops.triton.triton_ragged_hstu_attention import (
_ragged_hstu_attn_fwd,
_ragged_hstu_attn_fwd_persistent,
RaggedAttentionRelativeBiasFunction,
)
except ModuleNotFoundError:
# OSS Import
Expand All @@ -19,8 +19,8 @@
_ragged_hstu_attn_fwd_persistent = (
triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent
)
_RaggedAttentionRelativeBiasFunction = (
triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction
RaggedAttentionRelativeBiasFunction = (
triton_ragged_hstu_attention.RaggedAttentionRelativeBiasFunction
)

@torch.fx.wrap
Expand Down Expand Up @@ -150,7 +150,7 @@ def forward(
grid = (1216,)
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
else:
out = _RaggedAttentionRelativeBiasFunction.apply(
out = RaggedAttentionRelativeBiasFunction.apply(
self.max_seq_len, # N
kwargs["alpha"],
q,
Expand All @@ -169,7 +169,6 @@ def forward(
kwargs["time_delta"], # time_delta
kwargs["max_pos_ind"], # max_pos_ind
kwargs["num_targets"],
None, # attn_scale
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["contextual_seq_len"], # contextual_seq_len
Expand Down

0 comments on commit 6e52ed2

Please sign in to comment.