-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_stretch.py
156 lines (101 loc) · 6.07 KB
/
test_stretch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-
import argparse
from Configs.ConfigHandler import ConfigHandler
from DataHandler.DataHandler_stretch import DataHandler
#from transformers import GPT2Config
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch as th
#from tensorboardX import SummaryWriter
# import datetime
# import tensorflow as tf
from NucleusSampling.NucleusSampling import top_k_top_p_filtering
import torch.nn.functional as F
import nltk
import json
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_file", type=str, default='config_stretch.ini',
help="the .ini file containing all the model and program settings")
parser.add_argument("--section", type=str, default='DEFAULT',
help="the section of config file")
args = parser.parse_args()
#load all the model configuration settings from the config file
config=ConfigHandler.get_configs(filename=args.config_file,section=args.section)
print(config)
dataHandler = DataHandler()
#load the gold references file of the model for tokenizer
#gold_train = dataHandler.get_gold_train(config)
tokenizer = GPT2Tokenizer.from_pretrained(config['model'])
#add these tokens to the dictionary otherwise model considers [ENT] as
#3 seperate tokens([,ENT,])
tokenizer.add_tokens(['[ENT]', '[SEP]'])
#load the gpt2 model from transformers library
model = GPT2LMHeadModel.from_pretrained(config['model'])
#resize the token embeddings since the model has two extra tokens added
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(th.load(config['checkpoint_dir'] + config['model_test_file']))
device = th.device(config['device'])
#load the model to the default gpu/cpu device specified in config
model.to(device)
# set the model to test mode
model.eval()
#load the gold references file of the model for training
gold_test = dataHandler.get_gold_test(config)
sent_bleus_1 = []
sent_bleus_2 = []
sent_bleus_3 = []
seq_list = {}
with th.no_grad():
for table_id in gold_test:
input_tensor,output_tensor = dataHandler.get_test_embedding(gold_test,table_id,config,tokenizer)
input_tensor = input_tensor.to(device)
output_tensor = output_tensor.to(device)
input_dim = input_tensor.shape[1]
seq_list[table_id] = []
#These two lists are used to know when the required sequences are done
#being extracted. finished_template is for knowing when the template
#portion of model output is extracted indicated by predicting the [SEP] token.
#finished_sentence is for knowing when the full sequence is extracted indicated
#by the EOS token
finished_template = [False for _ in range(len(input_tensor))]
finished_sentence = [False for _ in range(len(input_tensor))]
#If no EOS token is extracted, continue extraction for max_decoding length iterations
for tok in range(int(config['max_decoding_length'])):
model_output = model(input_tensor)[0]
#nucleus sampling proposed to use the tail of each sequence
modeloutput_tail = model_output[:, -1, :]
#apply nucleus filtering. This will set most components of each vector to -inf.
filtered_tail = top_k_top_p_filtering(modeloutput_tail,
top_k=int(config['top_k']),
top_p=float(config['top_p']))
#compute softmax on the predicted tails. Since most components are set to -inf, they will get
# a probability 0 and only a single token from the output corpus will be sampled for each sequence.
predicted_tokens = th.multinomial(F.softmax(filtered_tail, dim=-1), num_samples=1)
for token in range(len(predicted_tokens)):
if predicted_tokens[token].item() == tokenizer.convert_tokens_to_ids('[SEP]'):
finished_template[token] = True
if predicted_tokens[token].item() == tokenizer.eos_token_id:
finished_sentence[token] = True
input_tensor = th.cat((input_tensor, predicted_tokens), dim=1)
if all(finished_sentence):
break;
#extract the predicted portion of the input tensor
predicted_tensor = input_tensor[:,input_dim:]
for seq in predicted_tensor:
decoded_seq = tokenizer.decode(seq, clean_up_tokenization_spaces=True)
decoded_seq = decoded_seq[decoded_seq.find('[SEP]') + 6: decoded_seq.find(tokenizer.eos_token)].strip()
seq_list[table_id].append(decoded_seq)
references = dataHandler.get_references(gold_test[table_id])
#get references from the table entry and convert to list of lists
for seq in seq_list[table_id]:
seq = seq.lower().split()
sent_bleus_1.append(nltk.translate.bleu_score.sentence_bleu(references, seq, weights=(1, 0, 0)))
sent_bleus_2.append(nltk.translate.bleu_score.sentence_bleu( references, seq, weights=(0.5, 0.5, 0)))
sent_bleus_3.append(nltk.translate.bleu_score.sentence_bleu(references, seq, weights=(0.33, 0.33, 0.33)))
bleu_1 = format((sum(sent_bleus_1) / len(sent_bleus_1) * 100), '.2f')
bleu_2 = format((sum(sent_bleus_2) / len(sent_bleus_2) * 100), '.2f')
bleu_3 = format((sum(sent_bleus_3) / len(sent_bleus_3) * 100), '.2f')
print("table: {}, bleu-1: {}, bleu-2: {}, bleu-3: {}".format(table_id,bleu_1,bleu_2,bleu_3) )
print("total corpus BLEU score = {}/{}/{}".format(bleu_1, bleu_2, bleu_3))
with open( config['test_output_dir'] + 'GPT_C2F_output.json', 'w') as f:
json.dump(seq_list, f)