From 29192af9f11c5adba6486488a094415e43336b1a Mon Sep 17 00:00:00 2001 From: maysonma Date: Wed, 27 Sep 2023 14:40:59 -0400 Subject: [PATCH] model(random guess) and eval(acc) --- .gitignore | 2 + .idea/.gitignore | 8 +++ models/__pycache__/utils.cpython-38.pyc | Bin 0 -> 1194 bytes models/random_chance.py | 70 ++++++++++++++++++++++++ models/utils.py | 33 +++++++++++ 5 files changed, 113 insertions(+) create mode 100644 .gitignore create mode 100644 .idea/.gitignore create mode 100644 models/__pycache__/utils.cpython-38.pyc create mode 100644 models/random_chance.py create mode 100644 models/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3d031da --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +data \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/models/__pycache__/utils.cpython-38.pyc b/models/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bba9599f87dda3658842d19203624b7ffcd2edb GIT binary patch literal 1194 zcmZWpNsALf6t1PWBr|S|OK^lBUUD%Yco7lBIcgAs9vY| z#A{7s3O*}X$zx85;KDm1u_t^Hfc8ZwBG7^4GITSVu|B~mI>%IWMeoRqJAzKgaO(`d zO~ylNy}dj$8AbQ^@IZ$!pTpEk5Qdym!wf(0PW_e~1cH8Lf_-9N>5Bdr<9}nUWw1N! zlNARY3Gbb!;+6}7-e*8nv6ip6=?LF+TmFfDAtE>dO-V~zK4XA0951&URZk09>iY+} zt}Io{%S|p4-4wZ5uPH>oyg6Q!Q>g~^Y+$4|gPBT8sUHn2?GAg^bBDd6PKEXAMpl;9 z+ImfD_90g1mFkM(v*_HdY^&6E}#_~c|L(h7tn(H|FgD>v>!(}@z zeUc0Pji@Vt@BHDYlp;?@WqJtr7){brFRJ-uvNxLNW`8jqy?plS#miAy3t8yV!sLa% z-yGTwLQW7{fc4k{LFg{sWnGXvOkD&2Lirdh;8s_@0T@6x0ct5|q!^3BK=YGeMO$nJ zZ+uQiagb;!MY8lA&qJCjhsmG zM)t6hACff7ZosleJ_uu!{ZzNWg0>IW!1M*`Fg*fv5FUV$6QGFfk+=6&KwFD1iS~g| zXMuMcXxzbWWsk}G=r@__CY)55_e`LLz<|y^C>sHyM8J!SgHjrmXB!uAr1i48Y!*f$ z^gD1uM<59GXhiQ)wF~yJtMDH}VZa9Fu*s{r!lJuYaJ;%pTCR<<9QxHA+eqY?a@nHC X&~A%DKEW5{dMkoNTRiGTeiZ%!qcINu literal 0 HcmV?d00001 diff --git a/models/random_chance.py b/models/random_chance.py new file mode 100644 index 0000000..cb3c406 --- /dev/null +++ b/models/random_chance.py @@ -0,0 +1,70 @@ +# Adapted from: https://github.com/lupantech/ScienceQA/blob/main/models/run_gpt3.py + +import argparse +import json +import random +from typing import List + +from utils import load_data, get_result_file, new_acc, compute_acc + +parser = argparse.ArgumentParser() +parser.add_argument('--data_root', type=str, default='data/maplm_v0.1') +parser.add_argument('--output_dir', type=str, default='runs') +parser.add_argument('--test_split', type=str, default='test') +parser.add_argument('--test_number', type=int, default=-1, help='Number of test frames to run (default: -1, all)') +parser.add_argument('--exp_label', type=str, default='exp_random', help='Experiment label') +parser.add_argument('--random_seed', type=int, default=1, help='Random seed') +parser.add_argument('--debug', action='store_true', help='Debug mode') + +args = parser.parse_args() + +results = dict( + question_overall=new_acc(), + frame_overall=new_acc(), +) + +if __name__ == "__main__": + print('===== Input Arguments =====') + print(json.dumps(vars(args), indent=4, sort_keys=True)) + load_data(args) + + random.seed(args.random_seed) + + frames, frame_ids = load_data(args) + result_file_name = get_result_file(args) + + for i, frame_id in enumerate(frame_ids): + frame = frames[frame_id] + + image = frame['image'] + qas = frame['qa'] + + corrects = [] + for j, qa in enumerate(qas): + if qa['task'] != 'closed choice': + continue + question = qa['question'] + choices: List[str] = qa['choices'] + true_answer: int = qa['answer'] + random_guess: int = random.randint(0, len(choices) - 1) + + if question not in results: + results[question] = new_acc() + + correct = bool(random_guess == true_answer) + corrects.append(correct) + + results[question]['total'] += 1 + results[question]['correct'] += int(correct) + results['question_overall']['total'] += 1 + results['question_overall']['correct'] += int(correct) + + results['frame_overall']['total'] += 1 + results['frame_overall']['correct'] += int(all(corrects)) + + print('===== Results =====') + acc_dict = compute_acc(results) + print(json.dumps(acc_dict, indent=4, sort_keys=True)) + print(json.dumps(results, indent=4, sort_keys=True)) + + diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..8490c3c --- /dev/null +++ b/models/utils.py @@ -0,0 +1,33 @@ +import json +import os +from typing import Dict + + +def load_data(args): + problems = json.load(open(os.path.join(args.data_root, 'problems.json'), 'r')) + pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json'), 'r')) + + frame_ids = pid_splits[args.test_split] + frame_ids = frame_ids[:args.test_number] if args.test_number > 0 else frame_ids + print(f'Number of test frames: {len(frame_ids)}\n') + + return problems, frame_ids + + +def get_result_file(args): + result_file = f"{args.data_root}/{args.exp_label}_seed_{args.random_seed}.json" + return result_file + + +def new_acc(): + return { + 'total': 0, + 'correct': 0 + } + + +def compute_acc(results: Dict[str, Dict[str, int]]): + acc_dict = {} + for metric in results: + acc_dict[metric] = results[metric]['correct'] / results[metric]['total'] * 100 + return acc_dict