- 
                Notifications
    
You must be signed in to change notification settings  - Fork 166
 
          fix: fix checkpointing when val_period does not divide save_period
          #1229
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: ashors1 <[email protected]>
          
📝 WalkthroughWalkthroughCheckpointing behavior is updated across DPO, GRPO, and SFT training to stop mutating the configured checkpoint metric when it’s missing and to simplify warning messages. Checkpoint sorting now safely handles missing metrics via dictionary get with default sentinel values, removing KeyError handling and fallback paths. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant Trainer
  participant Algo as DPO/GRPO/SFT
  participant Ckpt as CheckpointManager
  rect rgba(230,240,255,0.5)
  note over Trainer,Algo: New flow on missing metric
  Trainer->>Algo: train_step()
  Algo->>Ckpt: save_checkpoint(save_state, metric_name)
  alt metric_name set AND metric missing
    Ckpt-->Algo: compute rank using get(metric, ±inf)
    Algo-->>Trainer: Warn "This checkpoint will not be saved as top-k."
    note right of Algo: metric_name is NOT mutated
  else metric present or metric_name None
    Ckpt-->Algo: sort normally (by metric or step)
  end
  end
    sequenceDiagram
  autonumber
  participant Trainer
  participant Algo as DPO/GRPO/SFT
  participant Ckpt as CheckpointManager
  rect rgba(255,240,230,0.5)
  note over Trainer,Algo: Previous flow on missing metric
  Trainer->>Algo: train_step()
  Algo->>Ckpt: save_checkpoint(save_state, metric_name)
  Ckpt--x Algo: KeyError during metric sort
  Algo-->>Trainer: Warn about fallback
  note right of Algo: metric_name set to None (mutated)
  Ckpt-->Algo: fallback to recent-k by step
  end
    Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (3 warnings)
 ✅ Passed checks (1 passed)
 ✨ Finishing touches
 🧪 Generate unit tests
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🧪 Early access (Sonnet 4.5): enabledWe are currently testing the Sonnet 4.5 model, which is expected to improve code review quality. However, this model may lead to increased noise levels in the review comments. Please disable the early access features if the noise level causes any inconvenience. Note: 
 Comment   | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
nemo_rl/algorithms/grpo.py (1)
1669-1673: Inconsistent warning message in async_grpo_train.The warning message at line 1671 still states "Saving most recent k checkpoints instead," which differs from the updated message at line 877 in the synchronous
grpo_trainfunction ("This checkpoint will not be saved as top-k."). Additionally, line 1673 still mutatesmaster_config["checkpointing"]["metric_name"]by setting it toNone, which was removed in the synchronous version.For consistency with the rest of this PR, update the async path to match the synchronous path:
- Change the warning message to match line 877
 - Remove the
 metric_namemutation at line 1673Apply this diff:
warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." + "This checkpoint will not be saved as top-k." ) - master_config["checkpointing"]["metric_name"] = None
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
nemo_rl/algorithms/dpo.py(1 hunks)nemo_rl/algorithms/grpo.py(1 hunks)nemo_rl/algorithms/sft.py(1 hunks)nemo_rl/utils/checkpoint.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/algorithms/grpo.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/dpo.pynemo_rl/utils/checkpoint.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/grpo.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/dpo.pynemo_rl/utils/checkpoint.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
 - GitHub Check: Post automodel integration comment / Comment on PR
 - GitHub Check: Post submodule check comment / Comment on PR
 
