-
Notifications
You must be signed in to change notification settings - Fork 71
[WIP] RL experimentation branch + synthetic data generation #2039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Resolves multiple unit test failures introduced by recent API changes and fixes a threading/execution-context issue in the weight transfer tests. Changes: * `tests/rl/test_weight_transfer.py`: Fix Arrow Flight tests by adding a shared `job_context` fixture so Server and Client actors run in the same execution context. * `tests/rl/test_curriculum.py`: Update `CurriculumConfig` initialization to include `max_seq_len`; update `RolloutStats` initialization to include `temperature` and `top_k`. * `tests/rl/test_replay_buffer.py`: Update `Rollout` initialization to include `top_k=None`; fix batch shape assertions. * `tests/rl/test_train_batch.py`: Add missing `pad_to` argument to `convert_rollout_to_training_format` and `create_training_batch_from_rollouts`; update expected dictionary keys accordingly. * `tests/rl/environments/*.py`: Update `DummyInferenceContext.batch_completions` signature to accept the `top_k` argument.
experiments/exp2039_rl_math500.py
Outdated
| @@ -243,7 +324,7 @@ def rl_train(name: str, experiment_config: ExperimentConfig) -> ExecutorStep: | |||
| mode=WeightTransferMode.ARROW_FLIGHT, | |||
| sync_interval_steps=1, | |||
| # We are running on-policy, so wait for new weights from the trainer after each episode. | |||
| max_weight_transfer_wait_time=120, | |||
| max_weight_transfer_wait_time=300, | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we make a rl_macro.py file or something explaining where all these magic numbers come from? or just define the macros up top with an explainer one or two sentences
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, addressed with a refactoring
|
|
||
| try: | ||
| from vllm import LLM, SamplingParams | ||
| from vllm.outputs import RequestOutput |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
… Llama 3.1 8B Instruct
RL accuracy maintained after recent merges and refactoringsAs expected, MATH-500 run started at 0.29 and quickly reached close to 0.5 (0.488) in 18 steps.
test commit |
…und torchvision wheel issue astral-sh/uv#16386 (comment)
…env.py that indirectly depends on vllm to pass on CI/cpu
c8abda0 to
e261bff
Compare
| """Parameters for sampling rollouts from an environment.""" | ||
|
|
||
| temperature: float = 1.0 | ||
| top_k: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we had a top k bug, do we want to set a warning or something here? I don't know when you'd ever want greedy decoding for RL training
| n_generations: Number of generations per example | ||
| temperature: Sampling temperature for generation | ||
| prng_key: JAX random key for sampling | ||
| mode: "train" or "eval" - which dataset to sample from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe base env is a good place to add a warning
|
|
||
| import equinox as eqx | ||
| import jax | ||
| import jax.numpy as jnp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GENERAL comment:
We should factorize this more. Instead of having commented out lines in PPO, just make different methods for DAPO, PPO, GRPO etc. this way we can more easily test them against one another.
normally code re-use is good but for RL it can bite us so a bit of repetition is fine here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be addressed in #2327
| # loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / max_output_tokens) | ||
|
|
||
| # more like DAPO loss | ||
| loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the PPO loss is normalized incorrectly.
We dividing each example’s objective by the total number of unmasked tokens across the whole batch, and then we also average across the batch. That effectively divides by batch size twice, so the loss/gradients get smaller as you increase batch size
We should either normalize per-example by that example’s token count and then average, or do a single global token-average over the batch.
| loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks)) | |
| loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting this! Could you check if 7d8f7fa dresses this issue?
Reverts changes to lib/fray/src/fray/job/context.py to match main. Refactors downstream usage of .call() in rollout_worker.py, train_worker.py, and curriculum.py to use .remote() and get_default_job_ctx().get() instead.
Recommended Review & Merge OrderTo facilitate a smooth review and merge process, I recommend reviewing these sub-PRs in the following order (foundation first, then logic, then specific experiments/features):
All changes from |
Rerunning MATH-500 after sub-PR merges:actor wandb, trainer wandb The above run got stuck after 60 steps when the rollout worker kept dying. Rerunning again with some QoL updates: |
MATH-500 RL Training Regression AnalysisSummaryA change to the DAPO loss normalization in Context: Original Review CommentThis change was suggested in a PR review comment:
The concern about "dividing by batch size twice" is valid mathematically, but under AdamW this doesn't matter — AdamW normalizes gradients per-parameter, so constant scaling factors cancel out. What does matter is the relative weighting between examples, which the change altered. Evidence from WandBComparing runs:
Root CauseThe change was in # Per-example normalization (introduced regression)
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks, axis=1))
# ^^^^^^^^^^^^^^^^^^^^^^^^
# Per-example token count
# Global normalization (original behavior)
loss = -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks))
# ^^^^^^^^^^^^^^^^^
# Total batch tokensWhy This MattersThe two formulas represent different objectives:
Setup:
Global normalization: → Longer responses get proportionally more gradient signal Per-example normalization: → All responses contribute equally regardless of length Concrete example:
With global normalization (N=600), example B has 5x more gradient influence because it's 5x longer. With per-example normalization, both contribute equally. Impact on Math ReasoningFor math reasoning, global normalization works better empirically because:
The WandB evidence shows this matters: -13% accuracy when switching to per-example normalization. FixRestored original DAPO loss normalization in def compute_dapo_loss(loss_objective, loss_masks):
"""Compute DAPO-like loss (global token normalization)."""
return -1 * jnp.mean(jnp.sum(loss_objective * loss_masks, axis=1) / jnp.sum(loss_masks))Why Global Normalization Works Better HereThe original reviewer's concern was that we "divide by batch size twice." Let's trace through: loss = -1 * mean(sum(L_i) / N) # N = total tokens across batch
= -1 * (1/B) * sum(L_i / N) # B = batch sizeThis does produce smaller gradients with larger batches. However:
Per-example normalization may be preferable for other tasks where response length shouldn't affect gradient weight. Reference: How Tinker Handles ThisFor comparison, Tinker's importance_sampling loss uses sum reduction without explicit normalization: loss = -(prob_ratio * advantages).sum()From their documentation:
Tinker delegates normalization to the advantage computation, which uses simple mean-centering within groups: # tinker-cookbook/tinker_cookbook/rl/data_processing.py
def compute_advantages(trajectory_groups_P):
for traj_group in trajectory_groups_P:
rewards_G = torch.tensor(traj_group.get_total_rewards())
advantages_G = rewards_G - rewards_G.mean() # Center within groupThis approach gives longer responses more gradient signal (similar to our global normalization), since the sum is not divided by sequence length. Verified to fix regressionBefore the regression fix, pass@1 maxed out at 0.45 and is less stable (dipped to 0.35 at step 17): After applying the fix which reverts advantage back to the per-example normalization, we reached 0.486 pass@1 on MATH-500 in 23 steps and stablizes around 0.47 for the next ~100 steps, closely matching the behavior observed before.
actor wandb, trainer wandb |
…2376) This PR fixes a regression in the DAPO loss computation by switching from per-example normalization (/ n_i) back to global token normalization (/ N). Per-example normalization gives shorter responses disproportionately more gradient weight, which hurts math reasoning tasks where correct answers often require detailed, longer derivations. Global normalization weights all examples equally regardless of response length. Check out #2039 (comment) for full context and experimental validation. Co-authored-by: Claude Opus 4.5 <[email protected]>



Description
In-flight weight updates
DAPO
Loss function
Zero-variance prompt filtering
Length penalty
Checklist
uv run python infra/pre-commit.py --all-filesto lint/format your code