diff --git a/eval/humaneval.py b/eval/humaneval.py
index ca3ca008..24fc952f 100644
--- a/eval/humaneval.py
+++ b/eval/humaneval.py
@@ -1,27 +1,33 @@
from __future__ import annotations
-import sys, os
+import sys
+import os
+import argparse
+import subprocess
+import torch
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from human_eval.data import write_jsonl, read_problems
from exllamav2 import model_init
from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
-import argparse, contextlib, subprocess
import util
# Args
-parser = argparse.ArgumentParser(description = "Run HumanEval evaluation on EXL2 model")
-parser.add_argument("-o", "--output", type = str, help = "Output .jsonl filename", required = True)
-parser.add_argument("-cs", "--cache_size", type = int, default = None)
-parser.add_argument("-spt", "--samples_per_task", type = int, default = 200)
-parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
-parser.add_argument("-cq6", "--cache_q6", action = "store_true", help = "Use Q6 cache")
-parser.add_argument("-cq8", "--cache_q8", action = "store_true", help = "Use Q8 cache")
-parser.add_argument("--max_tokens", type = int, default = 768, help = "Max number of tokens for each completion")
-parser.add_argument("-pf", "--prompt_format", type = str, help = "Instruct format to apply. Default is raw completion (for base models) ")
-parser.add_argument("-v", "--verbose", action = "store_true", help = "Spam completions to console while generating")
-parser.add_argument("-e", "--eval", action = "store_true", help = "Run evaluation script on output file after sampling")
-parser.add_argument("-temp", "--temperature", type = float, help = "Sampling temperature (0 for greedy), default: 0.6")
+parser = argparse.ArgumentParser(description="Run HumanEval evaluation on EXL2 model")
+parser.add_argument("-o", "--output", type=str, help="Output .jsonl filename", required=True)
+parser.add_argument("-cs", "--cache_size", type=int, default=None)
+parser.add_argument("-spt", "--samples_per_task", type=int, default=200)
+parser.add_argument("-cq4", "--cache_q4", action="store_true", help="Use Q4 cache")
+parser.add_argument("-cq6", "--cache_q6", action="store_true", help="Use Q6 cache")
+parser.add_argument("-cq8", "--cache_q8", action="store_true", help="Use Q8 cache")
+parser.add_argument("--max_tokens", type=int, default=768, help="Max number of tokens for each completion")
+parser.add_argument("-pf", "--prompt_format", type=str,
+ help="Instruct format to apply. Default is raw completion (for base models)")
+parser.add_argument("-v", "--verbose", action="store_true", help="Spam completions to console while generating")
+parser.add_argument("-e", "--eval", action="store_true", help="Run evaluation script on output file after sampling")
+parser.add_argument("-temp", "--temperature", type=float, default=0.6, help="Sampling temperature (0 for greedy)")
+parser.add_argument("-bs", "--batch_size", type=int, default=50, help="Number of problems to process in each batch")
model_init.add_args(parser)
args = parser.parse_args()
@@ -37,38 +43,27 @@
# Prompt formats
prompt_formats = {
- "raw": (
- "```python\n{{problem}} ",
- " "
- ),
+ "raw": ("```python\n{{problem}} ", " "),
"granite": (
"Question:\nComplete the following Python function:\n\n{{problem}}\n\nAnswer:\n"
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"llama": (
- "[INST] <>\n"
- "You are a helpful AI coding assistant.\n"
- "<>\n\n"
- "Complete the following Python function:\n\n"
- "{{problem}} [/INST] "
+ "[INST] <>\nYou are a helpful AI coding assistant.\n<>\n\n"
+ "Complete the following Python function:\n\n{{problem}} [/INST] "
"Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"llama3": (
- "<|start_header_id|>system<|end_header_id|>\n\n"
- "You are a helpful AI coding assistant.<|eot_id|>"
- "<|start_header_id|>user<|end_header_id|>\n\n"
- "Complete the following Python function:\n\n{{problem}}<|eot_id|>"
- "<|start_header_id|>assistant<|end_header_id|>\n\n"
- "Sure! Here is how you might implement the function:\n\n```python\n{{problem}}",
+ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful AI coding assistant.<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nComplete the following Python function:\n\n{{problem}}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\nSure! Here is how you might implement the function:\n\n```python\n{{problem}}",
" "
),
"gemma": (
- "user\n"
- "Complete the following Python function:\n\n{{problem}}<|eot_id|>"
- "model\n"
- "```python\n{{problem}}",
+ "user\nComplete the following Python function:\n\n{{problem}}<|eot_id|>"
+ "model\n```python\n{{problem}}",
" "
)
}
@@ -88,93 +83,71 @@
model_init.print_options(args)
model, tokenizer = model_init.init(
args,
- allow_auto_split = True,
- progress = True,
- max_output_len = 4,
- max_input_len = 2048
+ allow_auto_split=True,
+ progress=True,
+ max_output_len=4,
+ max_input_len=2048
)
-if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
-elif args.cache_q6: cache_type = ExLlamaV2Cache_Q6
-elif args.cache_q8: cache_type = ExLlamaV2Cache_Q8
-else: cache_type = ExLlamaV2Cache
+if args.cache_q4:
+ cache_type = ExLlamaV2Cache_Q4
+elif args.cache_q6:
+ cache_type = ExLlamaV2Cache_Q6
+elif args.cache_q8:
+ cache_type = ExLlamaV2Cache_Q8
+else:
+ cache_type = ExLlamaV2Cache
cache = cache_type(
model,
- lazy = not model.loaded,
- max_seq_len = args.cache_size or model.config.max_seq_len
+ lazy=not model.loaded,
+ max_seq_len=args.cache_size or model.config.max_seq_len
)
if not model.loaded:
- model.load_autosplit(cache, progress = True)
+ model.load_autosplit(cache, progress=True)
# Generator
generator = ExLlamaV2DynamicGenerator(
- model = model,
- cache = cache,
- tokenizer = tokenizer,
- max_batch_size = 256,
- max_q_size = 4
+ model=model,
+ cache=cache,
+ tokenizer=tokenizer,
+ max_batch_size=256,
+ max_q_size=4
)
gen_settings = ExLlamaV2Sampler.Settings(
- token_repetition_penalty = 1.0,
- temperature = 0.6,
- top_k = 50,
- top_p = 0.6
+ token_repetition_penalty=1.0,
+ temperature=args.temperature,
+ top_k=50,
+ top_p=0.6
)
-# Get problems
-
-problems = read_problems()
-num_samples_per_task = args.samples_per_task
-
-# Create jobs
-with util.get_progress() as progress:
-
- task1 = progress.add_task("[red]Sample", total = len(problems) * num_samples_per_task, name = "Creating sample jobs")
- for problem_id, problem in problems.items():
+def process_batch(batch_problems, batch_size, progress, sample_task, generate_task):
+ samples = []
+ for problem_id, problem in batch_problems.items():
b_problem = problem["prompt"]
f_problem = prompt_format.replace("{{problem}}", b_problem)
input_ids = tokenizer.encode(f_problem, encode_special_tokens=True, add_bos=True)
- for s in range(num_samples_per_task):
-
+ for s in range(batch_size):
job = ExLlamaV2DynamicJob(
- input_ids = input_ids,
- gen_settings = gen_settings,
- max_new_tokens = args.max_tokens,
- stop_conditions = [tokenizer.eos_token_id],
- token_healing = True,
- identifier = (problem_id, s),
- min_new_tokens = 6
+ input_ids=input_ids,
+ gen_settings=gen_settings,
+ max_new_tokens=args.max_tokens,
+ stop_conditions=[tokenizer.eos_token_id],
+ token_healing=True,
+ identifier=(problem_id, s),
+ min_new_tokens=6
)
-
generator.enqueue(job)
- progress.update(task1, advance = 1)
-
-# Collect samples here
-
-samples = []
-
-# Work
-
-total_jobs = generator.num_remaining_jobs()
-cm = contextlib.nullcontext() if args.verbose else util.get_progress()
-with cm as progress:
-
- if not args.verbose:
- task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Generating samples")
+ progress.update(sample_task, advance=1)
while generator.num_remaining_jobs():
-
results = generator.iterate()
for result in results:
-
- # End sample if generator says EOS or if there is a non-indented line at the end of the output
-
job = result["job"]
eos = False
completion = job.full_completion
@@ -186,32 +159,49 @@
eos = True
eos = eos or result["eos"]
- # Collect completed sample
-
if eos:
identifier = result["identifier"]
- sample = problems[identifier[0]]["prompt"] + prefix + completion.strip()
+ sample = batch_problems[identifier[0]]["prompt"] + prefix + completion.strip()
if not result["eos"]:
generator.cancel(job)
if args.verbose:
print("----------------------------------------------------------------------")
- print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {num_samples_per_task}")
+ print(f" ** Problem {identifier[0]}, sample {identifier[1] + 1} / {batch_size}")
print("----------------------------------------------------------------------")
print(sample)
print()
else:
- progress.update(task1, advance = 1)
+ progress.update(generate_task, advance=1)
- samples.append(dict(task_id = identifier[0], completion = prefix + completion.strip()))
+ samples.append(dict(task_id=identifier[0], completion=prefix + completion.strip()))
-# Save output
+ return samples
+
+
+# Main execution
+problems = read_problems()
+all_samples = []
+batch_size = args.batch_size
+total_samples = len(problems) * args.samples_per_task
+
+with util.get_progress() as progress:
+ sample_task = progress.add_task("[red]Sample", total=total_samples, name="Creating sample jobs")
+ generate_task = progress.add_task("[green]Sample", total=total_samples, name="Generating samples")
+
+ for i in range(0, len(problems), batch_size):
+ batch_problems = dict(list(problems.items())[i:i + batch_size])
+ batch_samples = process_batch(batch_problems, args.samples_per_task, progress, sample_task, generate_task)
+ all_samples.extend(batch_samples)
+ # Optional: Clear CUDA cache to free up memory
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+# Save output
print(f" -- Saving: {args.output}")
-write_jsonl(args.output, samples)
+write_jsonl(args.output, all_samples)
# Optionally launch eval script
-
if args.eval:
- subprocess.run(["evaluate_functional_correctness", args.output])
-
+ subprocess.run(["evaluate_functional_correctness", args.output])
\ No newline at end of file
diff --git a/eval/mmlu.py b/eval/mmlu.py
index 0d3f7d8a..0b256b85 100644
--- a/eval/mmlu.py
+++ b/eval/mmlu.py
@@ -1,197 +1,187 @@
from __future__ import annotations
-import sys, os
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-from exllamav2 import model_init
-from exllamav2 import ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
-from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
-import argparse, contextlib
-import torch
+import sys, argparse, random, torch
import util
-import random
-
-# Args
-
-parser = argparse.ArgumentParser(description = "Run MMLU evaluation on EXL2 model")
-parser.add_argument("-cs", "--cache_size", type = int, default = None)
-parser.add_argument("-cq4", "--cache_q4", action = "store_true", help = "Use Q4 cache")
-parser.add_argument("-cq6", "--cache_q6", action = "store_true", help = "Use Q6 cache")
-parser.add_argument("-cq8", "--cache_q8", action = "store_true", help = "Use Q8 cache")
-parser.add_argument("-sub", "--subjects", type = str, default = "all", help = "Comma-separated list of categories to test, or 'all'")
-parser.add_argument("-fs", "--fewshot_examples", type = int, default = 5, help = "Number of examples for fewshot examples, max 5")
-parser.add_argument("-shf", "--shuffle", action = "store_true", help = "Shuffle choices randomly")
+from exllamav2 import model_init, ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8
+from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2DynamicJob, ExLlamaV2Sampler
+from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn
+from collections import defaultdict
+
+# Argument Parsing
+parser = argparse.ArgumentParser(description="Run MMLU evaluation on EXL2 model")
+parser.add_argument("-cs", "--cache_size", type=int, default=None)
+parser.add_argument("-cq4", "--cache_q4", action="store_true", help="Use Q4 cache")
+parser.add_argument("-cq6", "--cache_q6", action="store_true", help="Use Q6 cache")
+parser.add_argument("-cq8", "--cache_q8", action="store_true", help="Use Q8 cache")
+parser.add_argument("-sub", "--subjects", type=str, default="all", help="Comma-separated list of categories to test, or 'all'")
+parser.add_argument("-fs", "--fewshot_examples", type=int, default=5, help="Number of examples for fewshot examples, max 5")
+parser.add_argument("-shf", "--shuffle", action="store_true", help="Shuffle choices randomly")
+parser.add_argument("-bs", "--batch_size", type=int, default=128, help="Number of problems to process in each batch. Decrease to prevent potential OOM errors")
model_init.add_args(parser)
args = parser.parse_args()
-# Init model and cache
-
-model_init.check_args(args)
-model_init.print_options(args)
-model, tokenizer = model_init.init(
- args,
- allow_auto_split = True,
- progress = True,
- max_output_len = 1,
- max_input_len = 2048
-)
-
-if args.cache_q4: cache_type = ExLlamaV2Cache_Q4
-elif args.cache_q6: cache_type = ExLlamaV2Cache_Q6
-elif args.cache_q8: cache_type = ExLlamaV2Cache_Q8
-else: cache_type = ExLlamaV2Cache
-cache = cache_type(
- model,
- lazy = not model.loaded,
- max_seq_len = args.cache_size or model.config.max_seq_len
-)
-
-if not model.loaded:
- model.load_autosplit(cache, progress = True)
-
-# Generator
-
-generator = ExLlamaV2DynamicGenerator(
- model = model,
- cache = cache,
- tokenizer = tokenizer,
- max_batch_size = 1024,
- max_q_size = 1
-)
-
-c_options = "ABCD"
-
-gen_settings = ExLlamaV2Sampler.Settings(
- token_repetition_penalty = 1.0,
- temperature = 1.0,
- top_k = 10,
- top_p = 1.0,
-)
-
-token_map = [tokenizer.single_id(piece) for piece in [" " + c for c in c_options]]
-token_rmap = { token_map[i]: i for i in range(len(c_options)) }
-gen_settings.allow_tokens(tokenizer, token_map)
-
-# Get dataset
-
-dataset_dev = util.get_dataset("cais/mmlu", "all", "dev")
-dataset_all = util.get_dataset("cais/mmlu", "all", "test")
-dataset_dev = sorted(dataset_dev, key = lambda q: q["subject"])
-dataset_all = sorted(dataset_all, key = lambda q: q["subject"])
-
-all_subjects = set([q["subject"] for q in dataset_dev])
-if args.subjects != "all":
- sel_subjects = args.subjects.split(",")
- for s in sel_subjects:
- if s not in all_subjects:
- print(f"Subject: {s} is not present in dataset")
- sys.exit()
- all_subjects = set(sel_subjects)
-
-# Optionally shuffle
-
-if args.shuffle:
- for problem in dataset_all:
- if problem["subject"] in all_subjects:
- perm = random.sample(range(4), k = 4)
- problem["choices"] = [problem["choices"][i] for i in perm]
- problem["answer"] = perm.index(problem["answer"])
-
-# Format
-
+# Model and Cache Initialization
+def initialize_model_and_cache(args):
+ try:
+ model_init.check_args(args)
+ model_init.print_options(args)
+ model, tokenizer = model_init.init(args, allow_auto_split=True, progress=True, max_output_len=1, max_input_len=2048)
+ except Exception as e:
+ print(f"Error initializing model: {e}")
+ sys.exit(1)
+
+ cache_type = {
+ "q4": ExLlamaV2Cache_Q4,
+ "q6": ExLlamaV2Cache_Q6,
+ "q8": ExLlamaV2Cache_Q8
+ }.get(next((k[6:] for k, v in vars(args).items() if k.startswith('cache_q') and v), None), ExLlamaV2Cache)
+
+ cache = cache_type(model, lazy=not model.loaded, max_seq_len=args.cache_size or model.config.max_seq_len)
+ if not model.loaded:
+ model.load_autosplit(cache, progress=True)
+
+ return model, tokenizer, cache
+
+# Dataset Loading and Preparation
+def load_and_prepare_datasets(args):
+ try:
+ dataset_dev = sorted(util.get_dataset("cais/mmlu", "all", "dev"), key=lambda q: q["subject"])
+ dataset_all = sorted(util.get_dataset("cais/mmlu", "all", "test"), key=lambda q: q["subject"])
+ except Exception as e:
+ print(f"Error loading datasets: {e}")
+ sys.exit(1)
+
+ all_subjects = set(q["subject"] for q in dataset_dev)
+ if args.subjects != "all":
+ sel_subjects = set(args.subjects.split(","))
+ invalid_subjects = sel_subjects - all_subjects
+ if invalid_subjects:
+ print(f"Subjects not present in dataset: {', '.join(invalid_subjects)}")
+ sys.exit(1)
+ all_subjects = sel_subjects
+
+ if args.shuffle:
+ for problem in dataset_all:
+ if problem["subject"] in all_subjects:
+ perm = random.sample(range(4), k=4)
+ problem["choices"] = [problem["choices"][i] for i in perm]
+ problem["answer"] = perm.index(problem["answer"])
+
+ return dataset_dev, dataset_all, all_subjects
+
+# Question Formatting
def format_question(question: str, choices: list[str], answer: int | None):
- f = question + "\n"
- for i, c in enumerate(c_options):
- f += c + ". " + choices[i] + "\n"
- f += "Answer:"
- if answer is not None:
- f += " " + c_options[answer] + "\n\n"
+ f = question + "\n" + "\n".join(f"{c}. {choices[i]}" for i, c in enumerate("ABCD")) + "\nAnswer:"
+ if answer is not None: f += f" {'ABCD'[answer]}\n\n"
return f
-# Fewshot preprompts
-
-preprompt_ids = {}
-with util.get_progress() as progress:
- task1 = progress.add_task("[red]Preprompts", total = len(all_subjects), name = "Preparing preprompts")
+# Preprompt Preparation
+def prepare_preprompts(all_subjects, dataset_dev, tokenizer, args, progress, task_id):
+ preprompt_ids = {}
for subject in all_subjects:
-
preprompt = f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.\n\n"
fewshots = 0
for pq in dataset_dev:
if fewshots == args.fewshot_examples: break
if pq["subject"] != subject: continue
preprompt += format_question(pq["question"], pq["choices"], pq["answer"])
- preprompt_ids[subject] = tokenizer.encode(preprompt, add_bos = True)
- progress.update(task1, advance = 1)
-
-# Questions
-
-total_jobs = 0
-for q in dataset_all:
- if q["subject"] in all_subjects:
- total_jobs += 1
-
-with util.get_progress() as progress:
- task1 = progress.add_task("[red]Questions", total=total_jobs, name="Preparing questions")
+ preprompt_ids[subject] = tokenizer.encode(preprompt, add_bos=True)
+ progress.update(task_id, advance=1)
+ return preprompt_ids
+# Job Preparation
+def prepare_jobs(dataset_all, all_subjects, preprompt_ids, tokenizer, gen_settings, batch_size, progress, task_id):
+ jobs = []
for q in dataset_all:
- if q["subject"] not in all_subjects:
- continue
-
+ if q["subject"] not in all_subjects: continue
prompt = format_question(q["question"], q["choices"], None)
- prompt_ids = tokenizer.encode(prompt, add_bos = False)
-
+ prompt_ids = tokenizer.encode(prompt, add_bos=False)
job = ExLlamaV2DynamicJob(
- input_ids = torch.cat([preprompt_ids[q["subject"]], prompt_ids], dim = -1),
- gen_settings = gen_settings,
- max_new_tokens = 1,
- return_top_tokens = 4,
- identifier = q,
+ input_ids=torch.cat([preprompt_ids[q["subject"]], prompt_ids], dim=-1),
+ gen_settings=gen_settings,
+ max_new_tokens=1,
+ return_top_tokens=4,
+ identifier=q,
)
+ jobs.append(job)
+ progress.update(task_id, advance=1)
- generator.enqueue(job)
- progress.update(task1, advance = 1)
+ if len(jobs) == batch_size:
+ yield jobs
+ jobs = []
-# Work
+ if jobs:
+ yield jobs
-with util.get_progress() as progress:
- task1 = progress.add_task("[red]Sample", total = total_jobs, name = "Testing")
+# Batch Processing
+def process_batch(generator, job_batch, token_map, token_rmap, progress, task_id):
+ for job in job_batch:
+ generator.enqueue(job)
while generator.num_remaining_jobs():
-
results = generator.iterate()
for result in results:
-
- if not result["eos"]:
- continue
-
- # Ignore completion and use top-K tokens only
-
- top_tokens = result["top_k_tokens"]
- top_probs = result["top_k_probs"]
- q = result["identifier"]
-
+ if not result["eos"]: continue
+ top_tokens, top_probs, q = result["top_k_tokens"], result["top_k_probs"], result["identifier"]
correct_answer = q["answer"]
for i in range(top_tokens.shape[-1]):
if top_tokens[0, 0, i].item() == token_map[correct_answer]:
confidence = top_probs[0, 0, i].item()
-
q["correct_answer_confidence"] = confidence
q["answer_correct"] = token_rmap[top_tokens[0, 0, 0].item()] == correct_answer
+ progress.update(task_id, advance=1)
+
+# Result Summarization
+def summarize_results(dataset_all):
+ results = defaultdict(lambda: {"total": 0, "correct": 0, "confidence_sum": 0})
+
+ for q in dataset_all:
+ if "answer_correct" not in q:
+ continue
+ subject = q["subject"]
+ results[subject]["total"] += 1
+ results["overall"]["total"] += 1
+ if q["answer_correct"]:
+ results[subject]["correct"] += 1
+ results["overall"]["correct"] += 1
+ results[subject]["confidence_sum"] += q["correct_answer_confidence"]
+ results["overall"]["confidence_sum"] += q["correct_answer_confidence"]
+
+ print("\nResults:")
+ print(f"{'Subject':<30} {'Accuracy':<10} {'Confidence':<10}")
+ print("-" * 50)
+ for subject, data in sorted(results.items()):
+ if subject != "overall":
+ acc = data["correct"] / data["total"] * 100
+ conf = data["confidence_sum"] / data["total"] * 100
+ print(f"{subject:<30} {acc:.2f}% {conf:.2f}%")
+
+ print("-" * 50)
+ overall = results["overall"]
+ overall_acc = overall["correct"] / overall["total"] * 100
+ overall_conf = overall["confidence_sum"] / overall["total"] * 100
+ print(f"{'Overall':<30} {overall_acc:.2f}% {overall_conf:.2f}%")
+
+# Main Execution
+model, tokenizer, cache = initialize_model_and_cache(args)
+dataset_dev, dataset_all, all_subjects = load_and_prepare_datasets(args)
+
+generator = ExLlamaV2DynamicGenerator(model=model, cache=cache, tokenizer=tokenizer, max_batch_size=1024, max_q_size=1)
+gen_settings = ExLlamaV2Sampler.Settings(token_repetition_penalty=1.0, temperature=1.0, top_k=10, top_p=1.0)
+token_map = [tokenizer.single_id(" " + c) for c in "ABCD"]
+token_rmap = {token_map[i]: i for i in range(len("ABCD"))}
+gen_settings.allow_tokens(tokenizer, token_map)
- progress.update(task1, advance = 1)
+progress = Progress(TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), TextColumn("{task.completed}/{task.total}"))
-# Summarize
+with progress:
+ preprompt_task = progress.add_task("[red]Preparing preprompts", total=len(all_subjects))
+ preprompt_ids = prepare_preprompts(all_subjects, dataset_dev, tokenizer, args, progress, preprompt_task)
-total = 0
-correct = 0
-confidence_sum = 0.0
+ total_jobs = sum(1 for q in dataset_all if q["subject"] in all_subjects)
+ preparation_task = progress.add_task("[green]Preparing questions", total=total_jobs)
+ processing_task = progress.add_task("[blue]Processing questions", total=total_jobs)
-for q in dataset_all:
- if not "answer_correct" in q:
- continue
- total += 1
- if q["answer_correct"]:
- correct += 1
- confidence_sum += q["correct_answer_confidence"]
+ for job_batch in prepare_jobs(dataset_all, all_subjects, preprompt_ids, tokenizer, gen_settings, args.batch_size, progress, preparation_task):
+ process_batch(generator, job_batch, token_map, token_rmap, progress, processing_task)
-print(f"Correct answers: {correct}/{total} = {correct/total*100:.2f}%")
-print(f"Confidence: {confidence_sum/total*100:.2f}%")
\ No newline at end of file
+summarize_results(dataset_all)
\ No newline at end of file