Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch processing to MMLU and Humaneval evaluation scripts to prevent OOM errors #597

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 92 additions & 102 deletions eval/humaneval.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
from __future__ import annotations
import sys, os
import sys
import os
import argparse
import subprocess
import torch

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from human_eval.data import write_jsonl, read_problems
from exllamav2 import model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
import argparse, contextlib, subprocess
import util

# Args

parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
parser.add_argument("-cs", "--cache_size", type = int, default = None)
parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
parser.add_argument("-cq6", "--cache_q6", action = "store_true", help = "Use Q6 cache")
parser.add_argument("-cq8", "--cache_q8", action = "store_true", help = "Use Q8 cache")
parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6")
parser = argparse.ArgumentParser(description="Run HumanEval evaluation on EXL2 model")
parser.add_argument("-o", "--output", type=str, help="Output .jsonl filename", required=True)
parser.add_argument("-cs", "--cache_size", type=int, default=None)
parser.add_argument("-spt", "--samples_per_task", type=int, default=200)
parser.add_argument("-cq4", "--cache_q4", action="store_true", help="Use Q4 cache")
parser.add_argument("-cq6", "--cache_q6", action="store_true", help="Use Q6 cache")
parser.add_argument("-cq8", "--cache_q8", action="store_true", help="Use Q8 cache")
parser.add_argument("--max_tokens", type=int, default=768, help="Max number of tokens for each completion")
parser.add_argument("-pf", "--prompt_format", type=str,
help="Instruct format to apply. Default is raw completion (for base models)")
parser.add_argument("-v", "--verbose", action="store_true", help="Spam completions to console while generating")
parser.add_argument("-e", "--eval", action="store_true", help="Run evaluation script on output file after sampling")
parser.add_argument("-temp", "--temperature", type=float, default=0.6, help="Sampling temperature (0 for greedy)")
parser.add_argument("-bs", "--batch_size", type=int, default=50, help="Number of problems to process in each batch")
model_init.add_args(parser)
args = parser.parse_args()

Expand All @@ -37,38 +43,27 @@
# Prompt formats

