-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_twt_bert2bert.py
364 lines (317 loc) · 15.1 KB
/
train_twt_bert2bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import os
import argparse
from typing import Optional
from dataclasses import dataclass, field
import torch
import datasets
import sacrebleu
from transformers import (
TrainingArguments,
BertTokenizerFast,
BertTokenizer
)
from model import (
TWTBertGenerationEncoder,
TWTBertGenerationDecoder,
TWTEncoderDecoderModel,
Seq2SeqTrainer
)
from utils.data_utils import inc_load_cache
from language.totto.totto_to_twt_utils import ADDITIONAL_SPECIAL_TOKENS
ENCODE_MAX_LENGTH = 512
DECODE_MAX_LENGTH = 128
def collate_fn(batch):
return torch.utils.data.dataloader.default_collate(batch)
class CachedDataset():
def __init__(self, model_inputs, args):
self.model_inputs = model_inputs
self.args = args
def __getitem__(self, index):
model_inputs = self.model_inputs[index]
input_item = {}
input_item['input_ids'] = torch.tensor(model_inputs['input_input_ids'])
input_item['type_ids'] = torch.tensor(model_inputs['input_type_ids'])
if not self.args.no_row_col_embeddings:
input_item['row_ids'] = torch.tensor(model_inputs['input_row_ids'])
input_item['col_ids'] = torch.tensor(model_inputs['input_col_ids'])
else:
input_item['row_ids'] = torch.zeros(len(model_inputs['input_row_ids']), dtype=torch.int64)
input_item['col_ids'] = torch.zeros(len(model_inputs['input_col_ids']), dtype=torch.int64)
input_item['attention_mask'] = torch.tensor(model_inputs['input_attention_mask'])
input_item['decoder_input_ids'] = torch.tensor(model_inputs['output_input_ids'])
input_item['decoder_attention_mask'] = torch.tensor(model_inputs['output_attention_mask'])
input_item['decoder_copy_mask'] = torch.tensor(model_inputs['output_copy_mask'])
input_item['cross_attention_mask'] = torch.tensor(model_inputs['cross_attention_mask'])
input_item["labels"] = model_inputs['output_labels']
# because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`.
# We have to make sure that the PAD token is ignored
input_item["labels"] = [-100 if token == 0 else token for token in input_item["labels"]]
input_item["labels"] = torch.tensor(input_item["labels"])
return input_item
def __len__(self):
return len(self.model_inputs)
@dataclass
class Seq2SeqTrainingArguments(TrainingArguments):
label_smoothing: Optional[float] = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."}
)
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."})
predict_with_generate: bool = field(
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
)
adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
encoder_layerdrop: Optional[float] = field(
default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."}
)
decoder_layerdrop: Optional[float] = field(
default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."}
)
dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."})
attention_dropout: Optional[float] = field(
default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
)
lr_scheduler: Optional[str] = field(
default="linear", metadata={"help": "Which lr scheduler to use."}
)
class BertSeq2Seq(object):
def __init__(self, args):
self.args = args
self.tokenizer = self._init_tokenizer()
self.model = self._init_model()
# Create output dir if it doesn't exist
if not os.path.exists(self.args.output_dir):
os.makedirs(self.args.output_dir, exist_ok=True)
def _init_tokenizer(self):
tokenizer = BertTokenizerFast.from_pretrained(f"bert-{self.args.model_size}-uncased")
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
return tokenizer
def _init_model(self):
if not self.args.load_checkpoint:
encoder_config = TWTBertGenerationEncoder.config_class.from_pretrained(
f"bert-{self.args.model_size}-uncased",
return_unused_kwargs=False,
force_download=False,
resume_download=False,
local_files_only=False,
bos_token_id=101,
eos_token_id=102
)
encoder_config.type_vocab_size = 4
encoder_config.max_row_embeddings = encoder_config.max_position_embeddings
encoder_config.max_col_embeddings = encoder_config.max_position_embeddings
encoder_config.no_row_col_embeddings = self.args.no_row_col_embeddings
# use BERT's cls token as BOS token and sep token as EOS token
encoder = TWTBertGenerationEncoder.from_pretrained(f"bert-{self.args.model_size}-uncased", config=encoder_config)
# resize token embeddings
encoder.resize_token_embeddings(len(self.tokenizer))
decoder_config = TWTBertGenerationDecoder.config_class.from_pretrained(
f"bert-{self.args.model_size}-uncased",
return_unused_kwargs=False,
force_download=False,
resume_download=False,
local_files_only=False,
add_cross_attention=True,
is_decoder=True,
output_attentions=True,
output_hidden_states=True,
bos_token_id=101,
eos_token_id=102
)
decoder_config.no_copy = self.args.no_copy
decoder_config.no_cross_attention_mask = self.args.no_cross_attention_mask
decoder_config.gen_loss_rate = self.args.gen_loss_rate
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = TWTBertGenerationDecoder.from_pretrained(f"bert-{self.args.model_size}-uncased", config=decoder_config)
# resize token embeddings
decoder.resize_token_embeddings(len(self.tokenizer))
# load model
bert2bert = TWTEncoderDecoderModel(encoder=encoder, decoder=decoder)
else:
# load model from checkpoint path
bert2bert = TWTEncoderDecoderModel.from_pretrained(self.args.checkpoint_path)
# set special tokens
bert2bert.config.decoder_start_token_id = self.tokenizer.bos_token_id
bert2bert.config.eos_token_id = self.tokenizer.eos_token_id
bert2bert.config.pad_token_id = self.tokenizer.pad_token_id
# sensible parameters for beam search
bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 50
# bert2bert.config.min_length = 10
# bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
# bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4
return bert2bert
def compute_metrics(self, pred):
# load rouge for validation
rouge = datasets.load_metric("rouge")
labels_ids = pred.label_ids
pred_ids = pred.predictions
# all unnecessary tokens are removed
pred_str = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = self.tokenizer.pad_token_id
label_str = self.tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
sacrebleu_output = sacrebleu.corpus_bleu(pred_str, [label_str])
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
with open(self.args.output_dir + "/eval_label_" + str(round(sacrebleu_output.score, 4)) + ".txt", 'w') as f:
f.write(os.linesep.join(label_str))
with open(self.args.output_dir + "/eval_predict_" + str(round(sacrebleu_output.score, 4)) + ".txt", 'w') as f:
f.write(os.linesep.join(pred_str))
return {
"sacrebleu_score": round(sacrebleu_output.score, 4),
"rouge2_precision": round(rouge_output.precision, 4),
"rouge2_recall": round(rouge_output.recall, 4),
"rouge2_fmeasure": round(rouge_output.fmeasure, 4),
}
def train(self):
checkpoints_dir = self.args.output_dir + "/checkpoints/"
# Create checkpoint dir if it doesn't exist
if not os.path.exists(checkpoints_dir):
os.makedirs(checkpoints_dir, exist_ok=True)
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
output_dir=checkpoints_dir,
learning_rate=self.args.learning_rate,
per_device_train_batch_size=self.args.train_batch_size,
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
per_device_eval_batch_size=self.args.eval_batch_size,
# eval_accumulation_steps=4,
predict_with_generate=True,
evaluation_strategy=self.args.evaluation_strategy,
do_train=self.args.do_train,
do_eval=self.args.do_eval,
logging_steps=self.args.logging_steps, # set to 1000 for full training
save_steps=self.args.save_steps, # set to 500 for full training
eval_steps=self.args.eval_steps, # set to 8000 for full training
warmup_steps=self.args.warmup_steps, # set to 2000 for full training
max_steps=self.args.max_steps, # delete for full training
num_train_epochs=self.args.num_train_epochs,
overwrite_output_dir=True,
# save_total_limit=3,
fp16=False,
# load_best_model_at_end=True,
# metric_for_best_model="sacrebleu_score",
# greater_is_better=True
)
# load data
train_model_inputs = inc_load_cache(self.args.train_model_inputs_file)
print(f"Train data size: {str(len(train_model_inputs))}")
train_dataset = CachedDataset(train_model_inputs, self.args)
val_model_inputs = inc_load_cache(self.args.val_model_inputs_file)
print(f"Dev data size: {str(len(val_model_inputs))}")
val_dataset = CachedDataset(val_model_inputs, self.args)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=self.model,
args=training_args,
data_collator=collate_fn,
compute_metrics=self.compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
)
trainer.train()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_size",
type=str,
default="base",
help="Model size: base or large")
parser.add_argument("--train_model_inputs_file",
type=str,
default="./data/cache/twt/totto_random_prefix_bert2bert_base_train_model_inputs.pkl.gz",
help="Path of the train file")
parser.add_argument("--val_model_inputs_file",
type=str,
default="./data/cache/twt/totto_random_prefix_bert2bert_base_dev_model_inputs.pkl.gz",
help='Path of the validation file')
parser.add_argument("--output_dir",
type=str,
default="./output",
help="Output dirctory")
parser.add_argument("--no_row_col_embeddings",
action="store_true",
default=False,
help="Whether to use row/col embeddings.")
parser.add_argument("--no_copy",
action="store_true",
default=False,
help="Whether to use copy.")
parser.add_argument("--no_cross_attention_mask",
action="store_true",
default=False,
help="Whether to use cross attention mask.")
parser.add_argument("--gen_loss_rate",
type=float,
default=0.4,
help="Rate to multiply by gen loss.")
parser.add_argument("--do_train",
action="store_true",
default=True,
help="If do training")
parser.add_argument("--learning_rate",
type=float,
default=5e-5,
help="Learning rate")
parser.add_argument("--train_batch_size",
type=int,
default=8,
help="Train batch size")
parser.add_argument("--gradient_accumulation_steps",
type=int,
default=1,
help="Learning rate")
parser.add_argument("--num_train_epochs",
type=int,
default=40,
help="Training epochs")
parser.add_argument("--warmup_steps",
type=int,
default=2000,
help="Warmup steps")
parser.add_argument("--logging_steps",
type=int,
default=1000,
help="Logging steps")
parser.add_argument("--save_steps",
type=int,
default=2000,
help="Model save steps")
parser.add_argument("--max_steps",
type=int,
default=100000,
help="Max steps to train")
parser.add_argument("--do_eval",
action="store_true",
default=False,
help="If do evaluation")
parser.add_argument("--evaluation_strategy",
type=str,
default="no",
help="Evaluation strategy: steps or no")
parser.add_argument("--eval_batch_size",
type=int,
default=2,
help="Evaluation batch size")
parser.add_argument("--eval_steps",
type=int,
default=2000,
help="Model save steps")
parser.add_argument("--load_checkpoint",
action="store_true",
default=False,
help="Whether to load existing checkpoint.")
parser.add_argument("--checkpoint_path",
type=str,
default="",
help="Existing checkpoint_path")
args = parser.parse_args()
return args
def main():
args = parse_args()
print(args)
bert_seq2seq = BertSeq2Seq(args)
bert_seq2seq.train()
if __name__ == '__main__':
main()