diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index 9a050298..59224ece 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -19,8 +19,8 @@ _ragged_hstu_attn_fwd_persistent = ( triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent ) - _RaggedAttentionRelativeBiasFunction = ( - triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction + RaggedAttentionRelativeBiasFunction = ( + triton_ragged_hstu_attention.RaggedAttentionRelativeBiasFunction ) @torch.fx.wrap @@ -150,7 +150,7 @@ def forward( grid = (1216,) _ragged_hstu_attn_fwd_persistent[grid](**kwargs) else: - out = _RaggedAttentionRelativeBiasFunction.apply( + out = RaggedAttentionRelativeBiasFunction.apply( self.max_seq_len, # N kwargs["alpha"], q,