Skip to content

Commit

Permalink
Minor bug fix in train step.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558731486
  • Loading branch information
jpuigcerver authored and copybara-github committed Aug 21, 2023
1 parent 760d672 commit dd822ef
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def train_step(

@functools.partial(jax.grad, has_aux=True)
def compute_grads_and_metrics(params, images, labels, rngs):
rngs, next_rngs = utils.tree_rngs_split(state.rngs)
rngs, next_rngs = utils.tree_rngs_split(rngs)
logits, metrics = state.apply_fn({'params': params}, images, rngs=rngs)
metrics = dict(**metrics)
metrics['main_loss'] = jnp.mean(loss_fn(logits, labels))
Expand Down

0 comments on commit dd822ef

Please sign in to comment.