cannot import name 'pocketfft' from 'jaxlib' #18892
Replies: 2 comments 6 replies
-
Thanks for the question! jax 0.2.16 is not going to be compatible with jaxlib 0.4.X. This was before jax & jaxlib versions were in sync, but you can see here that jax 0.2.16 was tested with jaxlib versions on or after 0.1.65: https://github.com/google/jax/blob/jax-v0.2.16/jax/version.py#L16 My number one suggestion would be to use a more recent jax version: lots of things have changed, and I suspect it will be difficult to install such old versions on current systems (for example, jaxlib 0.1.65 will most definitely not be compatible with CUDA 12.2). But if for some reason you must use jax v0.2.16, then use it with jaxlib v0.1.65. |
Beta Was this translation helpful? Give feedback.
-
I also encountered this issue. I am using "pip install jaxlib==0.4.23" and "pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html" |
Beta Was this translation helpful? Give feedback.
-
Hi dears! I'm struggling with a jaxlib version that supports:
import name 'pocketfft' from 'jaxlib'
My device has cuda v.12.2. I installed sucessfully both libraries jax and jaxlib:
pip install jax==0.2.16 jaxlib==0.4.21+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
But when I run my script which is using whisper-jax this is throwing this error:
The bad news for me is that the jax cuda releases, the support for aarch64 it starts from cuda12, and the jaxlib with cuda 11 for aarch64 doesnt exists. see here.
I would like to know what version of jaxlib uses this:
from jaxlib import pocketfft
Maybe there are versions of jaxlib for only x86_64 architecture that supports
from jaxlib import pocketfft
and not for aarch64.I would like to know some ideas to fix this problem, My device is a nvidia jetson orin nano (arm64).
Thanks so much.
Beta Was this translation helpful? Give feedback.
All reactions