What would be the best way to keep track of the EMA of your model parameters in NNX? #4528
-
What would be the best way to keep track of the EMA of your model parameters in NNX? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I created a sample that performs Model EMA with Flax NXX. It is based on the MNIST sample of Flax NNX.
Even in the case of a model that includes BatchNorm, this can be achieved by changing the wr argument specified in the nnx.Optimizer constructor, and there is no need to be aware of it in the learning loop or ema_step. If you want to change the EMA logic, you should define your own optax GradientTransformation. |
Beta Was this translation helpful? Give feedback.
-
Thanks for the very nice example. class ModelEMA(nnx.Optimizer):
def __init__(
self,
model: nnx.Module,
tx: optax.GradientTransformation,
):
super().__init__(model, tx, wrt=[nnx.Param, nnx.BatchStat])
def update(self, model, model_orginal: nnx.Module):
params = nnx.state(model_orginal, self.wrt)
ema_params = nnx.state(model, self.wrt)
self.step.value += 1
ema_state = optax.EmaState(count=self.step, ema=ema_params)
_, new_ema_state = self.tx.update(params, ema_state)
nnx.update(model, new_ema_state.ema) An updated version of the notebook is available here: https://colab.research.google.com/gist/aurelio-amerio/afa5b4da0c3a2b881250e490c8688345/flax-nnx-model-ema.ipynb |
Beta Was this translation helpful? Give feedback.
I created a sample that performs Model EMA with Flax NXX.
It is based on the MNIST sample of Flax NNX.
Even in the case of a model that includes BatchNorm, this can be achieved by changing the wr argument specified in the nnx.Optimizer constructor, and there is no ne…