diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 6beb91cfac..4b961251e5 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -633,9 +633,8 @@ def dpo_train( ): warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." + "This checkpoint will not be saved as top-k." ) - master_config["checkpointing"]["metric_name"] = None with timer.time("checkpointing"): print(f"Saving checkpoint for step {total_steps + 1}...") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 190f3c2921..2ef60a73e0 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -874,9 +874,8 @@ def grpo_train( ): warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." + "This checkpoint will not be saved as top-k." ) - master_config["checkpointing"]["metric_name"] = None with timer.time("checkpointing"): print( diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 39e593e304..2b87e606d0 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -506,9 +506,8 @@ def sft_train( ): warnings.warn( f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. " - "Saving most recent k checkpoints instead." + "This checkpoint will not be saved as top-k." ) - master_config["checkpointing"]["metric_name"] = None with timer.time("checkpointing"): print(f"Saving checkpoint for step {total_steps + 1}...") diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index 94d83821d8..ca2bab3940 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -202,25 +202,18 @@ def remove_old_checkpoints(self, exclude_latest: bool = True) -> None: if self.metric_name is None: checkpoint_history.sort(key=lambda x: x[0], reverse=True) else: - try: - # sort by metric value first, then by step number (for equal metrics, prefer more recent) - if self.higher_is_better: - # For higher_is_better=True: higher metric values first, then higher step numbers - checkpoint_history.sort( - key=lambda x: (x[2][self.metric_name], x[0]), reverse=True - ) - else: - # For higher_is_better=False: lower metric values first, then higher step numbers for equal values - checkpoint_history.sort( - key=lambda x: (x[2][self.metric_name], -x[0]) - ) - except KeyError: - warnings.warn( - f"Metric {self.metric_name} not found in checkpoint history. Keeping most recent k checkpoints." + # sort by metric value first, then by step number (for equal metrics, prefer more recent) + if self.higher_is_better: + # For higher_is_better=True: higher metric values first, then higher step numbers + checkpoint_history.sort( + key=lambda x: (x[2].get(self.metric_name, -float("inf")), x[0]), + reverse=True, + ) + else: + # For higher_is_better=False: lower metric values first, then higher step numbers for equal values + checkpoint_history.sort( + key=lambda x: (x[2].get(self.metric_name, float("inf")), -x[0]) ) - checkpoint_history.sort(key=lambda x: x[0], reverse=True) - - self.metric_name = None # remove checkpoints that are not in the top-k for checkpoint in checkpoint_history[self.keep_top_k :]: diff --git a/tests/unit/utils/test_checkpoint.py b/tests/unit/utils/test_checkpoint.py index a698daf853..002524cc71 100644 --- a/tests/unit/utils/test_checkpoint.py +++ b/tests/unit/utils/test_checkpoint.py @@ -141,6 +141,85 @@ def test_remove_old_checkpoints_topk_bias_recent_if_equal( assert sorted(remaining_steps) == sorted(expected_steps) +def test_remove_old_checkpoints_topk_some_missing_val_metric( + checkpoint_manager, checkpoint_dir +): + # Create checkpoints where some have validation metrics and others don't + steps = [1, 2, 3, 4, 10, 11, 12] + # Some checkpoints have loss metrics, others don't have any validation metrics + training_infos = [ + {"loss": 0.5}, # step 1 - has loss + {"loss": 0.3}, # step 2 - has loss + {"other_metric": 0.8}, # step 3 - missing loss metric + {"loss": 0.2}, # step 4 - has loss + {}, # step 10 - missing loss metric + {"loss": 1.0}, # has loss but not in top-k + {}, # step 12 - missing loss (latest) + ] + + for step, training_info in zip(steps, training_infos): + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Check if only top-k checkpoints are kept + remaining_dirs = list(checkpoint_dir.glob("step_*")) + assert ( + len(remaining_dirs) == checkpoint_manager.keep_top_k + 1 + ) # +1 because we exclude the latest + + # Checkpoints with missing validation metrics should be treated as having the worst possible value + # Since higher_is_better=False, missing metrics get float("inf") which is worst + # So checkpoints with actual loss values should be preferred over those without + remaining_steps = [] + for dir_path in remaining_dirs: + step_num = int(dir_path.name.split("_")[1]) + remaining_steps.append(step_num) + + # Should keep checkpoints with actual loss values (steps 1, 2, 4, 12) + # and exclude those without loss metrics (steps 3, 10) + # The latest checkpoint (step 12) is always kept + expected_steps = [1, 2, 4, 12] # Steps with loss metrics, plus latest + assert sorted(remaining_steps) == sorted(expected_steps) + + +def test_remove_old_checkpoints_topk_most_missing_val_metric( + checkpoint_manager, checkpoint_dir +): + # Create checkpoints where some have validation metrics and others don't + steps = [1, 2, 3, 4, 10, 12] + # Some checkpoints have loss metrics, others don't have any validation metrics + training_infos = [ + {"loss": 0.2}, # step 1 - has loss + {}, # step 2 - has loss + {"other_metric": 0.8}, # step 3 - missing loss metric + {}, # step 4 - has loss + {}, # step 10 - missing loss metric + {}, # step 12 - missing loss (latest) + ] + + for step, training_info in zip(steps, training_infos): + tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info) + checkpoint_manager.finalize_checkpoint(tmp_dir) + + # Check if only top-k checkpoints are kept + remaining_dirs = list(checkpoint_dir.glob("step_*")) + assert len(remaining_dirs) == checkpoint_manager.keep_top_k + + # Checkpoints with missing validation metrics should be treated as having the worst possible value + # Since higher_is_better=False, missing metrics get float("inf") which is worst + # So checkpoints with actual loss values should be preferred over those without + remaining_steps = [] + for dir_path in remaining_dirs: + step_num = int(dir_path.name.split("_")[1]) + remaining_steps.append(step_num) + + # Should keep checkpoints with actual loss values (step 1) + # followed by the most recent steps + # The latest checkpoint (step 12) is always kept + expected_steps = [1, 10, 12] # Steps with loss metrics, plus latest + assert sorted(remaining_steps) == sorted(expected_steps) + + def test_get_best_checkpoint_path(checkpoint_manager, checkpoint_dir): # Create multiple checkpoints with different loss values steps = [1, 2, 3]