Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptation of the test script for Alpaca Lora, plus 2 tests, Alpaca 7b and 13b #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 118 additions & 91 deletions take_test.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,64 @@
import argparse
import openai
import torch
import json
import time
from transformers import GenerationConfig
from autograd_4bit import load_llama_model_4bit_low_ram

import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer


def load_trivia_questions(file_path):
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
trivia_data = json.load(file)
return trivia_data


def generate_question_string(question_data):
question = question_data['question']
choices = [f" {answer['choice']}. {answer['text']}\n" if answer != question_data['answers'][-1] else f" {answer['choice']}. {answer['text']}" for answer in question_data['answers']]
question = question_data["question"]
choices = [
f" {answer['choice']}. {answer['text']}\n"
if answer != question_data["answers"][-1]
else f" {answer['choice']}. {answer['text']}"
for answer in question_data["answers"]
]
return f"{question}\n{''.join(choices)}"


def grade_answers(question_data, llm_answer):
correct_answer = None
for answer in question_data['answers']:
if answer['correct']:
for answer in question_data["answers"]:
if answer["correct"]:
correct_answer = answer
break

if correct_answer is None:
return "No correct answer found"

normalized_llm_answer = llm_answer.lower().strip()
normalized_correct_answer = correct_answer['text'].lower().strip()
normalized_correct_answer = correct_answer["text"].lower().strip()

# lower case of the full text answer is in the llm's answer
if normalized_correct_answer in normalized_llm_answer:
return f"{correct_answer['choice']}. {correct_answer['text']} (correct)"

# Upper case " A." or " B." or " C." or " D." or " E." for instance
if f" {correct_answer['choice']}." in llm_answer:
return f"{correct_answer['choice']}. {correct_answer['text']} (correct)"
return f"{correct_answer['choice']}. {correct_answer['text']} (correct)"

# Upper case " (A)" or " (B)" or " (C)" or " (D)" or " (E)" for instance
if f"({correct_answer['choice']})" in llm_answer:
return f"{correct_answer['choice']}. {correct_answer['text']} (correct)"
return f"{correct_answer['choice']}. {correct_answer['text']} (correct)"

if "i don't know" in normalized_llm_answer or normalized_llm_answer == "d" or normalized_llm_answer == "d.":
if (
"i don't know" in normalized_llm_answer
or normalized_llm_answer == "d"
or normalized_llm_answer == "d."
):
return f"{llm_answer} (uncertain)"

return f"{llm_answer} (incorrect {correct_answer['choice']}.)"

def query_openai_gpt3_5(prompt, engine):
while True:
try:
response = openai.Completion.create(
engine=engine,
prompt=prompt,
max_tokens=50,
temperature=0.1,
)
return response.choices[0].text.strip()
except openai.error.RateLimitError as e:
print("Rate limit exceeded. Pausing for one minute...")
time.sleep(60)
continue
except Exception as e:
print(f"Error: {e}")
break

if torch.cuda.is_available():
device_count = torch.cuda.device_count()
Expand All @@ -75,115 +71,146 @@ def query_openai_gpt3_5(prompt, engine):
print(f"Using device: {device}")
else:
device = "cpu"
print("No GPU available, using CPU.")
print("No GPU available, quitting.")
exit()


def query_model(
prompt,
model,
tokenizer,
temperature=0.1,
max_new_tokens=50,
prompt,
model,
tokenizer,
temperature=0.1,
max_new_tokens=50,
**kwargs,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
# Default
"""
generation_config = GenerationConfig(
temperature=temperature,
**kwargs,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
**kwargs,
)
"""
# beam3
generation_config = GenerationConfig(
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=3
)
# Multinomial
"""
generation_config = GenerationConfig(
temperature=0.1,
top_p=0.75,
num_beams=1,
do_sample=True
)
"""
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
# output_scores=True,
max_new_tokens=max_new_tokens,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)

s = generation_output.sequences[0]
output = tokenizer.decode(s)
response = output.split("### Response:")[1].strip()
return response.split("### Instruction:")[0].strip()

def main():
parser = argparse.ArgumentParser(description='Run trivia quiz with GPT-3 or a local model.')
parser.add_argument('--use-gpt3', action='store_true', help='Use GPT-3')
parser.add_argument('--use-gpt3-5', action='store_true', help='Use GPT-3.5')
parser.add_argument('--openai-key', type=str, help='OpenAI API key')
parser.add_argument('--trivia', type=str, help='File path to trivia questions')
args = parser.parse_args()

