-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_single_ckpt.py
96 lines (76 loc) · 4.44 KB
/
evaluate_single_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Evaluate single ckpt.
# We recommend you to run this script after "evaluate_text2sql_ckpts.py", to evaluate a split of interest.
import argparse
import os
import json
import torch
import shutil
import logging
from text2sql import _test_spider, _test_mschema2qa
logger = logging.getLogger(__name__)
def parse_option():
parser = argparse.ArgumentParser("command line arguments for selecting the best ckpt.")
parser.add_argument('--batch_size', type = int, default = 8,
help = 'input batch size. Note that this is a effective batch size')
parser.add_argument('--device', type = str, default = "2",
help = 'the id of used GPU device.')
parser.add_argument('--seed', type = int, default = 42,
help = 'random seed.')
parser.add_argument('--save_path', type = str, default = "./models/mt5-large-16-text2sql/best_model/",
help = 'save path of fine-tuned text2sql model.')
parser.add_argument('--model_name_or_path', type=str, default= "mt5",
help="Type of model used for evaluation") # TODO: the name is might be confusing.. change it to model_type later.
parser.add_argument('--eval_results_path', type = str, default = "./eval_results/text2sql",
help = 'the evaluation results of fine-tuned text2sql models.')
parser.add_argument('--eval_file_name', type=str, default="eval_res.txt")
parser.add_argument('--mode', type = str, default = "eval",
help='eval.')
parser.add_argument('--dev_filepath', type = str, default = "./data/pre-processing/resdsql_test.json",
help = 'file path of test2sql dev set.')
parser.add_argument('--original_dev_filepath', type = str, default = "./data/spider/dev.json",
help = 'file path of the original dev set (for registing evaluator).')
parser.add_argument('--db_path', type = str, default = "./data/spider/database",
help = 'file path of database.')
parser.add_argument('--preprocessed_dataset_path', type = str, default = "/home/deokhk/research/ZX-seq2seq/data/multispider_preprocessed_data/preprocessed_dev_de_normalized.json",
help="Multispider evaluation only.")
parser.add_argument('--num_beams', type = int, default = 8,
help = 'beam size in model.generate() function.')
parser.add_argument('--num_return_sequences', type = int, default = 8,
help = 'the number of returned sequences in model.generate() function (num_return_sequences <= num_beams).')
parser.add_argument("--output", type = str, default = "predicted_sql.txt")
parser.add_argument("--dataset_type", type=str, choices=["spider", "mschema2qa"], default="spider")
parser.add_argument("--dataset_lang", type=str, default="en")
parser.add_argument("--save_predictions", action="store_true", default=False)
parser.add_argument("--save_predictions_path", type=str, default="./predictions/text2sql")
opt = parser.parse_args()
return opt
if __name__ == "__main__":
opt = parse_option()
save_path = opt.save_path
os.makedirs(opt.eval_results_path, exist_ok = True)
dev_filepath = opt.dev_filepath
original_dev_filepath = opt.original_dev_filepath
logger.info("Start evaluating ckpt at: {}".format(opt.save_path))
with open(opt.eval_results_path+f"/{opt.eval_file_name}", "w") as f:
f.write("Evaluating...")
if opt.dataset_type == "spider":
em, exec = _test_spider(opt)
eval_result = dict()
eval_result["ckpt"] = opt.save_path
eval_result["EM"] = em
eval_result["EXEC"] = exec
with open(opt.eval_results_path+f"/{opt.eval_file_name}", "w") as f:
f.write(json.dumps(eval_result, indent = 2, ensure_ascii = False))
elif opt.dataset_type == "mschema2qa":
em = _test_mschema2qa(opt)
eval_result = dict()
eval_result["ckpt"] = opt.save_path
eval_result["EM"] = em
with open(opt.eval_results_path+f"/{opt.eval_file_name}", "w") as f:
f.write(json.dumps(eval_result, indent = 2, ensure_ascii = False))
else:
raise NotImplementedError("Dataset type not implemented.")
logger.info("Finish evaluating ckpt at: {}".format(opt.save_path))
logger.info("EM: {}".format(em))
if (opt.dataset_type != "mschema2qa"):
logger.info("EXEC: {}".format(exec))