Skip to content

Commit ffe9c94

Browse files
committed
makes requested changes
1 parent d96e395 commit ffe9c94

File tree

1 file changed

+70
-76
lines changed

1 file changed

+70
-76
lines changed

scripts/evaluate_best_checkpoint.py

Lines changed: 70 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,59 @@
1010
# Standard
1111
from pathlib import Path
1212
from typing import Optional
13-
from typing_extensions import Annotated
1413
import json
1514

1615
# Third Party
1716
from rich import print
17+
from typing_extensions import Annotated
1818
import typer
1919

2020
app = typer.Typer()
2121

2222

23+
def print_metrics(result: dict, checkpoint_name: str = None, prefix: str = ""):
24+
"""
25+
Print formatted metrics for a checkpoint result.
26+
27+
Args:
28+
result: The evaluation result dictionary
29+
checkpoint_name: Optional checkpoint name to display
30+
prefix: Optional prefix for each line
31+
"""
32+
if checkpoint_name:
33+
print(f"{prefix}[bold]Leaderboard results[/bold]: {checkpoint_name}")
34+
print(f"{prefix}Overall: {result['overall_score'] * 100:.2f}%")
35+
if "leaderboard_bbh" in result:
36+
print(f"{prefix}BBH: {result['leaderboard_bbh']['score'] * 100:.2f}%")
37+
if "leaderboard_gpqa" in result:
38+
print(f"{prefix}GPQA: {result['leaderboard_gpqa']['score'] * 100:.2f}%")
39+
if "leaderboard_ifeval" in result:
40+
print(f"{prefix}IFEval: {result['leaderboard_ifeval']['score'] * 100:.2f}%")
41+
if "leaderboard_math_hard" in result:
42+
print(
43+
f"{prefix}MATH-Hard: {result['leaderboard_math_hard']['score'] * 100:.2f}%"
44+
)
45+
if "leaderboard_mmlu_pro" in result:
46+
print(f"{prefix}MMLU-Pro: {result['leaderboard_mmlu_pro']['score'] * 100:.2f}%")
47+
if "leaderboard_musr" in result:
48+
print(f"{prefix}MUSR: {result['leaderboard_musr']['score'] * 100:.2f}%")
49+
50+
2351
@app.command()
2452
def best_checkpoint(
2553
input_dir: Path = typer.Argument(..., help="Input directory to process"),
2654
output_file: Optional[Path] = typer.Option(None, help="Optional output file path"),
27-
tasks: Annotated[Optional[list[str]], typer.Option()] = None,
55+
tasks: Annotated[
56+
Optional[list[str]],
57+
typer.Option(
58+
help="Specific tasks to evaluate (e.g., 'leaderboard_bbh', 'leaderboard_gpqa')"
59+
),
60+
] = None,
61+
num_gpus: int = typer.Option(8, help="Number of GPUs to use for evaluation"),
2862
):
2963
"""
30-
Process files in the input directory and optionally save results to an output file.
64+
Find the best checkpoint by evaluating all checkpoints in the input directory.
65+
Processes all checkpoint subdirectories and ranks them by overall score.
3166
"""
3267
if not input_dir.exists():
3368
typer.echo(f"Error: Input directory '{input_dir}' does not exist")
@@ -55,7 +90,7 @@ def best_checkpoint(
5590
typer.echo(f"Processing checkpoint: {checkpoint}")
5691
ckpt_output_file = checkpoint / "leaderboard_results.json"
5792
evaluator = LeaderboardV2Evaluator(
58-
model_path=str(checkpoint), output_file=ckpt_output_file, num_gpus=8
93+
model_path=str(checkpoint), output_file=ckpt_output_file, num_gpus=num_gpus
5994
)
6095
if tasks:
6196
evaluator.tasks = tasks
@@ -72,28 +107,12 @@ def best_checkpoint(
72107
typer.echo(f"{'=' * 100}")
73108
# Add [BEST CHECKPOINT] label for the first checkpoint
74109
if i == 0:
75-
typer.echo(
76-
f"[bold]Leaderboard results[/bold]: {checkpoint_name} [bold green][BEST CHECKPOINT][/bold green]"
110+
checkpoint_display = (
111+
f"{checkpoint_name} [bold green][BEST CHECKPOINT][/bold green]"
77112
)
78113
else:
79-
typer.echo(f"[bold]Leaderboard results[/bold]: {checkpoint_name}")
80-
typer.echo(f"Overall: {result['overall_score'] * 100:.2f}%")
81-
if "leaderboard_bbh" in result:
82-
typer.echo(f"BBH: {result['leaderboard_bbh']['score'] * 100:.2f}%")
83-
if "leaderboard_gpqa" in result:
84-
typer.echo(f"GPQA: {result['leaderboard_gpqa']['score'] * 100:.2f}%")
85-
if "leaderboard_ifeval" in result:
86-
typer.echo(f"IFEval: {result['leaderboard_ifeval']['score'] * 100:.2f}%")
87-
if "leaderboard_math_hard" in result:
88-
typer.echo(
89-
f"MATH-Hard: {result['leaderboard_math_hard']['score'] * 100:.2f}%"
90-
)
91-
if "leaderboard_mmlu_pro" in result:
92-
typer.echo(
93-
f"MMLU-Pro: {result['leaderboard_mmlu_pro']['score'] * 100:.2f}%"
94-
)
95-
if "leaderboard_musr" in result:
96-
typer.echo(f"MUSR: {result['leaderboard_musr']['score'] * 100:.2f}%")
114+
checkpoint_display = checkpoint_name
115+
print_metrics(result, checkpoint_display)
97116

98117
typer.echo(f"{'=' * 100}")
99118
typer.echo(
@@ -113,10 +132,20 @@ def best_checkpoint(
113132
@app.command()
114133
def evaluate(
115134
input_dir: Path = typer.Argument(..., help="Input directory to process"),
116-
tasks: Annotated[Optional[list[str]], typer.Option()] = None,
135+
tasks: Annotated[
136+
Optional[list[str]],
137+
typer.Option(
138+
help="Specific tasks to evaluate (e.g., 'leaderboard_bbh', 'leaderboard_gpqa')"
139+
),
140+
] = None,
141+
num_gpus: int = typer.Option(8, help="Number of GPUs to use for evaluation"),
142+
output_file: Optional[Path] = typer.Option(
143+
None,
144+
help="Custom output file path (default: input_dir/leaderboard_results.json)",
145+
),
117146
):
118147
"""
119-
Process files in the input directory and optionally save results to an output file.
148+
Evaluate a single checkpoint directory and save results to JSON file.
120149
"""
121150
if not input_dir.exists():
122151
typer.echo(f"Error: Input directory '{input_dir}' does not exist")
@@ -133,30 +162,27 @@ def evaluate(
133162
typer.echo("done")
134163

135164
evaluator = LeaderboardV2Evaluator(
136-
model_path=str(input_dir), num_gpus=8, eval_config={"batch_size": "auto"}
165+
model_path=str(input_dir), num_gpus=num_gpus, eval_config={"batch_size": "auto"}
137166
)
138167
if tasks:
139168
evaluator.tasks = tasks
140169
result = evaluator.run()
141170

142171
# now just print out the checkpoint results
143-
print(f"[bold]Leaderboard results[/bold]: {input_dir}")
144-
print(f"Overall: {result['overall_score'] * 100:.2f}%")
145-
if "leaderboard_bbh" in result:
146-
print(f"BBH: {result['leaderboard_bbh']['score'] * 100:.2f}%")
147-
if "leaderboard_gpqa" in result:
148-
print(f"GPQA: {result['leaderboard_gpqa']['score'] * 100:.2f}%")
149-
if "leaderboard_ifeval" in result:
150-
print(f"IFEval: {result['leaderboard_ifeval']['score'] * 100:.2f}%")
151-
if "leaderboard_math_hard" in result:
152-
print(f"MATH-Hard: {result['leaderboard_math_hard']['score'] * 100:.2f}%")
153-
if "leaderboard_mmlu_pro" in result:
154-
print(f"MMLU-Pro: {result['leaderboard_mmlu_pro']['score'] * 100:.2f}%")
155-
if "leaderboard_musr" in result:
156-
print(f"MUSR: {result['leaderboard_musr']['score'] * 100:.2f}%")
172+
print_metrics(result, str(input_dir))
173+
174+
# Determine output file path
175+
if output_file is None:
176+
output_file = input_dir / "leaderboard_results.json"
177+
178+
# Check if file exists and warn user
179+
if output_file.exists():
180+
typer.echo(
181+
f"Warning: Output file '{output_file}' already exists and will be overwritten"
182+
)
157183

158-
output_file = input_dir / "leaderboard_results.json"
159184
output_file.write_text(json.dumps(result, indent=2))
185+
typer.echo(f"Results saved to: {output_file}")
160186

161187

162188
@app.command()
@@ -218,43 +244,11 @@ def find_best(
218244
is_best = checkpoint == best_checkpoint
219245
prefix = "→ " if is_best else " "
220246
print(f"\n{prefix}Checkpoint: {checkpoint}")
221-
print(f" Overall score: {score * 100:.2f}%")
222-
if "leaderboard_bbh" in results:
223-
print(f" BBH: {results['leaderboard_bbh']['score'] * 100:.2f}%")
224-
if "leaderboard_gpqa" in results:
225-
print(f" GPQA: {results['leaderboard_gpqa']['score'] * 100:.2f}%")
226-
if "leaderboard_ifeval" in results:
227-
print(f" IFEval: {results['leaderboard_ifeval']['score'] * 100:.2f}%")
228-
if "leaderboard_math_hard" in results:
229-
print(
230-
f" MATH-Hard: {results['leaderboard_math_hard']['score'] * 100:.2f}%"
231-
)
232-
if "leaderboard_mmlu_pro" in results:
233-
print(
234-
f" MMLU-Pro: {results['leaderboard_mmlu_pro']['score'] * 100:.2f}%"
235-
)
236-
if "leaderboard_musr" in results:
237-
print(f" MUSR: {results['leaderboard_musr']['score'] * 100:.2f}%")
247+
print_metrics(results, prefix=" ")
238248
else:
239249
# Print only best results
240250
print(f"\n[bold]Best checkpoint found[/bold]: {best_checkpoint}")
241-
print(f"Overall score: {best_score * 100:.2f}%")
242-
if "leaderboard_bbh" in best_results:
243-
print(f"BBH: {best_results['leaderboard_bbh']['score'] * 100:.2f}%")
244-
if "leaderboard_gpqa" in best_results:
245-
print(f"GPQA: {best_results['leaderboard_gpqa']['score'] * 100:.2f}%")
246-
if "leaderboard_ifeval" in best_results:
247-
print(f"IFEval: {best_results['leaderboard_ifeval']['score'] * 100:.2f}%")
248-
if "leaderboard_math_hard" in best_results:
249-
print(
250-
f"MATH-Hard: {best_results['leaderboard_math_hard']['score'] * 100:.2f}%"
251-
)
252-
if "leaderboard_mmlu_pro" in best_results:
253-
print(
254-
f"MMLU-Pro: {best_results['leaderboard_mmlu_pro']['score'] * 100:.2f}%"
255-
)
256-
if "leaderboard_musr" in best_results:
257-
print(f"MUSR: {best_results['leaderboard_musr']['score'] * 100:.2f}%")
251+
print_metrics(best_results)
258252

259253

260254
if __name__ == "__main__":

0 commit comments

Comments
 (0)