We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2b1dda2 commit a5db0b3Copy full SHA for a5db0b3
tests/triton_call_test.py
@@ -28,6 +28,12 @@
28
29
config.parse_flags_with_absl()
30
31
+try:
32
+ jt.get_compute_capability(0)
33
+except AttributeError:
34
+ # TODO(stephen-huan): add in jaxlib
35
+ jt.get_compute_capability = lambda _: np.inf
36
+
37
38
def setUpModule():
39
config.update("jax_enable_x64", True)
tests/triton_test.py
@@ -20,7 +20,10 @@
20
import numpy as np
21
import triton
22
import triton.language as tl
23
-from triton.language.extra.cuda import libdevice
24
+ from triton.language.extra.cuda import libdevice
25
+except ImportError:
26
+ from triton.language.extra.cpu import libdevice
27
@triton.jit
0 commit comments