|
24 | 24 |
|
25 | 25 | from .kernels import matmul as kernels
|
26 | 26 | from .partition_k import matmul_partition_k
|
27 |
| -from .persistent_matmul import ( |
28 |
| - matmul_persistent, |
29 |
| - matmul_tma_persistent, |
30 |
| - matmul_tma_persistent_cached, |
31 |
| -) |
| 27 | +try: |
| 28 | + from .persistent_matmul import ( |
| 29 | + matmul_persistent, |
| 30 | + matmul_tma_persistent, |
| 31 | + matmul_tma_persistent_cached, |
| 32 | + ) |
| 33 | + HAS_PRESISTENT = True |
| 34 | +except ModuleNotFoundError: |
| 35 | + HAS_PRESISTENT = False |
| 36 | + |
32 | 37 | from .triton_matmul import (
|
33 | 38 | matmul as triton_tutorial_matmul,
|
34 | 39 | matmul_kernel as triton_tutorial_matmul_kernel,
|
@@ -158,22 +163,22 @@ def matmul_partition_k(self, a, b, bias) -> Callable:
|
158 | 163 | else:
|
159 | 164 | return lambda: matmul_partition_k(a, bt)
|
160 | 165 |
|
161 |
| - @register_benchmark() |
| 166 | + @register_benchmark(enabled=HAS_PRESISTENT) |
162 | 167 | def triton_persistent_matmul(self, a, b, bias) -> Callable:
|
163 | 168 | if not bias == None:
|
164 | 169 | return lambda: matmul_persistent(a, b) + bias
|
165 | 170 | else:
|
166 | 171 | return lambda: matmul_persistent(a, b)
|
167 | 172 |
|
168 |
| - @register_benchmark(enabled=not IS_FBCODE) |
| 173 | + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) |
169 | 174 | def triton_tma_persistent_matmul(self, a, b, bias) -> Callable:
|
170 | 175 | b = b.T.contiguous()
|
171 | 176 | if not bias == None:
|
172 | 177 | return lambda: matmul_tma_persistent(a, b) + bias
|
173 | 178 | else:
|
174 | 179 | return lambda: matmul_tma_persistent(a, b)
|
175 | 180 |
|
176 |
| - @register_benchmark(enabled=not IS_FBCODE) |
| 181 | + @register_benchmark(enabled=not IS_FBCODE and HAS_PRESISTENT) |
177 | 182 | def triton_tma_persistent_cached_matmul(self, a, b, bias) -> Callable:
|
178 | 183 | b = b.T.contiguous()
|
179 | 184 | if not bias == None:
|
|
0 commit comments