Skip to content

Commit f2963f1

Browse files
committed
Disable gemm on persistent
1 parent 07000ee commit f2963f1

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

test/test_gpu/skip_tests_h100_pytorch.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ fp8_gemm:
2525
- triton_persistent_fp8_gemm
2626
- triton_tma_persistent_fp8_gemm
2727
fp8_gemm_rowwise:
28+
gemm:
29+
- triton_persistent_matmul
30+
- triton_tma_persistent_matmul
31+
- triton_tma_persistent_cached_matmul
2832
jagged_layer_norm:
2933
jagged_mean:
3034
jagged_softmax:

tritonbench/operators/gemm/operator.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,16 @@
2424

2525
from .kernels import matmul as kernels
2626
from .partition_k import matmul_partition_k
27-
from .persistent_matmul import (
28-
matmul_persistent,
29-
matmul_tma_persistent,
30-
matmul_tma_persistent_cached,
31-
)
27+
try:
28+
from .persistent_matmul import (
29+
matmul_persistent,
30+
matmul_tma_persistent,
31+
matmul_tma_persistent_cached,
32+
)
33+
HAS_PRESISTENT = True
34+
except ModuleNotFoundError:
35+
HAS_PRESISTENT = False
36+
3237
from .triton_matmul import (
3338
matmul as triton_tutorial_matmul,
3439
matmul_kernel as triton_tutorial_matmul_kernel,
@@ -158,22 +163,22 @@ def matmul_partition_k(self, a, b, bias) -> Callable:
158163
else:
159164
return lambda: matmul_partition_k(a, bt)
160165

161-
@register_benchmark()
166+
@register_benchmark(enabled=HAS_PRESISTENT)
162167
def triton_persistent_matmul(self, a, b, bias) -> Callable:
163168
if not bias == None:
164169
return lambda: matmul_persistent(a, b) + bias
165170
else:
166171
return lambda: matmul_persistent(a, b)
167172

168-
@register_benchmark(enabled=not IS_FBCODE)
173+
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
169174
def triton_tma_persistent_matmul(self, a, b, bias) -> Callable:
170175
b = b.T.contiguous()
171176
if not bias == None:
172177
return lambda: matmul_tma_persistent(a, b) + bias
173178
else:
174179
return lambda: matmul_tma_persistent(a, b)
175180

176-
@register_benchmark(enabled=not IS_FBCODE)
181+
@register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT)
177182
def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable:
178183
b = b.T.contiguous()
179184
if not bias == None:

0 commit comments

Comments
 (0)