🔇 Additional comments (3)
nemo_rl/algorithms/grpo.py (1)
875-878: LGTM: Warning message updated and metric_name mutation removed.The change correctly updates the warning message and removes the mutation of
metric_name, which aligns with the updated checkpoint sorting logic incheckpoint.pythat now handles missing metrics using safe.get()calls with sentinel defaults.nemo_rl/algorithms/dpo.py (1)
629-637: LGTM: Consistent warning and no metric_name mutation.The change correctly updates the warning message to match the other algorithm files and removes the mutation of
metric_name. This is consistent with the updated checkpoint sorting logic that handles missing metrics safely.nemo_rl/algorithms/sft.py (1)
502-510: LGTM: Consistent warning and no metric_name mutation.The change correctly updates the warning message to match the other algorithm files and removes the mutation of
metric_name. This is consistent with the updated checkpoint sorting logic that handles missing metrics safely.
Signed-off-by: ashors1 <[email protected]>
ef13231    to
    810de60      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ashors1 can you add a test that tests the possibility that some steps don't have val metrics?
RL/tests/unit/utils/test_checkpoint.py
Lines 85 to 141 in 8003918
| def test_remove_old_checkpoints(checkpoint_manager, checkpoint_dir): | |
| # Create multiple checkpoints with different loss values | |
| steps = [1, 2, 3, 4, 5, 6] | |
| losses = [0.5, 0.3, 0.7, 0.2, 0.4, 0.8] | |
| for step, loss in zip(steps, losses): | |
| training_info = {"loss": loss} | |
| tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) | |
| checkpoint_manager.finalize_checkpoint(tmp_dir) | |
| # Check if only top-k checkpoints are kept | |
| remaining_dirs = list(checkpoint_dir.glob("step_*")) | |
| assert ( | |
| len(remaining_dirs) == checkpoint_manager.keep_top_k + 1 | |
| ) # +1 because we exclude the latest | |
| # Verify the remaining checkpoints are the ones with lowest loss | |
| remaining_losses = [] | |
| for dir_path in remaining_dirs: | |
| with open(dir_path / "training_info.json", "r") as f: | |
| metadata = json.load(f) | |
| remaining_losses.append(metadata["loss"]) | |
| assert sorted(remaining_losses) == sorted(losses)[ | |
| : checkpoint_manager.keep_top_k | |
| ] + [0.8] # exclude latest | |
| def test_remove_old_checkpoints_topk_bias_recent_if_equal( | |
| checkpoint_manager, checkpoint_dir | |
| ): | |
| # Create multiple checkpoints with the same loss value | |
| # Create multiple checkpoints with the same loss value | |
| steps = [1, 2, 3, 4, 10, 12] | |
| losses = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5] # All checkpoints have the same loss | |
| for step, loss in zip(steps, losses): | |
| training_info = {"loss": loss} | |
| tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) | |
| checkpoint_manager.finalize_checkpoint(tmp_dir) | |
| # Check if only top-k checkpoints are kept | |
| remaining_dirs = list(checkpoint_dir.glob("step_*")) | |
| assert ( | |
| len(remaining_dirs) == checkpoint_manager.keep_top_k | |
| ) # +1 because we exclude the latest | |
| # When all losses are equal, the most recent checkpoints should be kept | |
| # (excluding the latest which is always kept) | |
| remaining_steps = [] | |
| for dir_path in remaining_dirs: | |
| step_num = int(dir_path.name.split("_")[1]) | |
| remaining_steps.append(step_num) | |
| # Should keep the most recent checkpoints (highest step numbers) | |
| expected_steps = sorted(steps)[-checkpoint_manager.keep_top_k :] | |
| assert sorted(remaining_steps) == sorted(expected_steps) | 
What happens if save_period and val_period are not divisible, does that mean some metrics will have inf and -inf and get pruned out? Correct me if I've read that wrong.
I think regarding the top_k to save, we should have a test that guards that latest ckpt even if it doesn't have a val metric
          
 Will do 
 If we have some checkpoints which don't have an associated val metric, all of those checkpoints should get the same default metric value (  | 
    
Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
What does this PR do ?
Previously, if
save_period % val_period != 0, we would simply save most recentkcheckpoints. With this change, if fewer thankcheckpoints have validation metrics (say,m), we fallback to saving the most recentk-mcheckpoints in addition to themcheckpoints with validation metrics.Add a one line overview of what this PR aims to accomplish.
Issues
closes #1214
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit