From dd822ef4c1c94ff9a50161ddfc5900c4ce1f92e5 Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Mon, 21 Aug 2023 03:04:48 -0700 Subject: [PATCH] Minor bug fix in train step. PiperOrigin-RevId: 558731486 --- vmoe/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vmoe/train/trainer.py b/vmoe/train/trainer.py index eb2375a..e16143a 100644 --- a/vmoe/train/trainer.py +++ b/vmoe/train/trainer.py @@ -642,7 +642,7 @@ def train_step( @functools.partial(jax.grad, has_aux=True) def compute_grads_and_metrics(params, images, labels, rngs): - rngs, next_rngs = utils.tree_rngs_split(state.rngs) + rngs, next_rngs = utils.tree_rngs_split(rngs) logits, metrics = state.apply_fn({'params': params}, images, rngs=rngs) metrics = dict(**metrics) metrics['main_loss'] = jnp.mean(loss_fn(logits, labels))