-
Notifications
You must be signed in to change notification settings - Fork 25
/
prepare_vocab.py
115 lines (96 loc) · 4.09 KB
/
prepare_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Prepare vocabulary and initial word vectors.
"""
import json
import pickle
import argparse
import numpy as np
from collections import Counter
from utils import vocab, constant, helper, jsonl
def parse_args():
parser = argparse.ArgumentParser(description='Prepare vocab for embeddings.')
parser.add_argument('data_dir', help='Data directory')
parser.add_argument('vocab_dir', help='Output vocab directory.')
parser.add_argument('--random', action='store_true', help='Randomly initialize vectors.')
parser.add_argument('--glove_dir', default='dataset/glove', help='GloVe directory.')
parser.add_argument('--wv_file', default='radglove.800M.100d.txt', help='GloVe vector file.')
parser.add_argument('--wv_dim', type=int, default=100, help='GloVe vector dimension.')
parser.add_argument('--min_freq', type=int, default=0, help='If > 0, use min_freq as the cutoff.')
parser.add_argument('--lower', action='store_true', help='If specified, lowercase all words.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# input files
wv_file = args.glove_dir + '/' + args.wv_file
wv_dim = args.wv_dim
# output files
helper.ensure_dir(args.vocab_dir)
vocab_file = args.vocab_dir + '/vocab.pkl'
emb_file = args.vocab_dir + '/embedding.npy'
# load files
print("loading files...")
train_file = args.data_dir + '/train.jsonl'
dev_file = args.data_dir + '/dev.jsonl'
test_file = args.data_dir + '/test.jsonl'
train_tokens = load_tokens(train_file)
dev_tokens = load_tokens(dev_file)
test_tokens = load_tokens(test_file)
if args.lower:
train_tokens, dev_tokens, test_tokens = [[t.lower() for t in tokens] for tokens in\
(train_tokens, dev_tokens, test_tokens)]
# load glove
print("loading glove...")
glove_vocab = vocab.load_glove_vocab(wv_file, wv_dim)
print("{} words loaded from glove.".format(len(glove_vocab)))
print("building vocab...")
all_tokens = train_tokens + dev_tokens + test_tokens
v = build_vocab(all_tokens, glove_vocab, args.min_freq)
print("calculating oov...")
datasets = {'train': train_tokens, 'dev': dev_tokens, 'test': test_tokens}
for dname, d in datasets.items():
total, oov = count_oov(d, v)
print("{} oov: {}/{} ({:.2f}%)".format(dname, oov, total, oov*100.0/total))
print("building embeddings...")
if args.random:
print("using random initialization...")
embedding = random_embedding(v, wv_dim)
else:
embedding = vocab.build_embedding(wv_file, v, wv_dim)
print("embedding size: {} x {}".format(*embedding.shape))
print("dumping to files...")
with open(vocab_file, 'wb') as outfile:
pickle.dump(v, outfile)
np.save(emb_file, embedding)
print("all done.")
def random_embedding(vocab, wv_dim):
embeddings = 2 * np.random.rand(len(vocab), wv_dim) - 1.0
return embeddings
def load_tokens(filename):
with open(filename) as infile:
data = jsonl.load(infile)
tokens = []
for d in data:
tokens += d['findings'] + d['impression'] + d['background']
tokens = list(map(vocab.normalize_token, tokens))
print("{} tokens from {} examples loaded from {}.".format(len(tokens), len(data), filename))
return tokens
def build_vocab(tokens, glove_vocab, min_freq):
""" build vocab from tokens and glove words. """
counter = Counter(t for t in tokens)
# if min_freq > 0, use min_freq, otherwise keep all glove words
if min_freq > 0:
v = sorted([t for t in counter if counter.get(t) >= min_freq], key=counter.get, reverse=True)
else:
v = sorted([t for t in counter if t in glove_vocab], key=counter.get, reverse=True)
# add special tokens and entity mask tokens
v = constant.VOCAB_PREFIX + v
print("vocab built with {}/{} words.".format(len(v), len(counter)))
return v
def count_oov(tokens, vocab):
c = Counter(t for t in tokens)
total = sum(c.values())
matched = sum(c[t] for t in vocab)
return total, total-matched
if __name__ == '__main__':
main()