From 6e52ed20ae5371c331c92d373e548a4516c7be7c Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Tue, 19 Nov 2024 06:56:39 -0800 Subject: [PATCH] Fix internal ragged_attengion operator and its tests Summary: Fix ragged_attention and its API change caused by upstream Reviewed By: manman-ren Differential Revision: D66116624 fbshipit-source-id: c808764df77412aaf82bf1bd95c40438abd7a6b5 --- tritonbench/operators/ragged_attention/hstu.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index 9a050298..7059f057 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -8,8 +8,8 @@ # Internal Import from hammer.oss.generative_recommenders.ops.triton.triton_ragged_hstu_attention import ( - _ragged_hstu_attn_fwd, _ragged_hstu_attn_fwd_persistent, + RaggedAttentionRelativeBiasFunction, ) except ModuleNotFoundError: # OSS Import @@ -19,8 +19,8 @@ _ragged_hstu_attn_fwd_persistent = ( triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent ) - _RaggedAttentionRelativeBiasFunction = ( - triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction + RaggedAttentionRelativeBiasFunction = ( + triton_ragged_hstu_attention.RaggedAttentionRelativeBiasFunction ) @torch.fx.wrap @@ -150,7 +150,7 @@ def forward( grid = (1216,) _ragged_hstu_attn_fwd_persistent[grid](**kwargs) else: - out = _RaggedAttentionRelativeBiasFunction.apply( + out = RaggedAttentionRelativeBiasFunction.apply( self.max_seq_len, # N kwargs["alpha"], q, @@ -169,7 +169,6 @@ def forward( kwargs["time_delta"], # time_delta kwargs["max_pos_ind"], # max_pos_ind kwargs["num_targets"], - None, # attn_scale kwargs["ATTN_BIAS_TYPE"], # relative_bias_type kwargs["MAX_ATTN_LEN"], # max_attn_len kwargs["contextual_seq_len"], # contextual_seq_len