You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since version 0.4.32 of the jax-cuda12-pjrt and jax-cuda12-plugin packages, installing an older version of jax/jaxlib generates a segfault as soon as an array is instantiated:
This happens even when a GPU is not used (e.g. JAX_PLATFORMS=cpu).
Previously (<=0.4.31), the plugin would raise an exception instead, warning of the version mismatch:
RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
Setting the platform to CPU would also allow you to use JAX normally despite the mismatched CUDA plugin version.
You might ask why would you install mismatched versions of the jax and cuda plugin packages, but because JAX doesn't manage the versions of these dependencies automatically, it can easily happen that JAX+CUDA is installed into an environment first, and that another package requests a lower JAX version from pip. It is then very difficult for the user to understand what went wrong due to the segfault.
System info (python version, jaxlib version, accelerator, etc.)
Description
Since version
0.4.32
of thejax-cuda12-pjrt
andjax-cuda12-plugin
packages, installing an older version ofjax
/jaxlib
generates a segfault as soon as an array is instantiated:This happens even when a GPU is not used (e.g.
JAX_PLATFORMS=cpu
).Previously (
<=0.4.31
), the plugin would raise an exception instead, warning of the version mismatch:Setting the platform to CPU would also allow you to use JAX normally despite the mismatched CUDA plugin version.
You might ask why would you install mismatched versions of the jax and cuda plugin packages, but because JAX doesn't manage the versions of these dependencies automatically, it can easily happen that JAX+CUDA is installed into an environment first, and that another package requests a lower JAX version from pip. It is then very difficult for the user to understand what went wrong due to the segfault.
System info (python version, jaxlib version, accelerator, etc.)
(
import jax; jax.print_environment_info()
crashes)Linux
Python 3.10.12
jax-0.4.28
jaxlib-0.4.28
jax-cuda12-pjrt 0.4.33
jax-cuda12-plugin-0.4.33
The text was updated successfully, but these errors were encountered: