-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy patheval_mamba_with_context.py
98 lines (82 loc) · 2.74 KB
/
eval_mamba_with_context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
import sys
import json
import pandas as pd
import time
from tqdm import tqdm
def run_mamba(model, question, context):
text = f"{context}\n\nQ: {question}\nA:"
# print(text)
input_ids = torch.LongTensor([tokenizer.encode(text)]).cuda()
# print(input_ids)
out = model.generate(
input_ids=input_ids,
max_length=len(input_ids)+128,
eos_token_id=tokenizer.eos_token_id
)
# print(out)
decoded = tokenizer.batch_decode(out)[0]
# print("="*80)
# print(decoded)
# out returns the whole sequence plus the original
cleaned = decoded.replace(text, "")
cleaned = cleaned.replace("<|endoftext|>", "")
# the model will just keep generating, so only grab the first one
answer = cleaned.split("\n\n")[0].strip()
# print(answer)
return answer
def write_results(results, output_file):
df = pd.DataFrame(results)
df = df[["idx", "context", "question", "answer", "guess", "is_correct", "time"]]
print(f"Writing {output_file}")
df.to_json(output_file, orient="records", lines=True)
model = sys.argv[1]
dataset = sys.argv[2]
output_file = sys.argv[3]
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
model = MambaLMHeadModel.from_pretrained(model, device="cuda", dtype=torch.float16)
results = []
with open(dataset) as f:
all_data = []
for line in tqdm(f):
data = json.loads(line)
all_data.append(data)
total_qs = len(all_data)
for i, data in enumerate(all_data):
start_time = time.time()
# print(data)
question = data["prompt"]
context = data["context"]
answer = data["response"]
guess = run_mamba(model, question, context)
end_time = time.time()
is_correct = (answer.strip().lower() == guess.strip().lower())
print(f"Question {i}/{total_qs}")
print(f"Context: {context}")
print(f"Q: {question}")
print(f"A: {answer}")
print(f"?: {guess}")
if is_correct:
print(f"✅")
else:
print(f"❌")
print("="*80)
result = {
"idx": i,
"question": question,
"context": context,
"answer": answer,
"guess": guess,
"is_correct": is_correct,
"time": end_time - start_time
}
results.append(result)
if len(results) % 20 == 0:
write_results(results, output_file)
# if len(results) > 100:
# break
write_results(results, output_file)