-
Notifications
You must be signed in to change notification settings - Fork 10
/
emb2emb_autoencoder.py
138 lines (108 loc) · 5.12 KB
/
emb2emb_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from torch.nn.utils.rnn import pad_sequence
from autoencoders.autoencoder import AutoEncoder
from autoencoders.rnn_encoder import RNNEncoder
from autoencoders.rnn_decoder import RNNDecoder
from emb2emb.encoding import Encoder, Decoder
from tokenizers import CharBPETokenizer, SentencePieceBPETokenizer
from emb2emb.utils import Namespace
import torch
import os
import json
import copy
HUGGINGFACE_TOKENIZERS = ["CharBPETokenizer", "SentencePieceBPETokenizer"]
def tokenize(s):
# TODO: more sophisticated tokenization
return s.split()
def get_tokenizer(tokenizer, location='bert-base-uncased'):
# TODO: do we need to pass more options to the file?
tok = eval(tokenizer)(vocab_file=location + '-vocab.json',
merges_file=location + '-merges.txt')
tok.add_special_tokens(["[PAD]", "<unk>", "<SOS>", "<EOS>"])
return tok
def get_autoencoder(config):
if os.path.exists(config["default_config"]):
with open(config["default_config"]) as f:
model_config_dict = json.load(f)
else:
model_config_dict = {}
with open(os.path.join(config["modeldir"], "config.json")) as f:
orig_model_config = json.load(f)
model_config_dict.update(orig_model_config)
model_config = Namespace()
model_config.__dict__.update(model_config_dict)
tokenizer = get_tokenizer(
model_config.tokenizer, model_config.tokenizer_location)
model_config.__dict__["vocab_size"] = tokenizer.get_vocab_size()
model_config.__dict__["sos_idx"] = tokenizer.token_to_id("<SOS>")
model_config.__dict__["eos_idx"] = tokenizer.token_to_id("<EOS>")
model_config.__dict__["unk_idx"] = tokenizer.token_to_id("<unk>")
model_config.__dict__["device"] = config["device"]
encoder_config, decoder_config = copy.deepcopy(
model_config), copy.deepcopy(model_config)
encoder_config.__dict__.update(model_config.__dict__[model_config.encoder])
encoder_config.__dict__["tokenizer"] = tokenizer
decoder_config.__dict__.update(model_config.__dict__[model_config.decoder])
if model_config.encoder == "RNNEncoder":
encoder = RNNEncoder(encoder_config)
if model_config.decoder == "RNNDecoder":
decoder = RNNDecoder(decoder_config)
model = AutoEncoder(encoder, decoder, tokenizer, model_config)
checkpoint = torch.load(os.path.join(
config["modeldir"], model_config.model_file), map_location=config["device"])
model.load_state_dict(checkpoint["model_state_dict"])
return model
class AEEncoder(Encoder):
def __init__(self, config):
super(AEEncoder, self).__init__(config)
self.device = config["device"]
self.model = get_autoencoder(config)
self.use_lookup = self.model.encoder.variational
def _prepare_batch(self, indexed, lengths):
X = pad_sequence([torch.tensor(index_list, device=self.device)
for index_list in indexed], batch_first=True, padding_value=0)
lengths, idx = torch.sort(torch.tensor(
lengths, device=self.device).long(), descending=True)
return X[idx], lengths, idx
def _undo_batch(self, encoded, sort_idx):
ret = [[] for _ in range(encoded.shape[0])]
for i, c in zip(sort_idx, range(encoded.shape[0])):
ret[i] = encoded[c]
return torch.stack(ret)
def encode(self, S_list):
indexed = [self.model.tokenizer.encode(
"<SOS>" + s + "<EOS>").ids for s in S_list]
lengths = [len(i) for i in indexed]
X, X_lens, sort_idx = self._prepare_batch(indexed, lengths)
encoded = self.model.encode(X, X_lens)
# Since _prepare_batch sorts by length, we will need to undo this.
return self._undo_batch(encoded, sort_idx)
class AEDecoder(Decoder):
def __init__(self, config):
super(AEDecoder, self).__init__()
self.device = config["device"]
self.model = get_autoencoder(config)
def _prepare_batch(self, indexed, lengths):
X = pad_sequence([torch.tensor(index_list, device=self.device)
for index_list in indexed], batch_first=True, padding_value=0)
#lengths, idx = torch.sort(torch.tensor(lengths, device=self.device).long(), descending=True)
# return X[idx], lengths, idx
lengths = torch.tensor(lengths, device=self.device).long()
return X, lengths
def _encode(self, S_list):
indexed = [self.model.tokenizer.encode(
"<SOS>" + s + "<EOS>").ids for s in S_list]
lengths = [len(i) for i in indexed]
X, X_lens = self._prepare_batch(indexed, lengths)
return X, X_lens
def predict(self, S_batch, target_batch=None):
if self.training:
target_batch, target_length = self._encode(target_batch)
out = self.model.decode_training(
S_batch, target_batch, target_length)
return out, target_batch
else:
return self.model.decode(S_batch, beam_width=15)
def prediction_to_text(self, predictions):
predictions = [self.model.tokenizer.decode(
p, skip_special_tokens=True) for p in predictions]
return predictions