Skip to content

Commit a5db0b3

Browse files
committed
Fix tests with CPU backend
1 parent 2b1dda2 commit a5db0b3

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

tests/triton_call_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828

2929
config.parse_flags_with_absl()
3030

31+
try:
32+
jt.get_compute_capability(0)
33+
except AttributeError:
34+
# TODO(stephen-huan): add in jaxlib
35+
jt.get_compute_capability = lambda _: np.inf
36+
3137

3238
def setUpModule():
3339
config.update("jax_enable_x64", True)

tests/triton_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
import numpy as np
2121
import triton
2222
import triton.language as tl
23-
from triton.language.extra.cuda import libdevice
23+
try:
24+
from triton.language.extra.cuda import libdevice
25+
except ImportError:
26+
from triton.language.extra.cpu import libdevice
2427

2528

2629
@triton.jit

0 commit comments

Comments
 (0)