|
| 1 | +""" |
| 2 | +#Trains a TCN on the IMDB sentiment classification task. |
| 3 | +Output after 1 epochs on CPU: ~0.8611 |
| 4 | +Time per epoch on CPU (Core i7): ~64s. |
| 5 | +Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py |
| 6 | +""" |
| 7 | +import numpy as np |
| 8 | +from keras import Model, Input |
| 9 | +from keras.datasets import imdb |
| 10 | +from keras.layers import Dense, Dropout, Embedding |
| 11 | +from keras.preprocessing import sequence |
| 12 | + |
| 13 | +from tcn import TCN |
| 14 | + |
| 15 | +max_features = 20000 |
| 16 | +# cut texts after this number of words |
| 17 | +# (among top max_features most common words) |
| 18 | +maxlen = 100 |
| 19 | +batch_size = 32 |
| 20 | + |
| 21 | +print('Loading data...') |
| 22 | +(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) |
| 23 | +print(len(x_train), 'train sequences') |
| 24 | +print(len(x_test), 'test sequences') |
| 25 | + |
| 26 | +print('Pad sequences (samples x time)') |
| 27 | +x_train = sequence.pad_sequences(x_train, maxlen=maxlen) |
| 28 | +x_test = sequence.pad_sequences(x_test, maxlen=maxlen) |
| 29 | +print('x_train shape:', x_train.shape) |
| 30 | +print('x_test shape:', x_test.shape) |
| 31 | +y_train = np.array(y_train) |
| 32 | +y_test = np.array(y_test) |
| 33 | + |
| 34 | +i = Input(shape=(maxlen,)) |
| 35 | +x = Embedding(max_features, 128)(i) |
| 36 | +x = TCN(nb_filters=64, |
| 37 | + kernel_size=6, |
| 38 | + dilations=[1, 2, 4, 8, 16, 32, 64])(x) |
| 39 | +x = Dropout(0.5)(x) |
| 40 | +x = Dense(1, activation='sigmoid')(x) |
| 41 | + |
| 42 | +model = Model(inputs=[i], outputs=[x]) |
| 43 | + |
| 44 | +model.summary() |
| 45 | + |
| 46 | +# try using different optimizers and different optimizer configs |
| 47 | +model.compile('adam', 'binary_crossentropy', metrics=['accuracy']) |
| 48 | + |
| 49 | +print('Train...') |
| 50 | +model.fit(x_train, y_train, |
| 51 | + batch_size=batch_size, |
| 52 | + epochs=1, |
| 53 | + validation_data=[x_test, y_test]) |
0 commit comments