-
System informationOS Platform and Distribution: Linux Ubuntu 22.04 Problem you have encountered:
Steps to reproduce:
I saw a similar discussion for |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hey! I'm guessing you want to replicate the weights but have different RNGs, to do this you can use the state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
@nnx.split_rngs(splits=1)
@nnx.pmap(in_axes=(state_axes, 0))
def forward(model, x):
return model(x)
out = forward(model, jnp.ones((1, 16, 2))) For more info, check out the Filters guide. |
Beta Was this translation helpful? Give feedback.
-
thank you, keep up the amazing work! |
Beta Was this translation helpful? Give feedback.
Hey! I'm guessing you want to replicate the weights but have different RNGs, to do this you can use the
nnx.split_rngs
decorator to split the RNGs before enteringpmap
and, and useStateAxes
to specify the parallelization axes for substates of your Module, in this case mapRngState
to0
and the rest (...
) toNone
:For more info, check out the Filters guide.