use_gpt_3 = args.use_gpt3 or args.use_gpt3_5
s = generation_output.sequences[0]
output = tokenizer.decode(s)
response = output.split("### Response:")[1].strip()
return response.split("### Instruction:")[0].strip()

if use_gpt_3 and not args.openai_key:
print("Please provide an OpenAI API key with the --openai-key argument.")
return

if use_gpt_3:
openai.api_key = args.openai_key
def main():
# python3 take_test.py --test_name='alpaca7b_lora_r16' --base_model='decapoda-research/llama-7b-hf' --lora_weights='tloen/alpaca-lora-7b' --trivia=nota_trivia_questions.json
# python3 take_test.py --test_name='alpaca7b_lora_r16' --base_model='decapoda-research/llama-7b-hf' --lora_weights='tloen/alpaca-lora-7b' --trivia=hq_trivia_questions.json
# fake_trivia_questions.json
parser = argparse.ArgumentParser(
description="Run trivia quiz with local Alpaca model."
)
parser.add_argument("--base_model", type=str, help="Base llama model")
parser.add_argument("--lora_weights", type=str, help="lora weights")
parser.add_argument("--trivia", type=str, help="File path to trivia questions")
parser.add_argument("--test_name", type=str, help="name for the model/lora couple")
args = parser.parse_args()

if not use_gpt_3:
config_path = './models/llama-7b-hf/'
model_path = './weights/llama-7b-4bit.pt'
model, tokenizer = load_llama_model_4bit_low_ram(config_path, model_path)
print('Fitting 4bit scales and zeros to half')
for n, m in model.named_modules():
if '4bit' in str(type(m)):
m.zeros = m.zeros.half()
m.scales = m.scales.half()
tokenizer = LlamaTokenizer.from_pretrained(
args.base_model, padding_side="left", device_map={"": 0}
)
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"

# Works for cuda only.
model = LlamaForCausalLM.from_pretrained(
args.base_model,
load_in_8bit=True,
torch_dtype=torch.float16,
device_map={"": 0},
)
model = PeftModel.from_pretrained(
model, args.lora_weights, device_map={"": 0}, torch_dtype=torch.float16
)
model.eval()
if torch.__version__ >= "2":
model = torch.compile(model)

file_path = args.trivia
trivia_data = load_trivia_questions(file_path)

total_score = 0
incorrect = []
unknown = []
model_name = ("text-davinci-003" if args.use_gpt3_5 else "text-davinci-002") if use_gpt_3 else "alpaca-lora-4bit"
model_name = ""

for i, question_data in enumerate(trivia_data):
question_string = generate_question_string(question_data)
prompt = generate_prompt(question_string)

print(f"Question {i+1}: {question_string}")
if use_gpt_3:
llm_answer = query_openai_gpt3_5(prompt, model_name)
else:
llm_answer = query_model(prompt, model, tokenizer)
llm_answer = query_model(prompt, model, tokenizer)

answer_output = grade_answers(question_data, llm_answer)
print(f"Answer: {answer_output}\n")

if "(correct)" in answer_output:
total_score += 2
elif "(incorrect" in answer_output:
incorrect.append((i+1, question_string, answer_output))
incorrect.append((i + 1, question_string, answer_output))
else:
total_score += 1
unknown.append((i+1, question_string, answer_output))
unknown.append((i + 1, question_string, answer_output))

with open(f"test_results_{file_path}_{model_name}.txt", 'w') as f:
with open(f"test_results_{file_path}_{args.test_name}.txt", "w") as f:
f.write(f"Total score: {total_score} of {len(trivia_data) * 2}\n")
i = len(incorrect)
if i:
f.write(f"\nIncorrect: {i}\n")
for question_num, question_string, answer_output in incorrect:
f.write(f"Question {question_num}: {question_string.strip()}\n{answer_output.strip()}\n\n")
f.write(
f"Question {question_num}: {question_string.strip()}\n{answer_output.strip()}\n\n"
)
u = len(unknown)
if u:
f.write(f"Unknown: {u}\n")
for question_num, question_string, answer_output in unknown:
f.write(f"Question {question_num}: {question_string.strip()}\n{answer_output.strip()}\n\n")
f.write(
f"Question {question_num}: {question_string.strip()}\n{answer_output.strip()}\n\n"
)

print(f"Total score: {total_score} of {len(trivia_data) * 2}\n", end="")

print(f"Total score: {total_score} of {len(trivia_data) * 2}\n", end='')

def generate_prompt(instruction):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. Only answer the question. Keep token limit low.
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. Only answer the question. Keep token limit low.

### Instruction:
{instruction}

### Response:
"""

if __name__ == '__main__':

if __name__ == "__main__":
main()
Loading