-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy path04_gen_preferences.py
65 lines (51 loc) · 1.61 KB
/
04_gen_preferences.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
import os
import sys
import uuid
if len(sys.argv) != 3:
print("Usage: python 04_gen_preferences.py <scores.jsonl> <preferences.jsonl>")
exit()
scores_file = sys.argv[1]
preferences_file = sys.argv[2]
# Group all the prompts by prompt_id
prompts = {}
with open(scores_file, "r") as f:
for line in f:
row = json.loads(line)
prompt_id = row['prompt_id']
if prompt_id not in prompts:
prompts[prompt_id] = []
prompts[row['prompt_id']].append(row)
# Iterate over prompts and look at high and low scores to generate preference pairs
# if the score is the same, skip
pairs = []
for prompt_id, prompts in prompts.items():
# find the best score
best_score = -1
best_prompt = None
for prompt in prompts:
if prompt['score'] > best_score:
best_score = prompt['score']
best_prompt = prompt
# find the worst score
worst_score = 100
worst_prompt = None
for prompt in prompts:
if prompt['score'] < worst_score:
worst_score = prompt['score']
worst_prompt = prompt
if None == best_prompt or None == worst_prompt:
continue
if best_score == worst_score:
continue
pairs.append({
"prompt_id": best_prompt['prompt_id'],
"prompt": best_prompt['prompt'],
"chosen": best_prompt['completion'],
"rejected": worst_prompt['completion'],
"score_chosen": best_prompt['score'],
"score_rejected": worst_prompt['score']
})
with open(preferences_file, "w") as f:
for line in pairs:
f.write(json.dumps(line) + "\n")