diff --git a/lib/marin/src/marin/rl/rl_losses.py b/lib/marin/src/marin/rl/rl_losses.py index 18a7785745..292a551f80 100644 --- a/lib/marin/src/marin/rl/rl_losses.py +++ b/lib/marin/src/marin/rl/rl_losses.py @@ -235,9 +235,11 @@ def compute_dapo_loss( loss_objective: jax.Array, loss_masks: jax.Array, ) -> jax.Array: - """Compute DAPO-like loss (per-example normalization).""" - # Use per-example normalization (averaging the per-example means) - return -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1)) + """Compute DAPO-like loss (global token normalization). + + Divides by total tokens across all examples in the batch, not per-example. + """ + return -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks)) def compute_grpo_loss(