forked from HongChow/im2latex-1
-
Notifications
You must be signed in to change notification settings - Fork 3
/
evaluate.py
76 lines (59 loc) · 2.63 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# load best model and evaluating
from os.path import join
from functools import partial
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from data import Im2LatexDataset
from build_vocab import Vocab, load_vocab
from utils import collate_fn
from model.model import Im2LatexModel
from model.decoding import LatexProducer
from model.score import score_files
def main():
parser = argparse.ArgumentParser(description="Im2Latex Evaluating Program")
parser.add_argument('--model_path', required=True, help='path of the evaluated model')
parser.add_argument("--data_path", type=str, default="./data/sample_data/", help="The dataset's dir")
parser.add_argument("--cuda", action='store_true', default=True, help="Use cuda or not")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--beam_size", type=int, default=5)
parser.add_argument("--result_path", type=str, default="./results/result.txt", help="The file to store result")
parser.add_argument("--ref_path", type=str, default="./results/ref.txt", help="The file to store reference")
parser.add_argument("--max_len", type=int, default=64, help="Max step of decoding")
parser.add_argument("--split", type=str, default="validate", help="The data split to decode")
args = parser.parse_args()
checkpoint = torch.load(join(args.model_path))
model_args = checkpoint['args']
vocab = load_vocab(args.data_path)
use_cuda = True if args.cuda and torch.cuda.is_available() else False
data_loader = DataLoader(
Im2LatexDataset(args.data_path, args.split, args.max_len),
batch_size=args.batch_size,
collate_fn=partial(collate_fn, vocab.sign2id),
pin_memory=True if use_cuda else False,
num_workers=4
)
model = Im2LatexModel(
len(vocab), model_args.emb_dim, model_args.enc_rnn_h, model_args.dec_rnn_h
)
model.load_state_dict(checkpoint['model_state_dict'])
model.train(False)
result_file = open(args.result_path, 'w')
ref_file = open(args.ref_path, 'w')
latex_producer = LatexProducer(
model, vocab)
for imgs, tgt4training, tgt4cal_loss in tqdm(data_loader):
try:
reference = latex_producer._idx2formulas(tgt4cal_loss)
results = latex_producer._greedy_decoding(imgs)
except RuntimeError:
break
result_file.write('\n'.join(results))
ref_file.write('\n'.join(reference))
result_file.close()
ref_file.close()
score = score_files(args.result_path, args.ref_path)
print("Result:", score)
if __name__ == "__main__":
main()