|
| 1 | +import numpy as np |
| 2 | +from faker import Faker |
| 3 | +import random |
| 4 | +from tqdm import tqdm |
| 5 | +from babel.dates import format_date |
| 6 | +from tensorflow.keras.utils import to_categorical |
| 7 | +import tensorflow.keras.backend as K |
| 8 | +from tensorflow.keras.models import Model |
| 9 | +import matplotlib.pyplot as plt |
| 10 | + |
| 11 | +fake = Faker() |
| 12 | +Faker.seed(12345) |
| 13 | +random.seed(12345) |
| 14 | + |
| 15 | +# Define format of the data we would like to generate |
| 16 | +FORMATS = ['short', |
| 17 | + 'medium', |
| 18 | + 'long', |
| 19 | + 'full', |
| 20 | + 'full', |
| 21 | + 'full', |
| 22 | + 'full', |
| 23 | + 'full', |
| 24 | + 'full', |
| 25 | + 'full', |
| 26 | + 'full', |
| 27 | + 'full', |
| 28 | + 'full', |
| 29 | + 'd MMM YYY', |
| 30 | + 'd MMMM YYY', |
| 31 | + 'dd MMM YYY', |
| 32 | + 'd MMM, YYY', |
| 33 | + 'd MMMM, YYY', |
| 34 | + 'dd, MMM YYY', |
| 35 | + 'd MM YY', |
| 36 | + 'd MMMM YYY', |
| 37 | + 'MMMM d YYY', |
| 38 | + 'MMMM d, YYY', |
| 39 | + 'dd.MM.YY'] |
| 40 | + |
| 41 | +# change this if you want it to work with another language |
| 42 | +LOCALES = ['en_US'] |
| 43 | + |
| 44 | +def load_date(): |
| 45 | + """ |
| 46 | + Loads some fake dates |
| 47 | + :returns: tuple containing human readable string, machine readable string, and date object |
| 48 | + """ |
| 49 | + dt = fake.date_object() |
| 50 | + |
| 51 | + try: |
| 52 | + human_readable = format_date(dt, format=random.choice(FORMATS), locale='en_US') # locale=random.choice(LOCALES)) |
| 53 | + human_readable = human_readable.lower() |
| 54 | + human_readable = human_readable.replace(',','') |
| 55 | + machine_readable = dt.isoformat() |
| 56 | + |
| 57 | + except AttributeError as e: |
| 58 | + return None, None, None |
| 59 | + |
| 60 | + return human_readable, machine_readable, dt |
| 61 | + |
| 62 | +def load_dataset(m): |
| 63 | + """ |
| 64 | + Loads a dataset with m examples and vocabularies |
| 65 | + :m: the number of examples to generate |
| 66 | + """ |
| 67 | + |
| 68 | + human_vocab = set() |
| 69 | + machine_vocab = set() |
| 70 | + dataset = [] |
| 71 | + Tx = 30 |
| 72 | + |
| 73 | + |
| 74 | + for i in tqdm(range(m)): |
| 75 | + h, m, _ = load_date() |
| 76 | + if h is not None: |
| 77 | + dataset.append((h, m)) |
| 78 | + human_vocab.update(tuple(h)) |
| 79 | + machine_vocab.update(tuple(m)) |
| 80 | + |
| 81 | + human = dict(zip(sorted(human_vocab) + ['<unk>', '<pad>'], |
| 82 | + list(range(len(human_vocab) + 2)))) |
| 83 | + inv_machine = dict(enumerate(sorted(machine_vocab))) |
| 84 | + machine = {v:k for k,v in inv_machine.items()} |
| 85 | + |
| 86 | + return dataset, human, machine, inv_machine |
| 87 | + |
| 88 | +def preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty): |
| 89 | + |
| 90 | + X, Y = zip(*dataset) |
| 91 | + |
| 92 | + X = np.array([string_to_int(i, Tx, human_vocab) for i in X]) |
| 93 | + Y = [string_to_int(t, Ty, machine_vocab) for t in Y] |
| 94 | + |
| 95 | + Xoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(human_vocab)), X))) |
| 96 | + Yoh = np.array(list(map(lambda x: to_categorical(x, num_classes=len(machine_vocab)), Y))) |
| 97 | + |
| 98 | + |
| 99 | + |
| 100 | + return X, np.array(Y), Xoh, Yoh |
| 101 | + |
| 102 | +def string_to_int(string, length, vocab): |
| 103 | + """ |
| 104 | + Converts all strings in the vocabulary into a list of integers representing the positions of the |
| 105 | + input string's characters in the "vocab" |
| 106 | + |
| 107 | + Arguments: |
| 108 | + string -- input string, e.g. 'Wed 10 Jul 2007' |
| 109 | + length -- the number of time steps you'd like, determines if the output will be padded or cut |
| 110 | + vocab -- vocabulary, dictionary used to index every character of your "string" |
| 111 | + |
| 112 | + Returns: |
| 113 | + rep -- list of integers (or '<unk>') (size = length) representing the position of the string's character in the vocabulary |
| 114 | + """ |
| 115 | + |
| 116 | + #make lower to standardize |
| 117 | + string = string.lower() |
| 118 | + string = string.replace(',','') |
| 119 | + |
| 120 | + if len(string) > length: |
| 121 | + string = string[:length] |
| 122 | + |
| 123 | + rep = list(map(lambda x: vocab.get(x, '<unk>'), string)) |
| 124 | + |
| 125 | + if len(string) < length: |
| 126 | + rep += [vocab['<pad>']] * (length - len(string)) |
| 127 | + |
| 128 | + #print (rep) |
| 129 | + return rep |
| 130 | + |
| 131 | + |
| 132 | +def int_to_string(ints, inv_vocab): |
| 133 | + """ |
| 134 | + Output a machine readable list of characters based on a list of indexes in the machine's vocabulary |
| 135 | + |
| 136 | + Arguments: |
| 137 | + ints -- list of integers representing indexes in the machine's vocabulary |
| 138 | + inv_vocab -- dictionary mapping machine readable indexes to machine readable characters |
| 139 | + |
| 140 | + Returns: |
| 141 | + l -- list of characters corresponding to the indexes of ints thanks to the inv_vocab mapping |
| 142 | + """ |
| 143 | + |
| 144 | + l = [inv_vocab[i] for i in ints] |
| 145 | + return l |
| 146 | + |
| 147 | + |
| 148 | +EXAMPLES = ['3 May 1979', '5 Apr 09', '20th February 2016', 'Wed 10 Jul 2007'] |
| 149 | + |
| 150 | +def run_example(model, input_vocabulary, inv_output_vocabulary, text): |
| 151 | + encoded = string_to_int(text, TIME_STEPS, input_vocabulary) |
| 152 | + prediction = model.predict(np.array([encoded])) |
| 153 | + prediction = np.argmax(prediction[0], axis=-1) |
| 154 | + return int_to_string(prediction, inv_output_vocabulary) |
| 155 | + |
| 156 | +def run_examples(model, input_vocabulary, inv_output_vocabulary, examples=EXAMPLES): |
| 157 | + predicted = [] |
| 158 | + for example in examples: |
| 159 | + predicted.append(''.join(run_example(model, input_vocabulary, inv_output_vocabulary, example))) |
| 160 | + print('input:', example) |
| 161 | + print('output:', predicted[-1]) |
| 162 | + return predicted |
| 163 | + |
| 164 | + |
| 165 | +def softmax(x, axis=1): |
| 166 | + """Softmax activation function. |
| 167 | + # Arguments |
| 168 | + x : Tensor. |
| 169 | + axis: Integer, axis along which the softmax normalization is applied. |
| 170 | + # Returns |
| 171 | + Tensor, output of softmax transformation. |
| 172 | + # Raises |
| 173 | + ValueError: In case `dim(x) == 1`. |
| 174 | + """ |
| 175 | + ndim = K.ndim(x) |
| 176 | + if ndim == 2: |
| 177 | + return K.softmax(x) |
| 178 | + elif ndim > 2: |
| 179 | + e = K.exp(x - K.max(x, axis=axis, keepdims=True)) |
| 180 | + s = K.sum(e, axis=axis, keepdims=True) |
| 181 | + return e / s |
| 182 | + else: |
| 183 | + raise ValueError('Cannot apply softmax to a tensor that is 1D') |
| 184 | + |
| 185 | + |
| 186 | +def plot_attention_map(modelx, input_vocabulary, inv_output_vocabulary, text, n_s = 128, num = 7): |
| 187 | + """ |
| 188 | + Plot the attention map. |
| 189 | + |
| 190 | + """ |
| 191 | + attention_map = np.zeros((10, 30)) |
| 192 | + layer = modelx.get_layer('attention_weights') |
| 193 | + |
| 194 | + Ty, Tx = attention_map.shape |
| 195 | + |
| 196 | + human_vocab_size = 37 |
| 197 | + |
| 198 | + # Well, this is cumbersome but this version of tensorflow-keras has a bug that affects the |
| 199 | + # reuse of layers in a model with the functional API. |
| 200 | + # So, I have to recreate the model based on the functional |
| 201 | + # components and connect then one by one. |
| 202 | + # ideally it can be done simply like this: |
| 203 | + # layer = modelx.layers[num] |
| 204 | + # f = Model(modelx.inputs, [layer.get_output_at(t) for t in range(Ty)]) |
| 205 | + # |
| 206 | + |
| 207 | + X = modelx.inputs[0] |
| 208 | + s0 = modelx.inputs[1] |
| 209 | + c0 = modelx.inputs[2] |
| 210 | + s = s0 |
| 211 | + c = s0 |
| 212 | + |
| 213 | + a = modelx.layers[2](X) |
| 214 | + outputs = [] |
| 215 | + |
| 216 | + for t in range(Ty): |
| 217 | + s_prev = s |
| 218 | + s_prev = modelx.layers[3](s_prev) |
| 219 | + concat = modelx.layers[4]([a, s_prev]) |
| 220 | + e = modelx.layers[5](concat) |
| 221 | + energies = modelx.layers[6](e) |
| 222 | + alphas = modelx.layers[7](energies) |
| 223 | + context = modelx.layers[8]([alphas, a]) |
| 224 | + # Don't forget to pass: initial_state = [hidden state, cell state] (≈ 1 line) |
| 225 | + s, _, c = modelx.layers[10](context, initial_state = [s, c]) |
| 226 | + outputs.append(energies) |
| 227 | + |
| 228 | + f = Model(inputs=[X, s0, c0], outputs = outputs) |
| 229 | + |
| 230 | + |
| 231 | + s0 = np.zeros((1, n_s)) |
| 232 | + c0 = np.zeros((1, n_s)) |
| 233 | + encoded = np.array(string_to_int(text, Tx, input_vocabulary)).reshape((1, 30)) |
| 234 | + encoded = np.array(list(map(lambda x: to_categorical(x, num_classes=len(input_vocabulary)), encoded))) |
| 235 | + |
| 236 | + |
| 237 | + r = f([encoded, s0, c0]) |
| 238 | + |
| 239 | + for t in range(Ty): |
| 240 | + for t_prime in range(Tx): |
| 241 | + attention_map[t][t_prime] = r[t][0, t_prime] |
| 242 | + |
| 243 | + # Normalize attention map |
| 244 | + row_max = attention_map.max(axis=1) |
| 245 | + attention_map = attention_map / row_max[:, None] |
| 246 | + |
| 247 | + prediction = modelx.predict([encoded, s0, c0]) |
| 248 | + |
| 249 | + predicted_text = [] |
| 250 | + for i in range(len(prediction)): |
| 251 | + predicted_text.append(int(np.argmax(prediction[i], axis=1))) |
| 252 | + |
| 253 | + predicted_text = list(predicted_text) |
| 254 | + predicted_text = int_to_string(predicted_text, inv_output_vocabulary) |
| 255 | + text_ = list(text) |
| 256 | + |
| 257 | + # get the lengths of the string |
| 258 | + input_length = len(text) |
| 259 | + output_length = Ty |
| 260 | + |
| 261 | + # Plot the attention_map |
| 262 | + plt.clf() |
| 263 | + f = plt.figure(figsize=(8, 8.5)) |
| 264 | + ax = f.add_subplot(1, 1, 1) |
| 265 | + |
| 266 | + # add image |
| 267 | + i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues') |
| 268 | + |
| 269 | + # add colorbar |
| 270 | + cbaxes = f.add_axes([0.2, 0, 0.6, 0.03]) |
| 271 | + cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal') |
| 272 | + cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2) |
| 273 | + |
| 274 | + # add labels |
| 275 | + ax.set_yticks(range(output_length)) |
| 276 | + ax.set_yticklabels(predicted_text[:output_length]) |
| 277 | + |
| 278 | + ax.set_xticks(range(input_length)) |
| 279 | + ax.set_xticklabels(text_[:input_length], rotation=45) |
| 280 | + |
| 281 | + ax.set_xlabel('Input Sequence') |
| 282 | + ax.set_ylabel('Output Sequence') |
| 283 | + |
| 284 | + # add grid and legend |
| 285 | + ax.grid() |
| 286 | + |
| 287 | + #f.show() |
| 288 | + |
| 289 | + return attention_map |
0 commit comments