diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index 9a050298..7059f057 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -8,8 +8,8 @@ # Internal Import from hammer.oss.generative_recommenders.ops.triton.triton_ragged_hstu_attention import ( - _ragged_hstu_attn_fwd, _ragged_hstu_attn_fwd_persistent, + RaggedAttentionRelativeBiasFunction, ) except ModuleNotFoundError: # OSS Import @@ -19,8 +19,8 @@ _ragged_hstu_attn_fwd_persistent = ( triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent ) - _RaggedAttentionRelativeBiasFunction = ( - triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction + RaggedAttentionRelativeBiasFunction = ( + triton_ragged_hstu_attention.RaggedAttentionRelativeBiasFunction ) @torch.fx.wrap @@ -150,7 +150,7 @@ def forward( grid = (1216,) _ragged_hstu_attn_fwd_persistent[grid](**kwargs) else: - out = _RaggedAttentionRelativeBiasFunction.apply( + out = RaggedAttentionRelativeBiasFunction.apply( self.max_seq_len, # N kwargs["alpha"], q, @@ -169,7 +169,6 @@ def forward( kwargs["time_delta"], # time_delta kwargs["max_pos_ind"], # max_pos_ind kwargs["num_targets"], - None, # attn_scale kwargs["ATTN_BIAS_TYPE"], # relative_bias_type kwargs["MAX_ATTN_LEN"], # max_attn_len kwargs["contextual_seq_len"], # contextual_seq_len