diff --git a/tritonbench/operators/bf16xint16_gemm/kernel.py b/tritonbench/operators/bf16xint16_gemm/kernel.py index c3590ad2..78155fa6 100644 --- a/tritonbench/operators/bf16xint16_gemm/kernel.py +++ b/tritonbench/operators/bf16xint16_gemm/kernel.py @@ -194,7 +194,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -205,7 +205,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -216,7 +216,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -227,7 +227,7 @@ def get_hip_autotune_config(): "waves_per_eu": 3, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -238,7 +238,7 @@ def get_hip_autotune_config(): "waves_per_eu": 8, }, num_warps=4, - num_stages=0, + num_stages=2, ), ] diff --git a/tritonbench/operators/fp8_gemm/tutorial.py b/tritonbench/operators/fp8_gemm/tutorial.py index ed312deb..99ae23d0 100644 --- a/tritonbench/operators/fp8_gemm/tutorial.py +++ b/tritonbench/operators/fp8_gemm/tutorial.py @@ -341,7 +341,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -352,7 +352,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -363,7 +363,7 @@ def get_hip_autotune_config(): "waves_per_eu": 2, }, num_warps=8, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -374,7 +374,7 @@ def get_hip_autotune_config(): "waves_per_eu": 3, }, num_warps=4, - num_stages=0, + num_stages=2, ), triton.Config( { @@ -385,7 +385,7 @@ def get_hip_autotune_config(): "waves_per_eu": 8, }, num_warps=4, - num_stages=0, + num_stages=2, ), ]