Skip to content

Commit 38a18fa

Browse files
authored
Merge pull request #45 from ku-nlp/add/wandb
Upload model config to wandb
2 parents cded8e9 + 3c96aa6 commit 38a18fa

File tree

3 files changed

+51
-7
lines changed

3 files changed

+51
-7
lines changed

llm_judge/common.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import re
77
import time
88
from pathlib import Path
9-
from typing import Union, Optional
9+
from typing import Optional, Union
1010

1111
import openai
1212
import tiktoken
@@ -257,6 +257,16 @@ def load_model_answers(answer_dir: Union[str, Path]):
257257
return answers
258258

259259

260+
def load_model_config(answer_dir: Union[str, Path]):
261+
"""Load model config.
262+
263+
Args:
264+
answer_dir (Union[str, Path]): The answer directory.
265+
"""
266+
with open(Path(answer_dir) / "config.json", "r") as fin:
267+
return json.load(fin)
268+
269+
260270
def load_judgements(judgement_dir: Union[str, Path]):
261271
"""Load judgements.
262272

llm_judge/show_result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import pandas as pd
66
from common import (
77
JUDGEMENT_DIR,
8-
load_judgements,
9-
filter_single_judgements,
108
filter_pairwise_judgements,
9+
filter_single_judgements,
10+
load_judgements,
1111
)
1212

1313
logger = logging.getLogger(__name__)

llm_judge/upload_result.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,48 @@
33
import os
44

55
import pandas as pd
6-
import wandb
76
from common import (
87
JUDGEMENT_DIR,
9-
load_judgements,
10-
filter_single_judgements,
8+
PREDICTION_DIR,
119
filter_pairwise_judgements,
10+
filter_single_judgements,
11+
load_judgements,
12+
load_model_config,
1213
)
1314
from show_result import calculate_win_rate
1415

16+
import wandb
17+
1518
logger = logging.getLogger(__name__)
1619

1720

21+
def get_run_config_from_result(mode: str, result: dict) -> dict:
22+
if mode == "single":
23+
model = result["model"]
24+
model_config = load_model_config(PREDICTION_DIR / model)
25+
judge_model = result["judge_model"]
26+
return {
27+
"mode": mode,
28+
"judge_model": judge_model,
29+
"model": model,
30+
"model_config": model_config,
31+
}
32+
else:
33+
model_1 = result["model_1"]
34+
model_1_config = load_model_config(PREDICTION_DIR / model_1)
35+
model_2 = result["model_2"]
36+
model_2_config = load_model_config(PREDICTION_DIR / model_2)
37+
judge_model = result["judge_model"]
38+
return {
39+
"mode": mode,
40+
"judge_model": judge_model,
41+
"model_1": model_1,
42+
"model_2": model_2,
43+
"model_1_config": model_1_config,
44+
"model_2_config": model_2_config,
45+
}
46+
47+
1848
def upload_results(
1949
mode: str,
2050
result_id: str,
@@ -30,7 +60,11 @@ def upload_results(
3060
baseline_model: Baseline model name. Only used in `pairwise-baseline` mode.
3161
"""
3262
project = os.getenv("WANDB_PROJECT", "ja-vicuna-qa-benchmark")
33-
run = wandb.init(project=project, name=result_id, reinit=True)
63+
if len(results) == 0:
64+
logger.warning(f"No results found for {result_id}")
65+
return
66+
config = get_run_config_from_result(mode, results[0])
67+
run = wandb.init(project=project, name=result_id, config=config, reinit=True)
3468

3569
table_prefix = mode
3670
if mode == "pairwise-baseline":

0 commit comments

Comments
 (0)