-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy path01_gen_prompts.py
123 lines (95 loc) · 3.39 KB
/
01_gen_prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# uses the ift dataset and the model mistralai/Mistral-7B-Instruct-v0.2 to generate prompts
# uses a 8-shot prompt to generate them
import torch
import pandas as pd
from transformers import TextStreamer
import sys
import uuid
from srlm.model import load_model
if len(sys.argv) != 5:
print("Usage: python 01_gen_prompts.py <tokenizer_name> <model_name> <train.jsonl> <prompts.jsonl>")
exit()
base_name = sys.argv[1]
model_name = sys.argv[2]
ift_dataset_file = sys.argv[3]
generated_prompts_file = sys.argv[4]
device = "cuda" # the device to load the model onto
num_prompts_to_generate=1000
def read_jsonl_file(file_path):
"""Read a JSONL file into a pandas DataFrame."""
return pd.read_json(file_path, lines=True)
def save_to_jsonl(df, file_path):
"""Save a DataFrame to a JSONL file."""
df.to_json(file_path, orient='records', lines=True)
def generate_prompt(examples):
prompt = """
Come up with a series of tasks and questions. Only the task/question,
no further text/explanation, no additional information.
The task or question should be something a person would ask a chatbot.
"""
for _, item in enumerate(examples):
prompt += f"<task>{item}</task>\n"
return prompt
def do_sample(model, tokenizer, examples):
with torch.no_grad():
n_shot_prompt = generate_prompt(examples)
print("<"*80)
print(f"{n_shot_prompt}")
print(">"*80)
model_inputs = tokenizer(n_shot_prompt, return_tensors="pt").to("cuda")
streamer = TextStreamer(tokenizer)
generated_ids = model.generate(
**model_inputs,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1,
streamer=streamer,
top_p=0.9,
temperature=0.6,
max_new_tokens=256
)
decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
answer = decoded[0]
# print(f"A: {answer}")
# print("\n\n")
return answer
def get_random_prompts(df, num_selections=8):
all_selected_prompts = df.sample(n=num_selections)['prompt'].tolist()
return all_selected_prompts
def extract_prompts(answer):
# find all the prompts between <task> </task> brackets
print("="*80)
print("Extracting prompts...")
print(answer)
print("="*80)
prompts = []
while True:
pattern = f"<task>"
start = answer.find(pattern)
if start == -1:
break
end = answer.find("</task>")
if end == -1:
break
prompts.append(answer[start + len(pattern):end])
answer = answer[end + len("</task>"):]
print("Prompts extracted:")
print(prompts)
return prompts
model, tokenizer = load_model(base_name, model_name)
ift_df = read_jsonl_file(ift_dataset_file)
uniq_prompts = set([])
new_prompts = []
while True:
if len(uniq_prompts) >= num_prompts_to_generate:
break
task_prompts = get_random_prompts(ift_df)
answer = do_sample(model, tokenizer, task_prompts)
prompts = extract_prompts(answer)
for prompt in prompts:
if prompt not in uniq_prompts:
uniq_prompts.add(prompt)
prompt_id = str(uuid.uuid4())
new_prompts.append({"prompt_id": prompt_id, "prompt": prompt, "source": "generated"})
new_prompts_df = pd.DataFrame(new_prompts)
save_to_jsonl(new_prompts_df, generated_prompts_file)