Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,8 @@ def dpo_train(
):
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"Saving most recent k checkpoints instead."
"This checkpoint will not be saved as top-k."
)
master_config["checkpointing"]["metric_name"] = None

with timer.time("checkpointing"):
print(f"Saving checkpoint for step {total_steps + 1}...")
Expand Down
3 changes: 1 addition & 2 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,9 +874,8 @@ def grpo_train(
):
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"Saving most recent k checkpoints instead."
"This checkpoint will not be saved as top-k."
)
master_config["checkpointing"]["metric_name"] = None

with timer.time("checkpointing"):
print(
Expand Down
3 changes: 1 addition & 2 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,8 @@ def sft_train(
):
warnings.warn(
f"You asked to save checkpoints based on {master_config['checkpointing']['metric_name']} but the metric is not found in the save state. "
"Saving most recent k checkpoints instead."
"This checkpoint will not be saved as top-k."
)
master_config["checkpointing"]["metric_name"] = None

with timer.time("checkpointing"):
print(f"Saving checkpoint for step {total_steps + 1}...")
Expand Down
29 changes: 11 additions & 18 deletions nemo_rl/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,18 @@ def remove_old_checkpoints(self, exclude_latest: bool = True) -> None:
if self.metric_name is None:
checkpoint_history.sort(key=lambda x: x[0], reverse=True)
else:
try:
# sort by metric value first, then by step number (for equal metrics, prefer more recent)
if self.higher_is_better:
# For higher_is_better=True: higher metric values first, then higher step numbers
checkpoint_history.sort(
key=lambda x: (x[2][self.metric_name], x[0]), reverse=True
)
else:
# For higher_is_better=False: lower metric values first, then higher step numbers for equal values
checkpoint_history.sort(
key=lambda x: (x[2][self.metric_name], -x[0])
)
except KeyError:
warnings.warn(
f"Metric {self.metric_name} not found in checkpoint history. Keeping most recent k checkpoints."
# sort by metric value first, then by step number (for equal metrics, prefer more recent)
if self.higher_is_better:
# For higher_is_better=True: higher metric values first, then higher step numbers
checkpoint_history.sort(
key=lambda x: (x[2].get(self.metric_name, -float("inf")), x[0]),
reverse=True,
)
else:
# For higher_is_better=False: lower metric values first, then higher step numbers for equal values
checkpoint_history.sort(
key=lambda x: (x[2].get(self.metric_name, float("inf")), -x[0])
)
checkpoint_history.sort(key=lambda x: x[0], reverse=True)

self.metric_name = None

# remove checkpoints that are not in the top-k
for checkpoint in checkpoint_history[self.keep_top_k :]:
Expand Down
79 changes: 79 additions & 0 deletions tests/unit/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,85 @@ def test_remove_old_checkpoints_topk_bias_recent_if_equal(
assert sorted(remaining_steps) == sorted(expected_steps)


def test_remove_old_checkpoints_topk_some_missing_val_metric(
checkpoint_manager, checkpoint_dir
):
# Create checkpoints where some have validation metrics and others don't
steps = [1, 2, 3, 4, 10, 11, 12]
# Some checkpoints have loss metrics, others don't have any validation metrics
training_infos = [
{"loss": 0.5}, # step 1 - has loss
{"loss": 0.3}, # step 2 - has loss
{"other_metric": 0.8}, # step 3 - missing loss metric
{"loss": 0.2}, # step 4 - has loss
{}, # step 10 - missing loss metric
{"loss": 1.0}, # has loss but not in top-k
{}, # step 12 - missing loss (latest)
]

for step, training_info in zip(steps, training_infos):
tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info)
checkpoint_manager.finalize_checkpoint(tmp_dir)

# Check if only top-k checkpoints are kept
remaining_dirs = list(checkpoint_dir.glob("step_*"))
assert (
len(remaining_dirs) == checkpoint_manager.keep_top_k + 1
) # +1 because we exclude the latest

# Checkpoints with missing validation metrics should be treated as having the worst possible value
# Since higher_is_better=False, missing metrics get float("inf") which is worst
# So checkpoints with actual loss values should be preferred over those without
remaining_steps = []
for dir_path in remaining_dirs:
step_num = int(dir_path.name.split("_")[1])
remaining_steps.append(step_num)

# Should keep checkpoints with actual loss values (steps 1, 2, 4, 12)
# and exclude those without loss metrics (steps 3, 10)
# The latest checkpoint (step 12) is always kept
expected_steps = [1, 2, 4, 12] # Steps with loss metrics, plus latest
assert sorted(remaining_steps) == sorted(expected_steps)


def test_remove_old_checkpoints_topk_most_missing_val_metric(
checkpoint_manager, checkpoint_dir
):
# Create checkpoints where some have validation metrics and others don't
steps = [1, 2, 3, 4, 10, 12]
# Some checkpoints have loss metrics, others don't have any validation metrics
training_infos = [
{"loss": 0.2}, # step 1 - has loss
{}, # step 2 - has loss
{"other_metric": 0.8}, # step 3 - missing loss metric
{}, # step 4 - has loss
{}, # step 10 - missing loss metric
{}, # step 12 - missing loss (latest)
]

for step, training_info in zip(steps, training_infos):
tmp_dir = checkpoint_manager.init_tmp_checkpoint(step, training_info)
checkpoint_manager.finalize_checkpoint(tmp_dir)

# Check if only top-k checkpoints are kept
remaining_dirs = list(checkpoint_dir.glob("step_*"))
assert len(remaining_dirs) == checkpoint_manager.keep_top_k

# Checkpoints with missing validation metrics should be treated as having the worst possible value
# Since higher_is_better=False, missing metrics get float("inf") which is worst
# So checkpoints with actual loss values should be preferred over those without
remaining_steps = []
for dir_path in remaining_dirs:
step_num = int(dir_path.name.split("_")[1])
remaining_steps.append(step_num)

# Should keep checkpoints with actual loss values (step 1)
# followed by the most recent steps
# The latest checkpoint (step 12) is always kept
expected_steps = [1, 10, 12] # Steps with loss metrics, plus latest
assert sorted(remaining_steps) == sorted(expected_steps)


def test_get_best_checkpoint_path(checkpoint_manager, checkpoint_dir):
# Create multiple checkpoints with different loss values
steps = [1, 2, 3]
Expand Down
Loading