Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnx.vmap use the same random key rngs inside nnx.Module across vectorization. #4195

Open
maxencefaldor opened this issue Sep 15, 2024 · 1 comment

Comments

@maxencefaldor
Copy link

Hi,

If I run something like:

def inference_fn(model, x):
	y = model(input)
	return y

ys = nnx.vmap(inference_fn, in_axes=(None, 0))(model, xs)

the random key that is used across the vectorization is unique. It means that stochastic functions will have the same behavior across the batch.

  • Is this the intended behavior for nnx.vmap when dealing with stochastic models?
  • If yes, is there a recommended way to ensure independent random keys are used for each batch element when using nnx.vmap?

Thanks!

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 16, 2024

Hey, in NNX model random state is just a regular type of state, Rngs internally holds RngKey and RngCount which are subtypes of RngState(Variable). Starting from flax>=0.9.0 NNX doesn't treat random state in a special way (see JAX-style NNX Transforms), to implement RNG state handling you can either split the rng keys passed to Rngs or use the new nnx.split_rngs API (easier):

state_axes = nnx.StateAxes({RngState: 0, ...: None}) # vectorize RngState, broadcast the rest

@nnx.split_rngs(splits=<num_splits>)
@nnx.vmap(in_axes=(state_axes, 0))
def inference_fn(model, x):
	y = model(input)
	return y

split_rngs will temporarily lift/split the RngState and lower it afterwards. Note that instead of using None to broadcast the model state, you have to use StateAxes to specify that you want to vectorize the RngState on axis 0 and broadcast all other state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants