-
Notifications
You must be signed in to change notification settings - Fork 4
/
preprocess.py
executable file
·119 lines (92 loc) · 4.67 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import logging
import os
from collections import Counter
import torch
import config
import pykp.utils.io as io
from utils.functions import read_src_and_trg_files
def build_vocab(tokenized_src_trg_pairs):
token_freq_counter = Counter()
for src_word_list, trg_word_lists in tokenized_src_trg_pairs:
token_freq_counter.update(src_word_list)
for word_list in trg_word_lists:
token_freq_counter.update(word_list)
# Discard special tokens if already present
special_tokens = [io.PAD_WORD, io.UNK_WORD, io.BOS_WORD, io.EOS_WORD, io.SEP_WORD, io.TITLE_ABS_SEP, io.PEOS_WORD]
num_special_tokens = len(special_tokens)
for s_t in special_tokens:
if s_t in token_freq_counter:
del token_freq_counter[s_t]
word2idx = dict()
idx2word = dict()
for idx, word in enumerate(special_tokens):
word2idx[word] = idx
idx2word[idx] = word
sorted_word2idx = sorted(token_freq_counter.items(), key=lambda x: x[1], reverse=True)
sorted_words = [x[0] for x in sorted_word2idx]
for idx, word in enumerate(sorted_words):
word2idx[word] = idx + num_special_tokens
for idx, word in enumerate(sorted_words):
idx2word[idx + num_special_tokens] = word
vocab = {"word2idx": word2idx, "idx2word": idx2word, "counter": token_freq_counter}
return vocab
def main(opt):
# Tokenize train_src and train_trg, return a list of tuple, (src_word_list, [trg_1_word_list, trg_2_word_list, ...])
tokenized_train_pairs = read_src_and_trg_files(opt.train_src, opt.train_trg, is_train=True)
tokenized_valid_pairs = read_src_and_trg_files(opt.valid_src, opt.valid_trg, is_train=False)
vocab = build_vocab(tokenized_train_pairs)
opt.vocab = vocab
retriever = None
if opt.use_multidoc_graph:
from retrievers.retriever import Retriever
logging.info("Initialized retriever and loading references documents. ")
retriever = Retriever(opt)
opt.retriever = retriever
logging.info("Dumping dict to disk: %s" % opt.save_data_dir + '/vocab.pt')
torch.save(vocab, open(opt.save_data_dir + '/vocab.pt', 'wb'))
if not opt.one2many:
# saving one2one datasets
train_one2one = io.build_dataset(tokenized_train_pairs, opt, mode='one2one')
logging.info("Dumping train one2one to disk: %s" % (opt.save_data_dir + '/train.one2one.pt'))
torch.save(train_one2one, open(opt.save_data_dir + '/train.one2one.pt', 'wb'))
len_train_one2one = len(train_one2one)
del train_one2one
valid_one2one = io.build_dataset(tokenized_valid_pairs, opt, mode='one2one')
logging.info("Dumping valid to disk: %s" % (opt.save_data_dir + '/valid.one2one.pt'))
torch.save(valid_one2one, open(opt.save_data_dir + '/valid.one2one.pt', 'wb'))
logging.info('#pairs of train_one2one = %d' % len_train_one2one)
logging.info('#pairs of valid_one2one = %d' % len(valid_one2one))
else:
# saving one2many datasets
train_one2many = io.build_dataset(tokenized_train_pairs, opt, mode='one2many')
logging.info("Dumping train one2many to disk: %s" % (opt.save_data_dir + '/train.one2many.pt'))
torch.save(train_one2many, open(opt.save_data_dir + '/train.one2many.pt', 'wb'))
len_train_one2many = len(train_one2many)
del train_one2many
valid_one2many = io.build_dataset(tokenized_valid_pairs, opt, mode='one2many')
logging.info("Dumping valid to disk: %s" % (opt.save_data_dir + '/valid.one2many.pt'))
torch.save(valid_one2many, open(opt.save_data_dir + '/valid.one2many.pt', 'wb'))
logging.info('#pairs of train_one2many = %d' % len_train_one2many)
logging.info('#pairs of valid_one2many = %d' % len(valid_one2many))
logging.info('Done!')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='preprocess.py',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
config.vocab_opts(parser)
config.preprocess_opts(parser)
config.retriever_opts(parser)
opt = parser.parse_args()
logging = config.init_logging(log_file=opt.log_path + "/output.log", stdout=True)
if not opt.one2many:
test_exists = os.path.join(opt.save_data_dir, "train.one2one.pt")
else:
test_exists = os.path.join(opt.save_data_dir, "train.one2many.pt")
if os.path.exists(test_exists):
logging.info("file exists %s, exit! " % test_exists)
exit()
opt.train_src = opt.data_dir + '/train_src.txt'
opt.train_trg = opt.data_dir + '/train_trg.txt'
opt.valid_src = opt.data_dir + '/valid_src.txt'
opt.valid_trg = opt.data_dir + '/valid_trg.txt'
main(opt)