Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA12 plugin segfaults when older version of JAX is installed #24901

Open
dime10 opened this issue Nov 14, 2024 · 0 comments
Open

CUDA12 plugin segfaults when older version of JAX is installed #24901

dime10 opened this issue Nov 14, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@dime10
Copy link

dime10 commented Nov 14, 2024

Description

Since version 0.4.32 of the jax-cuda12-pjrt and jax-cuda12-plugin packages, installing an older version of jax/jaxlib generates a segfault as soon as an array is instantiated:

>>> import jax
>>> jax.numpy.array(0)
Segmentation fault (core dumped)

This happens even when a GPU is not used (e.g. JAX_PLATFORMS=cpu).

Previously (<=0.4.31), the plugin would raise an exception instead, warning of the version mismatch:

RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Setting the platform to CPU would also allow you to use JAX normally despite the mismatched CUDA plugin version.


You might ask why would you install mismatched versions of the jax and cuda plugin packages, but because JAX doesn't manage the versions of these dependencies automatically, it can easily happen that JAX+CUDA is installed into an environment first, and that another package requests a lower JAX version from pip. It is then very difficult for the user to understand what went wrong due to the segfault.

System info (python version, jaxlib version, accelerator, etc.)

(import jax; jax.print_environment_info() crashes)

Linux
Python 3.10.12

jax-0.4.28
jaxlib-0.4.28
jax-cuda12-pjrt 0.4.33
jax-cuda12-plugin-0.4.33

@dime10 dime10 added the bug Something isn't working label Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant