Skip to content

Conversation

@BabyChouSr
Copy link
Collaborator

@BabyChouSr BabyChouSr commented Nov 18, 2025

Description

In-flight weight updates
DAPO
Loss function
Zero-variance prompt filtering
Length penalty

Checklist

  • You ran uv run python infra/pre-commit.py --all-files to lint/format your code
  • You ran 'pytest' to test your code
  • Delete this checklist

Resolves multiple unit test failures introduced by recent API changes and fixes a threading/execution-context issue in the weight transfer tests. Changes:
* `tests/rl/test_weight_transfer.py`: Fix Arrow Flight tests by adding a shared `job_context` fixture so Server and Client actors run in the same execution context.
* `tests/rl/test_curriculum.py`: Update `CurriculumConfig` initialization to include `max_seq_len`; update `RolloutStats` initialization to include `temperature` and `top_k`.
* `tests/rl/test_replay_buffer.py`: Update `Rollout` initialization to include `top_k=None`; fix batch shape assertions.
* `tests/rl/test_train_batch.py`: Add missing `pad_to` argument to `convert_rollout_to_training_format` and `create_training_batch_from_rollouts`; update expected dictionary keys accordingly.
* `tests/rl/environments/*.py`: Update `DummyInferenceContext.batch_completions` signature to accept the `top_k` argument.
@@ -243,7 +324,7 @@ def rl_train(name: str, experiment_config: ExperimentConfig) -> ExecutorStep:
mode=WeightTransferMode.ARROW_FLIGHT,
sync_interval_steps=1,
# We are running on-policy, so wait for new weights from the trainer after each episode.
max_weight_transfer_wait_time=120,
max_weight_transfer_wait_time=300,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make a rl_macro.py file or something explaining where all these magic numbers come from? or just define the macros up top with an explainer one or two sentences

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, addressed with a refactoring


try:
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@AlienKevin
Copy link
Contributor

RL accuracy maintained after recent merges and refactorings

As expected, MATH-500 run started at 0.29 and quickly reached close to 0.5 (0.488) in 18 steps.

Screenshot 2026-01-10 at 9 21 46 AM

actor wandb, trainer wandb

test commit
test command:

uv run lib/marin/src/marin/run/ray_run.py --env_vars WANDB_API_KEY ${WANDB_API_KEY} --env_vars WANDB_ENTITY marin-community --env_vars WANDB_PROJECT marin --env_vars TPU_CI true --env_vars HF_TOKEN ${HF_TOKEN} --cluster us-central1 --extra vllm,math --no_wait -- python experiments/exp2039_rl_math500.py --force_run_failed True

"""Parameters for sampling rollouts from an environment."""

temperature: float = 1.0
top_k: int | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we had a top k bug, do we want to set a warning or something here? I don't know when you'd ever want greedy decoding for RL training

n_generations: Number of generations per example
temperature: Sampling temperature for generation
prng_key: JAX random key for sampling
mode: "train" or "eval" - which dataset to sample from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe base env is a good place to add a warning


import equinox as eqx
import jax
import jax.numpy as jnp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GENERAL comment:

We should factorize this more. Instead of having commented out lines in PPO, just make different methods for DAPO, PPO, GRPO etc. this way we can more easily test them against one another.

normally code re-use is good but for RL it can bite us so a bit of repetition is fine here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be addressed in #2327

# loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / max_output_tokens)

# more like DAPO loss
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the PPO loss is normalized incorrectly.

We dividing each example’s objective by the total number of unmasked tokens across the whole batch, and then we also average across the batch. That effectively divides by batch size twice, so the loss/gradients get smaller as you increase batch size

We should either normalize per-example by that example’s token count and then average, or do a single global token-average over the batch.

Suggested change
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks))
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1)))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting this! Could you check if 7d8f7fa dresses this issue?

Reverts changes to lib/fray/src/fray/job/context.py to match main.
Refactors downstream usage of .call() in rollout_worker.py, train_worker.py, and curriculum.py to use .remote() and get_default_job_ctx().get() instead.
@AlienKevin
Copy link
Contributor

AlienKevin commented Jan 13, 2026

Recommended Review & Merge Order

To facilitate a smooth review and merge process, I recommend reviewing these sub-PRs in the following order (foundation first, then logic, then specific experiments/features):

  1. Upgrade tpu-inference in alignment with jax==0.8.0: Upgrade tpu-inference in alignment with jax==0.8.0 #2330 (Foundation & dependencies)
  2. Refactor Inference Context & Fix vLLM TopK: Refactor Inference Context & Fix vLLM TopK #2329 (Inference layer updates)
  3. Support inflight weight updates: Support inflight weight updates #2325 (Core RL infrastructure)
  4. RL Loss Improvements: RL Loss Improvements #2327 (Loss functions & sampling warnings)
  5. MATH-500 RL Environment and Experiment: MATH-500 RL Environment and Experiment #2326 (Main experiment environment)
  6. Add GSM8K RL Environment: Add GSM8K RL Environment #2324 (Secondary environment)
  7. Classification Processing: Classification Processing #2328 (Feature-specific processing)
  8. Update MockEnv logic: Update MockEnv logic #2334 (Environment utilities)
  9. Remove old RL scripts and update architecture docs: Remove old RL scripts and update architecture docs #2333 (Final cleanup)
  10. Ray Auth workaround to be able to submit jobs: Ray Auth workaround #2335

All changes from chris/exp-rl are now strictly accounted for across these 10 disjoint PRs. instruction_datasets.py is already merged in main and is thus excluded from the diffs.

@AlienKevin
Copy link
Contributor

AlienKevin commented Jan 16, 2026

Rerunning MATH-500 after sub-PR merges:

actor wandb, trainer wandb
test commit: 9d3fc65

The above run got stuck after 60 steps when the rollout worker kept dying.

Rerunning again with some QoL updates:
actor wandb, trainer wandb
test commit: 2c56194

@AlienKevin
Copy link
Contributor

AlienKevin commented Jan 17, 2026

MATH-500 RL Training Regression Analysis

Summary

A change to the DAPO loss normalization in rl_losses.py switched from global token normalization to per-example normalization. While both are valid approaches, the change caused a regression in MATH-500 training performance, likely because global normalization better suits tasks requiring detailed reasoning.

Context: Original Review Comment

This change was suggested in a PR review comment:

I think the PPO loss is normalized incorrectly. We dividing each example's objective by the total number of unmasked tokens across the whole batch, and then we also average across the batch. That effectively divides by batch size twice, so the loss/gradients get smaller as you increase batch size.

The concern about "dividing by batch size twice" is valid mathematically, but under AdamW this doesn't matter — AdamW normalizes gradients per-parameter, so constant scaling factors cancel out. What does matter is the relative weighting between examples, which the change altered.

Evidence from WandB

Comparing runs:

Metric Before After Change
train/loss -0.00049 -0.040 82x larger
grad/norm/total 0.003 0.407 130x larger
pass_at_1 (final) 0.444 0.386 -13% accuracy
train_correct_accuracy 0.602 0.530 -12% accuracy

Root Cause

The change was in compute_dapo_loss():

# Per-example normalization (introduced regression)
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1))
#                                                                    ^^^^^^^^^^^^^^^^^^^^^^^^
#                                                                    Per-example token count

# Global normalization (original behavior)
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks))
#                                                                    ^^^^^^^^^^^^^^^^^
#                                                                    Total batch tokens

Why This Matters

The two formulas represent different objectives:

Normalization Formula What it rewards
Global (/ N) -mean(L_i / N) Each token in the reasoning chain (process)
Per-example (/ n_i) -mean(L_i / n_i) Getting the right answer (outcome)

