Replies: 1 comment
-
@tajtac can you try to install the latest version of jax, jaxlib and flax using cuda12-local option: pip install -U "jax[cuda12-local]"
pip install -U flax Can you also use nnx instead of linen? I assume that the segfault is due to incompatible cuda versions and not related to linen or nnx... |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am using jax 0.5.3, jaxlib 0.5.3, flax 0.10.6, and cuda/12.6.1.
Below is a minimal working example. It always segfaults at params = model.init(rng, x), but only when using the GPU. If I start the script in the cpu-only mode, no problem occurs.
--------------------------
Minimal reproducible example
--------------------------
import jax
import jax.numpy as jnp
from flax import linen as nn
class SimpleSelfAttention(nn.Module):
d_model: int = 128
n_heads: int = 8
B = 1 # batch size
D = 16 # sequence length
d_model = 128
x = jnp.ones((B, D, d_model), dtype=jnp.float32)
model = SimpleSelfAttention(d_model=d_model, n_heads=8)
print('Created an instance of SimpleSelfAttention. Segfault occurs after this.')
rng = jax.random.PRNGKey(0)
params = model.init(rng, x)
print('Initialized the model')
--------------------------
A couple of notes about the versions that I am using in case it is important:
--------------------------
Beta Was this translation helpful? Give feedback.
All reactions