Replies: 1 comment 1 reply
-
If your input is a python list, you have to use a for loop. To use jaxy things like linears = [nnx.Linear(8, 8, rngs=nnx.Rngs(i)) for i in range(10)]
inputs = [jax.random.normal(jax.random.key(i), (4, 8)) for i in range(10)]
linears = jax.tree.map(lambda *xs: jnp.stack(xs), *linears)
inputs = jax.tree.map(lambda *xs: jnp.stack(xs), *inputs)
@jax.vmap
def f(m, x):
return m(x)
outs = f(linears, inputs) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
i have two same-size lists
models: List[nnx.Module]
andinps: List[jax.Array]
.models
are separate instantiations of the samennx.Module
subclass.i calculate outputs as:
outs = [m(i) for m, i in zip(models, inps)]
.how can i parallelize this without a for loop?
from what i understand,
jax.vmap
andjax.lax.map
aren't made to do this cleanlyBeta Was this translation helpful? Give feedback.
All reactions