Setup:

  • L_i = sum of loss_objective for example i
  • n_i = number of tokens in example i
  • N = total tokens across batch = Σn_i
  • B = batch size

Global normalization: loss = -mean(L_i / N)

∂loss/∂L_i = -1/(B × N)   # Same constant for all examples

→ Longer responses get proportionally more gradient signal

Per-example normalization: loss = -mean(L_i / n_i)

∂loss/∂L_i = -1/(B × n_i)   # Varies by example length

→ All responses contribute equally regardless of length

Concrete example:

Example Length (n_i) Global Weight Per-Example Weight
A 100 1/600 1/100
B 500 1/600 1/500

With global normalization (N=600), example B has 5x more gradient influence because it's 5x longer. With per-example normalization, both contribute equally.

Impact on Math Reasoning

For math reasoning, global normalization works better empirically because:

  • Correct solutions often require detailed step-by-step derivations → longer responses
  • We want to reinforce the entire reasoning process, not just the final answer
  • Per-example normalization treats a terse correct answer the same as a detailed proof

The WandB evidence shows this matters: -13% accuracy when switching to per-example normalization.

Fix

Restored original DAPO loss normalization in rl_losses.py:234-242:

def compute_dapo_loss(loss_objective, loss_masks):
    """Compute DAPO-like loss (global token normalization)."""
    return -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks))

Why Global Normalization Works Better Here

The original reviewer's concern was that we "divide by batch size twice." Let's trace through:

loss = -1 * mean(sum(L_i) / N)  # N = total tokens across batch
     = -1 * (1/B) * sum(L_i / N)  # B = batch size

This does produce smaller gradients with larger batches. However:

  1. AdamW cancels constant factors — The optimizer normalizes by running variance, so grad * c produces the same update as grad after a few steps.

  2. The key difference is example weighting — Global normalization weights all examples by the same factor (1/N). Per-example normalization weights each example by 1/n_i, giving shorter responses more influence.

  3. Batch size sensitivity is already handled — Learning rate schedules and hyperparameter tuning account for batch size effects. Changing the loss formula requires re-tuning hyperparameters.

Per-example normalization may be preferable for other tasks where response length shouldn't affect gradient weight.

Reference: How Tinker Handles This

For comparison, Tinker's importance_sampling loss uses sum reduction without explicit normalization:

loss = -(prob_ratio * advantages).sum()

From their documentation:

"If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation."

Tinker delegates normalization to the advantage computation, which uses simple mean-centering within groups:

# tinker-cookbook/tinker_cookbook/rl/data_processing.py
def compute_advantages(trajectory_groups_P):
    for traj_group in trajectory_groups_P:
        rewards_G = torch.tensor(traj_group.get_total_rewards())
        advantages_G = rewards_G - rewards_G.mean()  # Center within group

This approach gives longer responses more gradient signal (similar to our global normalization), since the sum is not divided by sequence length.

Verified to fix regression

Before the regression fix, pass@1 maxed out at 0.45 and is less stable (dipped to 0.35 at step 17):
Screenshot 2026-01-17 at 6 03 55 PM

After applying the fix which reverts advantage back to the per-example normalization, we reached 0.486 pass@1 on MATH-500 in 23 steps and stablizes around 0.47 for the next ~100 steps, closely matching the behavior observed before.

Screenshot 2026-01-17 at 5 58 46 PM

actor wandb, trainer wandb
test commit: deb3ebd

AlienKevin added a commit that referenced this pull request Jan 19, 2026
…2376)

This PR fixes a regression in the DAPO loss computation by switching
from per-example normalization (/ n_i) back to global token
normalization (/ N). Per-example normalization gives shorter responses
disproportionately more gradient weight, which hurts math reasoning
tasks where correct answers often require detailed, longer derivations.
Global normalization weights all examples equally regardless of response
length.

Check out
#2039 (comment)
for full context and experimental validation.

Co-authored-by: Claude Opus 4.5 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants