-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdata.py
229 lines (180 loc) · 7.01 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# Dataset
import os
import numpy
import random
import torch
from torch.utils.data import Dataset
# From https://pytorch.org/docs/stable/notes/randomness.html
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
numpy.random.seed(worker_seed)
random.seed(worker_seed)
class Vocabulary(object):
def __init__(self, vocab_dict=None, vocab_file=None,
include_unk=False, unk_str='<unk>',
include_eos=False, eos_str='<eos>',
no_out_str='_',
pad_id=None, pad_str=None):
# If provided, contruction from dict is prioritized.
self.str2idx = {}
self.idx2str = []
self.no_out_str = no_out_str
if include_eos:
self.add_str(eos_str)
self.eos_str = eos_str
if include_unk:
self.add_str(unk_str)
self.unk_str = unk_str
if vocab_dict is not None:
self.contruct_from_dict(vocab_dict)
elif vocab_file is not None:
self.contruct_from_file(vocab_file)
def contruct_from_file(self, vocab_file):
# Expect each line to contain "token_str idx", space separated.
print(f"Creating vocab from: {vocab_file}")
tmp_idx2str_dict = {}
with open(vocab_file, 'r') as text:
for line in text:
vocab_pair = line.split()
assert vocab_pair == 2, "Unexpected vocab format."
token_str, token_idx = vocab_pair
assert False, "Not implemented yet."
def contruct_from_dict(self, vocab_dict):
self.str2idx = vocab_dict
vocab_size = len(vocab_dict.keys())
assert False, "Not implemented yet."
def get_idx(self, stg):
return self.str2idx[stg]
def get_str(self, idx):
return self.idx2str(idx)
# Increment the vocab size, give the new index to the new token.
def add_str(self, stg):
if stg not in self.str2idx.keys():
self.idx2str.append(stg)
self.str2idx[stg] = len(self.idx2str) - 1
# Return vocab size.
def size(self):
return len(self.idx2str)
def get_no_op_id(self):
return self.str2idx[self.no_out_str]
def get_unk_str(self):
return self.unk_str
class LTEDataset(Dataset):
def __init__(self, src_file, tgt_file, src_pad_idx, tgt_pad_idx,
src_vocab=None, tgt_vocab=None, device='cuda'):
self.src_max_seq_length = None # set by text_to_data
self.tgt_max_seq_length = None
build_src_vocab = False
if src_vocab is None:
build_src_vocab = True
self.src_vocab = Vocabulary()
else:
self.src_vocab = src_vocab
build_tgt_vocab = False
if tgt_vocab is None:
build_tgt_vocab = True
self.tgt_vocab = Vocabulary()
else:
self.tgt_vocab = tgt_vocab
self.data = self.text_to_data(
src_file, tgt_file, src_pad_idx, tgt_pad_idx,
build_src_vocab, build_tgt_vocab, device)
self.data_size = len(self.data)
def __len__(self): # To be used by PyTorch Dataloader.
return self.data_size
def __getitem__(self, index): # To be used by PyTorch Dataloader.
return self.data[index]
def text_to_data(self, src_file, tgt_file, src_pad_idx, tgt_pad_idx,
build_src_vocab=None, build_tgt_vocab=None,
device='cuda'):
# Convert paired src/tgt texts into torch.tensor data.
# All sequences are padded to the length of the longest sequence
# of the respective file.
assert os.path.exists(src_file)
assert os.path.exists(tgt_file)
data_list = []
# Check the max length, if needed construct vocab file.
src_max = 0
with open(src_file, 'r') as text:
for line in text:
tokens = line.split()
length = len(tokens)
if src_max < length:
src_max = length
if build_src_vocab:
for token in tokens:
self.src_vocab.add_str(token)
self.src_max_seq_length = src_max
tgt_max = 0
with open(tgt_file, 'r') as text:
for line in text:
tokens = line.split()
length = len(tokens)
if tgt_max < length:
tgt_max = length
if build_tgt_vocab:
for token in tokens:
self.tgt_vocab.add_str(token)
self.tgt_max_seq_length = tgt_max
# Construct data
src_list = []
print(f"Loading source file from: {src_file}")
with open(src_file, 'r') as text:
for line in text:
seq = []
tokens = line.split()
for token in tokens:
seq.append(self.src_vocab.get_idx(token))
var_len = len(seq)
var_seq = torch.tensor(seq, device=device, dtype=torch.int64)
# padding
new_seq = var_seq.data.new(src_max).fill_(src_pad_idx)
new_seq[:var_len] = var_seq
src_list.append(new_seq)
tgt_list = []
print(f"Loading target file from: {tgt_file}")
with open(tgt_file, 'r') as text:
for line in text:
seq = []
tokens = line.split()
for token in tokens:
seq.append(self.tgt_vocab.get_idx(token))
var_len = len(seq)
var_seq = torch.tensor(seq, device=device, dtype=torch.int64)
# padding
new_seq = var_seq.data.new(tgt_max).fill_(tgt_pad_idx)
new_seq[:var_len] = var_seq
tgt_list.append(new_seq)
# src_file and tgt_file are assumed to be aligned.
assert len(src_list) == len(tgt_list)
for i in range(len(src_list)):
data_list.append((src_list[i], tgt_list[i]))
return data_list
if __name__ == '__main__':
from datetime import datetime
import random
import argparse
from torch.utils.data import DataLoader
torch.manual_seed(123)
random.seed(123)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(123)
parser = argparse.ArgumentParser(description='Learning to execute')
parser.add_argument(
'--data_dir', type=str,
default='./data/',
help='location of the data corpus')
args = parser.parse_args()
data_path = args.data_dir
file_src = f"{data_path}/valid_3.src"
file_tgt = f"{data_path}/valid_3.tgt"
bsz = 3
dummy_data = LTEDataset(src_file=file_src, tgt_file=file_tgt,
src_pad_idx=0, tgt_pad_idx=-1,
src_vocab=None, tgt_vocab=None)
data_loader = DataLoader(dataset=dummy_data, batch_size=bsz, shuffle=True)
stop_ = 2
for idx, batch in enumerate(data_loader):
src, tgt = batch
if idx < stop_:
print(src[:, 0:20])