Skip to content

Commit bb0e91a

Browse files
Merge pull request #308 from superbobry:maint-2
PiperOrigin-RevId: 688880108
2 parents aa8c729 + 76972e6 commit bb0e91a

File tree

4 files changed

+17
-22
lines changed

4 files changed

+17
-22
lines changed

docs/index.md

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,33 +24,25 @@ Check out the [JAX installation guide](https://github.com/google/jax#pip-install
2424

2525
### Installation at HEAD
2626

27-
JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:
27+
JAX-Triton is developed at JAX and jaxlib HEAD and close to Triton HEAD. To get
28+
a bleeding edge installation of JAX-Triton, run:
29+
2830
```bash
2931
$ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
3032
```
33+
3134
This should install compatible versions of JAX and Triton.
3235

33-
JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
34-
```bash
35-
$ pip install jaxlib[cuda]
36-
$ # or
37-
$ pip install jaxlib[cuda11_pip]
38-
$ # or
39-
$ pip install jaxlib[cuda12_pip]
40-
```
36+
JAX-Triton requires jaxlib with GPU support. You could install the latest stable
37+
release via
4138

42-
If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.
43-
To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via:
44-
```bash
45-
$ pip install 'jaxlib @ <link to nightly>'
46-
```
47-
or to install CUDA via pip automatically, you can do:
4839
```bash
49-
$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
50-
$ # or
51-
$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'
40+
$ pip install jaxlib[cuda12]
5241
```
5342

43+
In rare cases JAX-Triton might need a nighly version of jaxlib. You can install
44+
it following the instructions
45+
[here](https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation).
5446

5547
### Quickstart
5648

jax_triton/triton_lib.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,10 @@ def triton_kernel_call_lowering(
537537
named_args = dict(unsafe_zip(fn.arg_names, args))
538538

539539
if isinstance(fn, autotuner.Autotuner):
540-
key_idxs = [fn.arg_names.index(k) for k in fn.keys]
540+
if hasattr(fn, "key_idx"):
541+
key_idxs = fn.key_idx # Triton <=3.2
542+
else:
543+
key_idxs = [fn.arg_names.index(k) for k in fn.keys]
541544
if any(idx not in key_idxs for idx, _, _ in scalar_args):
542545
logging.warning(
543546
"Auto-tuning key does not include all scalar arguments. "

jax_triton/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version_info__ = (0, 2, 0)
15+
__version_info__ = (0, 3, 0)
1616
__version__ = ".".join(str(v) for v in __version_info__)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ readme = "README.md"
66
requires-python = ">=3.10"
77
dependencies = [
88
"absl-py>=1.4.0",
9-
"jax>=0.4.31",
10-
"triton>=3.0",
9+
"jax>=0.4.34",
10+
"triton>=3.1",
1111
]
1212

1313
[project.optional-dependencies]

0 commit comments

Comments
 (0)