Skip to content
Discussion options

You must be logged in to vote

I created a sample that performs Model EMA with Flax NXX.

It is based on the MNIST sample of Flax NNX.

  1. Create a derived class of nnx.Optimizer (Model EMA) and update the model parameters with the update method.
  2. Clone the model and create an instance of the ModelEMA class (ema_optimizer) with optax.ema.
  3. In the learning loop, after train_step, update the EMA parameters by calling the update method of ema_optimizer in ema_step with the updated parameters.

Even in the case of a model that includes BatchNorm, this can be achieved by changing the wr argument specified in the nnx.Optimizer constructor, and there is no ne…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by bitsandscraps
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants