Skip to content

nnx using vmap to create multiple models #4048

Answered by cgarciae
JeyRunner asked this question in Q&A
Discussion options

You must be logged in to vote

EDIT: Updating to use the new APIs.

Hey @JeyRunner! You can use nnx.split_rngs to automatically split the Rngs before going into nnx.vmap.

@nnx.split_rngs(splits=5)
@nnx.vmap
def make_model(rngs):
  return nnx.Linear(2, 3, rngs=rngs)

model = make_model(nnx.Rngs(0))

print(model)

Output:

Linear(
  bias=Param(
    value=Array(shape=(5, 3), dtype=float32)
  ),
  bias_init=<function zeros at 0x11ee95f30>,
  dot_general=<function dot_general at 0x11e933910>,
  dtype=None,
  in_features=2,
  kernel=Param(
    value=Array(shape=(5, 2, 3), dtype=float32)
  ),
  kernel_init=<function variance_scaling.<locals>.init at 0x11fa8fe20>,
  out_features=3,
  param_dtype=<class 'jax.numpy.float32'>,
  pre…

Replies: 2 comments 7 replies

Comment options

You must be logged in to vote
3 replies
@JeyRunner
Comment options

@JeyRunner
Comment options

@cgarciae
Comment options

Answer selected by JeyRunner
Comment options

You must be logged in to vote
4 replies
@cgarciae
Comment options

@errhernandez
Comment options

@cgarciae
Comment options

@errhernandez
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants