-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbatch_prediction.py
More file actions
347 lines (280 loc) · 13.3 KB
/
batch_prediction.py
File metadata and controls
347 lines (280 loc) · 13.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
#!/usr/bin/env python3
"""
Batch prediction orchestrator for running multiple models and prediction modes efficiently.
Usage examples:
python batch_prediction.py --mode direct --model claude --runs 5
python batch_prediction.py --mode narrative --models claude,gpt4,gemini --runs 3 --parallel
python batch_prediction.py --mode both --model deepseek --max-concurrent 2
"""
import argparse
import asyncio
import json
import logging
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple
# Import registry system
from models import (
MODEL_REGISTRY, get_prediction, get_baseline_prediction
)
MODEL_NAMES = {k: v['model_id'] for k, v in MODEL_REGISTRY.items()}
# Extract rate limits from registry
RATE_LIMITS = {k: v['rate_limit'] for k, v in MODEL_REGISTRY.items()}
def setup_logging(log_filename: str):
"""Set up logging configuration."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_filename, mode='a', encoding='utf-8'),
logging.StreamHandler(sys.stdout)
]
)
def setup_question_logger(question_id: int, model_name: str, mode: str):
"""Set up a logger for a specific question, model, and mode."""
# Sanitize model name for filename (replace slashes and other problematic characters)
safe_model_name = model_name.replace('/', '_').replace('\\', '_')
log_filename = f"logs/{question_id}_{safe_model_name}_{mode}.log"
logger = logging.getLogger(f"{question_id}_{model_name}_{mode}")
logger.setLevel(logging.INFO)
# Clear existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
file_handler = logging.FileHandler(log_filename, mode='a', encoding='utf-8')
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def log_question_reasoning(question_id: int, reasoning: str, question_title: str,
model_name: str, mode: str, run_number: int):
"""Log the reasoning for a specific question and run."""
logger = setup_question_logger(question_id, model_name, mode)
logger.info(f"Question: {question_title}")
logger.info(f"Run {run_number}:\n{reasoning}\n")
def list_questions(dataset_path: str) -> List[Dict]:
"""Get questions from the specified data file."""
with open(dataset_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return [
{
'id': item['id'],
'title': item['title'],
'resolution_criteria': item.get('resolution_criteria', ''),
'background': item.get('background', ''),
'fine_print': item.get('fine_print', ''),
'open_time': item['open_time'],
'scheduled_resolve_time': item.get('scheduled_resolve_time', '')
}
for item in data
]
def get_news_for_question(question_id: int, news_set: str = 'aibq3') -> str:
"""Get news articles for a specific question ID."""
news_files = {
'aibq3': 'data/aibq3_news.json',
'aibq4': 'data/aibq4_RTnews.json'
}
news_file = news_files.get(news_set, 'data/aibq3_news.json')
try:
with open(news_file, 'r', encoding='utf-8') as f:
news_data = json.load(f)
for item in news_data:
if item['question_id'] == question_id:
return item['news']
except FileNotFoundError:
logging.warning(f"News file {news_file} not found")
return "No news found for this question."
def get_output_filename(model_key: str, mode: str, dataset: str = None, custom_suffix: str = None) -> str:
"""Generate output filename based on model, mode, and dataset."""
model_name = MODEL_NAMES[model_key]
# Sanitize model name for filename
safe_model_name = model_name.replace('/', '-').replace('\\', '-')
if dataset == 'aibq4':
dataset_prefix = 'aibq4_subset'
else:
dataset_prefix = 'aibq3'
base_name = f"{dataset_prefix}_predictions_{mode}_{safe_model_name}"
if custom_suffix:
base_name += f"_{custom_suffix}"
return f"data/{base_name}.json"
def log_questions_json(questions_data: List[Dict], filename: str):
"""Log question predictions to a JSON file."""
logging.info(f"Adding {len(questions_data)} items to {filename}")
# Check for directory and create if it doesn't exist
os.makedirs(os.path.dirname(filename), exist_ok=True)
try:
# Read existing data if file exists
if os.path.exists(filename):
with open(filename, 'r', encoding='utf-8') as json_file:
existing_data = json.load(json_file)
else:
existing_data = []
# Update existing entries or add new ones
for new_entry in questions_data:
existing_entry = next((item for item in existing_data
if item["question_id"] == new_entry["question_id"]), None)
if existing_entry:
existing_entry.update(new_entry)
else:
existing_data.append(new_entry)
# Write all questions to the JSON file
with open(filename, 'w', encoding='utf-8') as json_file:
json.dump(existing_data, json_file, ensure_ascii=False, indent=2)
logging.info(f"Successfully wrote {len(existing_data)} total items to {filename}")
except Exception as e:
logging.error(f"Error writing to {filename}: {str(e)}")
async def run_single_prediction(question: Dict, model_key: str, mode: str,
run_number: int, semaphore: asyncio.Semaphore, news_set: str = 'aibq3') -> Tuple[str, int]:
"""Run a single prediction with rate limiting."""
async with semaphore:
# Check if model exists in registry
if model_key not in MODEL_REGISTRY:
raise ValueError(f"Unknown model: {model_key}. Available models: {list(MODEL_REGISTRY.keys())}")
formatted_articles = get_news_for_question(question['id'], news_set)
# Run the prediction in a thread to avoid blocking
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
if mode == 'baseline':
result = await loop.run_in_executor(
executor, get_baseline_prediction, model_key, question
)
else: # direct or narrative mode
result = await loop.run_in_executor(
executor, get_prediction, model_key, question, formatted_articles, mode
)
# Log the result
model_name = MODEL_NAMES.get(model_key, model_key)
log_question_reasoning(question['id'], result, question['title'],
model_name, mode, run_number)
return result, run_number
async def process_question_runs(question: Dict, model_key: str, mode: str,
num_runs: int, semaphore: asyncio.Semaphore, news_set: str = 'aibq3') -> Dict:
"""Process all runs for a single question and model."""
question_id = question['id']
model_name = MODEL_NAMES[model_key]
print(f"Processing question {question_id} with {model_key} ({num_runs} runs)")
# Create tasks for all runs
tasks = [
run_single_prediction(question, model_key, mode, run, semaphore, news_set)
for run in range(num_runs)
]
# Wait for all runs to complete
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
question_data = {
"question_id": question_id,
"question_title": question['title']
}
for result in results:
if isinstance(result, Exception):
logging.error(f"Error in question {question_id}: {result}")
continue
reasoning, run_number = result
question_data[f"{model_name}_reasoning{run_number}"] = reasoning
return question_data
async def run_batch_predictions(questions: List[Dict], models: List[str],
mode: str, num_runs: int, max_concurrent: int,
output_suffix: str = None, news_set: str = 'aibq3'):
"""Run batch predictions for multiple models and questions."""
all_results = {}
for model_key in models:
if model_key not in MODEL_REGISTRY:
logging.warning(f"Skipping {model_key} - not found in MODEL_REGISTRY")
continue
# Create semaphore for rate limiting
rate_limit = min(max_concurrent, RATE_LIMITS.get(model_key, 10))
semaphore = asyncio.Semaphore(rate_limit)
logging.info(f"Starting {model_key} predictions with max {rate_limit} concurrent requests")
model_results = []
# Process questions in batches to avoid overwhelming the system
batch_size = 50 # Process 50 questions at a time
for i in range(0, len(questions), batch_size):
batch_questions = questions[i:i + batch_size]
# Create tasks for this batch
batch_tasks = [
process_question_runs(question, model_key, mode, num_runs, semaphore, news_set)
for question in batch_questions
]
# Wait for batch to complete
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# Process batch results
for result in batch_results:
if isinstance(result, Exception):
logging.error(f"Batch error: {result}")
else:
model_results.append(result)
# Save progress after each batch
output_filename = get_output_filename(model_key, mode, news_set, output_suffix)
log_questions_json(model_results, output_filename)
logging.info(f"Completed batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size} for {model_key}")
all_results[model_key] = model_results
logging.info(f"Completed all predictions for {model_key}")
return all_results
def parse_arguments():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description='Batch prediction runner')
parser.add_argument('--mode', choices=['direct', 'narrative', 'baseline', 'both'],
default='direct', help='Prediction mode')
parser.add_argument('--model', help='Single model to run (claude, gpt4, gemini, deepseek)')
parser.add_argument('--models', help='Comma-separated list of models to run')
parser.add_argument('--runs', type=int, default=5, help='Number of runs per question')
parser.add_argument('--max-concurrent', type=int, default=10,
help='Maximum concurrent requests per model')
parser.add_argument('--output-suffix', help='Custom suffix for output files')
parser.add_argument('--dataset', help='Dataset to use',
choices=['aibq3', 'aibq4'], default='aibq3')
parser.add_argument('--parallel', action='store_true',
help='Enable parallel processing (default: True)')
return parser.parse_args()
def main():
args = parse_arguments()
# Set up logging
log_filename = f"logs/batch_prediction_{int(time.time())}.log"
setup_logging(log_filename)
# Determine which models to run
if args.model:
models = [args.model]
elif args.models:
models = [m.strip() for m in args.models.split(',')]
else:
models = ['claude'] # default
# Validate models
valid_models = set(MODEL_REGISTRY.keys())
for model in models:
if model not in valid_models:
logging.error(f"Invalid model: {model}. Valid options: {', '.join(sorted(valid_models))}")
sys.exit(1)
# Map dataset choice to file path
dataset_paths = {
'aibq3': 'data_metaculus/metaculus_data_aibq3_wd.json',
'aibq4': 'data_metaculus/metaculus_data_aibq4_subset_RT.json'
}
dataset_path = dataset_paths[args.dataset]
# Load questions
questions = list_questions(dataset_path)
logging.info(f"Loaded {len(questions)} questions from {dataset_path}")
# Determine modes to run
modes = ['both'] if args.mode == 'both' else [args.mode]
if 'both' in modes:
modes = ['direct', 'narrative']
# Run predictions
for mode in modes:
logging.info(f"Starting {mode} mode predictions")
try:
results = asyncio.run(run_batch_predictions(
questions=questions,
models=models,
mode=mode,
num_runs=args.runs,
max_concurrent=args.max_concurrent,
output_suffix=args.output_suffix,
news_set=args.dataset
))
logging.info(f"Completed processing {sum(len(r) for r in results.values())} total questions")
logging.info(f"Completed {mode} mode predictions for {len(models)} models")
except Exception as e:
logging.error(f"Error in {mode} mode: {e}")
raise
if __name__ == "__main__":
main()