-
Hi, I want to create multiple instances of a nnx.module (each initialized with a different key). def make(rng):
m = my_module.init(rng, dummy_input)
return ...
rngs = jax.random.split(jax.random.PRNGKey(0), num=5)
models = jax.vmap(make)(rngs) How can I achieve the same with nnx? def make_model(rngs):
return nnx.Sequential(
nnx.Linear(..., rngs=rngs),
...
)
init_keys = jax.random.split(jax.random.PRNGKey(0), num=5)
rngs = nnx.Rngs(init_keys)
model = jax.vmap(task.make_model)(rngs) But I get the error |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 7 replies
-
EDIT: Updating to use the new APIs. Hey @JeyRunner! You can use @nnx.split_rngs(splits=5)
@nnx.vmap
def make_model(rngs):
return nnx.Linear(2, 3, rngs=rngs)
model = make_model(nnx.Rngs(0))
print(model) Output: Linear(
bias=Param(
value=Array(shape=(5, 3), dtype=float32)
),
bias_init=<function zeros at 0x11ee95f30>,
dot_general=<function dot_general at 0x11e933910>,
dtype=None,
in_features=2,
kernel=Param(
value=Array(shape=(5, 2, 3), dtype=float32)
),
kernel_init=<function variance_scaling.<locals>.init at 0x11fa8fe20>,
out_features=3,
param_dtype=<class 'jax.numpy.float32'>,
precision=None,
use_bias=True
) |
Beta Was this translation helpful? Give feedback.
-
Hi @cgarciae. In the example that you gave above, namely def make_model(rngs): rngs = nnx.Rngs(0) print(model) how would you do it if you wanted the dimensions of Linear in make_model to be input? |
Beta Was this translation helpful? Give feedback.
EDIT: Updating to use the new APIs.
Hey @JeyRunner! You can use
nnx.split_rngs
to automatically split theRngs
before going intonnx.vmap
.Output: