You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This should install compatible versions of JAX and Triton.
32
35
33
-
JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
34
-
```bash
35
-
$ pip install jaxlib[cuda]
36
-
$ # or
37
-
$ pip install jaxlib[cuda11_pip]
38
-
$ # or
39
-
$ pip install jaxlib[cuda12_pip]
40
-
```
36
+
JAX-Triton requires jaxlib with GPU support. You could install the latest stable
37
+
release via
41
38
42
-
If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.
43
-
To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via:
44
-
```bash
45
-
$ pip install 'jaxlib @ <link to nightly>'
46
-
```
47
-
or to install CUDA via pip automatically, you can do:
48
39
```bash
49
-
$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
50
-
$ # or
51
-
$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'
40
+
$ pip install jaxlib[cuda12]
52
41
```
53
42
43
+
In rare cases JAX-Triton might need a nighly version of jaxlib. You can install
0 commit comments