Skip to content

Commit

Permalink
model(random guess) and eval(acc)
Browse files Browse the repository at this point in the history
  • Loading branch information
maysonma committed Sep 27, 2023
1 parent 1b6f89e commit 29192af
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.idea
data
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file added models/__pycache__/utils.cpython-38.pyc
Binary file not shown.
70 changes: 70 additions & 0 deletions models/random_chance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Adapted from: https://github.com/lupantech/ScienceQA/blob/main/models/run_gpt3.py

import argparse
import json
import random
from typing import List

from utils import load_data, get_result_file, new_acc, compute_acc

parser = argparse.ArgumentParser()
parser.add_argument('--data_root', type=str, default='data/maplm_v0.1')
parser.add_argument('--output_dir', type=str, default='runs')
parser.add_argument('--test_split', type=str, default='test')
parser.add_argument('--test_number', type=int, default=-1, help='Number of test frames to run (default: -1, all)')
parser.add_argument('--exp_label', type=str, default='exp_random', help='Experiment label')
parser.add_argument('--random_seed', type=int, default=1, help='Random seed')
parser.add_argument('--debug', action='store_true', help='Debug mode')

args = parser.parse_args()

results = dict(
question_overall=new_acc(),
frame_overall=new_acc(),
)

if __name__ == "__main__":
print('===== Input Arguments =====')
print(json.dumps(vars(args), indent=4, sort_keys=True))
load_data(args)

random.seed(args.random_seed)

frames, frame_ids = load_data(args)
result_file_name = get_result_file(args)

for i, frame_id in enumerate(frame_ids):
frame = frames[frame_id]

image = frame['image']
qas = frame['qa']

corrects = []
for j, qa in enumerate(qas):
if qa['task'] != 'closed choice':
continue
question = qa['question']
choices: List[str] = qa['choices']
true_answer: int = qa['answer']
random_guess: int = random.randint(0, len(choices) - 1)

if question not in results:
results[question] = new_acc()

correct = bool(random_guess == true_answer)
corrects.append(correct)

results[question]['total'] += 1
results[question]['correct'] += int(correct)
results['question_overall']['total'] += 1
results['question_overall']['correct'] += int(correct)

results['frame_overall']['total'] += 1
results['frame_overall']['correct'] += int(all(corrects))

print('===== Results =====')
acc_dict = compute_acc(results)
print(json.dumps(acc_dict, indent=4, sort_keys=True))
print(json.dumps(results, indent=4, sort_keys=True))


33 changes: 33 additions & 0 deletions models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import os
from typing import Dict


def load_data(args):
problems = json.load(open(os.path.join(args.data_root, 'problems.json'), 'r'))
pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json'), 'r'))

frame_ids = pid_splits[args.test_split]
frame_ids = frame_ids[:args.test_number] if args.test_number > 0 else frame_ids
print(f'Number of test frames: {len(frame_ids)}\n')

return problems, frame_ids


def get_result_file(args):
result_file = f"{args.data_root}/{args.exp_label}_seed_{args.random_seed}.json"
return result_file


def new_acc():
return {
'total': 0,
'correct': 0
}


def compute_acc(results: Dict[str, Dict[str, int]]):
acc_dict = {}
for metric in results:
acc_dict[metric] = results[metric]['correct'] / results[metric]['total'] * 100
return acc_dict

0 comments on commit 29192af

Please sign in to comment.