prompt_formats = {
"raw": (
"```python\n{{problem}} ",
" "
),
"raw": ("```python\n{{problem}} ", " "),
"granite": (
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"llama": (
"[INST] <<SYS>>\n"
"You are a helpful AI coding assistant.\n"
"<</SYS>>\n\n"
"Complete the following Python function:\n\n"
"{{problem}} [/INST] "
"[INST] <<SYS>>\nYou are a helpful AI coding assistant.\n<</SYS>>\n\n"
"Complete the following Python function:\n\n{{problem}} [/INST] "
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"llama3": (
"<|start_header_id|>system<|end_header_id|>\n\n"
"You are a helpful AI coding assistant.<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n"
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI coding assistant.<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\nComplete the following Python function:\n\n{{problem}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nSure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"gemma": (
"<bos><start_of_turn>user\n"
"Complete the following Python function:\n\n{{problem}}<|eot_id|>"
"<start_of_turn>model\n"
"```python\n{{problem}}",
"<bos><start_of_turn>user\nComplete the following Python function:\n\n{{problem}}<|eot_id|>"
"<start_of_turn>model\n```python\n{{problem}}",
" "
)
}
Expand All @@ -88,93 +83,71 @@
model_init.print_options(args)
model, tokenizer = model_init.init(
args,
allow_auto_split = True,
progress = True,
max_output_len = 4,
max_input_len = 2048
allow_auto_split=True,
progress=True,
max_output_len=4,
max_input_len=2048
)

if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
elif args.cache_q6: cache_type = ExLlamaV2Cache_Q6
elif args.cache_q8: cache_type = ExLlamaV2Cache_Q8
else: cache_type = ExLlamaV2Cache
if args.cache_q4:
cache_type = ExLlamaV2Cache_Q4
elif args.cache_q6:
cache_type = ExLlamaV2Cache_Q6
elif args.cache_q8:
cache_type = ExLlamaV2Cache_Q8
else:
cache_type = ExLlamaV2Cache
cache = cache_type(
model,
lazy = not model.loaded,
max_seq_len = args.cache_size or model.config.max_seq_len
lazy=not model.loaded,
max_seq_len=args.cache_size or model.config.max_seq_len
)

if not model.loaded:
model.load_autosplit(cache, progress = True)
model.load_autosplit(cache, progress=True)

# Generator

generator = ExLlamaV2DynamicGenerator(
model = model,
cache = cache,
tokenizer = tokenizer,
max_batch_size = 256,
max_q_size = 4
model=model,
cache=cache,
tokenizer=tokenizer,
max_batch_size=256,
max_q_size=4
)

gen_settings = ExLlamaV2Sampler.Settings(
token_repetition_penalty = 1.0,
temperature = 0.6,
top_k = 50,
top_p = 0.6
token_repetition_penalty=1.0,
temperature=args.temperature,
top_k=50,
top_p=0.6
)

# Get problems

problems = read_problems()
num_samples_per_task = args.samples_per_task

# Create jobs

with util.get_progress() as progress:

task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
for problem_id, problem in problems.items():
def process_batch(batch_problems, batch_size, progress, sample_task, generate_task):
samples = []

for problem_id, problem in batch_problems.items():
b_problem = problem["prompt"]
f_problem = prompt_format.replace("{{problem}}", b_problem)
input_ids = tokenizer.encode(f_problem, encode_special_tokens=True, add_bos=True)

for s in range(num_samples_per_task):

for s in range(batch_size):
job = ExLlamaV2DynamicJob(
input_ids = input_ids,
gen_settings = gen_settings,
max_new_tokens = args.max_tokens,
stop_conditions = [tokenizer.eos_token_id],
token_healing = True,
identifier = (problem_id, s),
min_new_tokens = 6
input_ids=input_ids,
gen_settings=gen_settings,
max_new_tokens=args.max_tokens,
stop_conditions=[tokenizer.eos_token_id],
token_healing=True,
identifier=(problem_id, s),
min_new_tokens=6
)

generator.enqueue(job)
progress.update(task1, advance = 1)

# Collect samples here

samples = []

# Work

total_jobs = generator.num_remaining_jobs()
cm = contextlib.nullcontext() if args.verbose else util.get_progress()
with cm as progress:

if not args.verbose:
task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Generating samples")
progress.update(sample_task, advance=1)

while generator.num_remaining_jobs():

results = generator.iterate()
for result in results:

# End sample if generator says EOS or if there is a non-indented line at the end of the output

job = result["job"]
eos = False
completion = job.full_completion
Expand All @@ -186,32 +159,49 @@
eos = True
eos = eos or result["eos"]

# Collect completed sample

if eos:
identifier = result["identifier"]
sample = problems[identifier[0]]["prompt"] + prefix + completion.strip()
sample = batch_problems[identifier[0]]["prompt"] + prefix + completion.strip()
if not result["eos"]:
generator.cancel(job)

if args.verbose:
print("----------------------------------------------------------------------")
print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {num_samples_per_task}")
print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {batch_size}")
print("----------------------------------------------------------------------")
print(sample)
print()
else:
progress.update(task1, advance = 1)
progress.update(generate_task, advance=1)

samples.append(dict(task_id = identifier[0], completion = prefix + completion.strip()))
samples.append(dict(task_id=identifier[0], completion=prefix + completion.strip()))

# Save output
return samples


# Main execution
problems = read_problems()
all_samples = []
batch_size = args.batch_size
total_samples = len(problems) * args.samples_per_task

with util.get_progress() as progress:
sample_task = progress.add_task("[red]Sample", total=total_samples, name="Creating sample jobs")
generate_task = progress.add_task("[green]Sample", total=total_samples, name="Generating samples")

for i in range(0, len(problems), batch_size):
batch_problems = dict(list(problems.items())[i:i + batch_size])
batch_samples = process_batch(batch_problems, args.samples_per_task, progress, sample_task, generate_task)
all_samples.extend(batch_samples)

# Optional: Clear CUDA cache to free up memory
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Save output
print(f" -- Saving: {args.output}")
write_jsonl(args.output, samples)
write_jsonl(args.output, all_samples)

# Optionally launch eval script

if args.eval:
subprocess.run(["evaluate_functional_correctness", args.output])

subprocess.run(["evaluate_functional_correctness", args.output])
Loading