-
Notifications
You must be signed in to change notification settings - Fork 97
/
prepare_vocab.py
113 lines (94 loc) · 3.93 KB
/
prepare_vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Prepare vocabulary and initial word vectors.
"""
import json
import pickle
import argparse
import numpy as np
from collections import Counter
from utils import vocab, constant, helper
def parse_args():
parser = argparse.ArgumentParser(description='Prepare vocab for relation extraction.')
parser.add_argument('data_dir', help='TACRED directory.')
parser.add_argument('vocab_dir', help='Output vocab directory.')
parser.add_argument('--glove_dir', default='dataset/glove', help='GloVe directory.')
parser.add_argument('--wv_file', default='glove.840B.300d.txt', help='GloVe vector file.')
parser.add_argument('--wv_dim', type=int, default=300, help='GloVe vector dimension.')
parser.add_argument('--min_freq', type=int, default=0, help='If > 0, use min_freq as the cutoff.')
parser.add_argument('--lower', action='store_true', help='If specified, lowercase all words.')
args = parser.parse_args()
return args
def main():
args = parse_args()
# input files
train_file = args.data_dir + '/train.json'
dev_file = args.data_dir + '/dev.json'
test_file = args.data_dir + '/test.json'
wv_file = args.glove_dir + '/' + args.wv_file
wv_dim = args.wv_dim
# output files
helper.ensure_dir(args.vocab_dir)
vocab_file = args.vocab_dir + '/vocab.pkl'
emb_file = args.vocab_dir + '/embedding.npy'
# load files
print("loading files...")
train_tokens = load_tokens(train_file)
dev_tokens = load_tokens(dev_file)
test_tokens = load_tokens(test_file)
if args.lower:
train_tokens, dev_tokens, test_tokens = [[t.lower() for t in tokens] for tokens in\
(train_tokens, dev_tokens, test_tokens)]
# load glove
print("loading glove...")
glove_vocab = vocab.load_glove_vocab(wv_file, wv_dim)
print("{} words loaded from glove.".format(len(glove_vocab)))
print("building vocab...")
v = build_vocab(train_tokens, glove_vocab, args.min_freq)
print("calculating oov...")
datasets = {'train': train_tokens, 'dev': dev_tokens, 'test': test_tokens}
for dname, d in datasets.items():
total, oov = count_oov(d, v)
print("{} oov: {}/{} ({:.2f}%)".format(dname, oov, total, oov*100.0/total))
print("building embeddings...")
embedding = vocab.build_embedding(wv_file, v, wv_dim)
print("embedding size: {} x {}".format(*embedding.shape))
print("dumping to files...")
with open(vocab_file, 'wb') as outfile:
pickle.dump(v, outfile)
np.save(emb_file, embedding)
print("all done.")
def load_tokens(filename):
with open(filename) as infile:
data = json.load(infile)
tokens = []
for d in data:
tokens += d['token']
print("{} tokens from {} examples loaded from {}.".format(len(tokens), len(data), filename))
return tokens
def build_vocab(tokens, glove_vocab, min_freq):
""" build vocab from tokens and glove words. """
counter = Counter(t for t in tokens)
# if min_freq > 0, use min_freq, otherwise keep all glove words
if min_freq > 0:
v = sorted([t for t in counter if counter.get(t) >= min_freq], key=counter.get, reverse=True)
else:
v = sorted([t for t in counter if t in glove_vocab], key=counter.get, reverse=True)
# add special tokens and entity mask tokens
v = constant.VOCAB_PREFIX + entity_masks() + v
print("vocab built with {}/{} words.".format(len(v), len(counter)))
return v
def count_oov(tokens, vocab):
c = Counter(t for t in tokens)
total = sum(c.values())
matched = sum(c[t] for t in vocab)
return total, total-matched
def entity_masks():
""" Get all entity mask tokens as a list. """
masks = []
subj_entities = list(constant.SUBJ_NER_TO_ID.keys())[2:]
obj_entities = list(constant.OBJ_NER_TO_ID.keys())[2:]
masks += ["SUBJ-" + e for e in subj_entities]
masks += ["OBJ-" + e for e in obj_entities]
return masks
if __name__ == '__main__':
main()