diff --git a/tritonbench/operators/fp8_attention/operator.py b/tritonbench/operators/fp8_attention/operator.py index 63d1e33b..0131ca73 100644 --- a/tritonbench/operators/fp8_attention/operator.py +++ b/tritonbench/operators/fp8_attention/operator.py @@ -10,7 +10,7 @@ import torch -from tritonbench.kernels.triton_fused_attention import attention as triton_attention +from tritonbench.kernels.triton_fused_attention import attention_opt as triton_attention from tritonbench.utils.triton_op import ( BenchmarkOperator, BenchmarkOperatorMetrics, @@ -110,7 +110,7 @@ def triton_flash_v2( triton_q, triton_k, triton_v = self.triton_preprocess(q, k, v) # full fp8 will be enabled if type of q,k,v is fp8 return lambda: triton_attention( - triton_q, triton_k, triton_v, False, self.sm_scale + triton_q, triton_k, triton_v, False, self.sm_scale, "base" ) def get_x_val(self, _example_inputs) -> Tuple[int, int, int, int]: