Skip to content

Commit a2f48ad

Browse files
authored
Merge pull request #174 from philipperemy/weight-norm
Weight norm
2 parents 4404f7c + b47ede9 commit a2f48ad

File tree

8 files changed

+116
-74
lines changed

8 files changed

+116
-74
lines changed

README.md

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -227,21 +227,19 @@ The task consists of feeding a large array of decimal numbers to the network, al
227227

228228
#### Implementation results
229229

230-
The model takes time to learn this task. It's symbolized by a very long plateau (could take ~8 epochs on some runs).
231-
232230
```
233-
200000/200000 [==============================] - 293s 1ms/step - loss: 0.1731 - val_loss: 0.1662
234-
200000/200000 [==============================] - 289s 1ms/step - loss: 0.1675 - val_loss: 0.1665
235-
200000/200000 [==============================] - 287s 1ms/step - loss: 0.1670 - val_loss: 0.1665
236-
200000/200000 [==============================] - 288s 1ms/step - loss: 0.1668 - val_loss: 0.1669
237-
200000/200000 [==============================] - 285s 1ms/step - loss: 0.1085 - val_loss: 0.0019
238-
200000/200000 [==============================] - 285s 1ms/step - loss: 0.0011 - val_loss: 4.1667e-04
239-
200000/200000 [==============================] - 282s 1ms/step - loss: 6.0470e-04 - val_loss: 6.7708e-04
240-
200000/200000 [==============================] - 282s 1ms/step - loss: 4.3099e-04 - val_loss: 7.3898e-04
241-
200000/200000 [==============================] - 282s 1ms/step - loss: 3.9102e-04 - val_loss: 1.8727e-04
242-
200000/200000 [==============================] - 280s 1ms/step - loss: 3.1040e-04 - val_loss: 0.0010
243-
200000/200000 [==============================] - 281s 1ms/step - loss: 3.1166e-04 - val_loss: 2.2333e-04
244-
200000/200000 [==============================] - 281s 1ms/step - loss: 2.8046e-04 - val_loss: 1.5194e-04
231+
782/782 [==============================] - 154s 197ms/step - loss: 0.8437 - val_loss: 0.1883
232+
782/782 [==============================] - 154s 196ms/step - loss: 0.0702 - val_loss: 0.0111
233+
782/782 [==============================] - 153s 195ms/step - loss: 0.0053 - val_loss: 0.0038
234+
782/782 [==============================] - 154s 196ms/step - loss: 0.0035 - val_loss: 0.0027
235+
782/782 [==============================] - 153s 196ms/step - loss: 0.0030 - val_loss: 0.0065
236+
782/782 [==============================] - 151s 193ms/step - loss: 0.0027 - val_loss: 0.0018
237+
782/782 [==============================] - 152s 194ms/step - loss: 0.0025 - val_loss: 0.0036
238+
782/782 [==============================] - 153s 196ms/step - loss: 0.0024 - val_loss: 0.0018
239+
782/782 [==============================] - 152s 194ms/step - loss: 0.0023 - val_loss: 0.0016
240+
782/782 [==============================] - 152s 194ms/step - loss: 0.0014 - val_loss: 3.7456e-04
241+
782/782 [==============================] - 153s 196ms/step - loss: 9.4740e-04 - val_loss: 7.0205e-04
242+
782/782 [==============================] - 152s 194ms/step - loss: 6.9630e-04 - val_loss: 3.7180e-04
245243
```
246244

247245
### Copy Memory Task
@@ -263,13 +261,14 @@ The idea is to copy the content of the vector x to the end of the large array. T
263261
#### Implementation results (first epochs)
264262

265263
```
266-
30000/30000 [==============================] - 30s 1ms/step - loss: 0.1174 - acc: 0.9586 - val_loss: 0.0370 - val_acc: 0.9859
267-
30000/30000 [==============================] - 26s 874us/step - loss: 0.0367 - acc: 0.9859 - val_loss: 0.0363 - val_acc: 0.9859
268-
30000/30000 [==============================] - 26s 852us/step - loss: 0.0361 - acc: 0.9859 - val_loss: 0.0358 - val_acc: 0.9859
269-
30000/30000 [==============================] - 26s 872us/step - loss: 0.0355 - acc: 0.9859 - val_loss: 0.0349 - val_acc: 0.9859
270-
30000/30000 [==============================] - 25s 850us/step - loss: 0.0339 - acc: 0.9864 - val_loss: 0.0291 - val_acc: 0.9881
271-
30000/30000 [==============================] - 26s 856us/step - loss: 0.0235 - acc: 0.9896 - val_loss: 0.0159 - val_acc: 0.9944
272-
30000/30000 [==============================] - 26s 872us/step - loss: 0.0169 - acc: 0.9929 - val_loss: 0.0125 - val_acc: 0.9966
264+
118/118 [==============================] - 17s 143ms/step - loss: 1.1732 - accuracy: 0.6725 - val_loss: 0.1119 - val_accuracy: 0.9796
265+
118/118 [==============================] - 15s 125ms/step - loss: 0.0645 - accuracy: 0.9831 - val_loss: 0.0402 - val_accuracy: 0.9853
266+
118/118 [==============================] - 15s 125ms/step - loss: 0.0393 - accuracy: 0.9856 - val_loss: 0.0372 - val_accuracy: 0.9857
267+
118/118 [==============================] - 15s 125ms/step - loss: 0.0361 - accuracy: 0.9858 - val_loss: 0.0344 - val_accuracy: 0.9860
268+
118/118 [==============================] - 15s 125ms/step - loss: 0.0345 - accuracy: 0.9860 - val_loss: 0.0335 - val_accuracy: 0.9864
269+
118/118 [==============================] - 15s 125ms/step - loss: 0.0325 - accuracy: 0.9867 - val_loss: 0.0268 - val_accuracy: 0.9886
270+
118/118 [==============================] - 15s 125ms/step - loss: 0.0268 - accuracy: 0.9885 - val_loss: 0.0206 - val_accuracy: 0.9908
271+
118/118 [==============================] - 15s 125ms/step - loss: 0.0228 - accuracy: 0.9900 - val_loss: 0.0169 - val_accuracy: 0.9933
273272
```
274273

275274
### Sequential MNIST
@@ -286,11 +285,16 @@ The idea here is to consider MNIST images as 1-D sequences and feed them to the
286285
#### Implementation results
287286

288287
```
289-
60000/60000 [==============================] - 118s 2ms/step - loss: 0.2348 - acc: 0.9265 - val_loss: 0.1308 - val_acc: 0.9579
290-
60000/60000 [==============================] - 116s 2ms/step - loss: 0.0973 - acc: 0.9698 - val_loss: 0.0645 - val_acc: 0.9798
291-
[...]
292-
60000/60000 [==============================] - 112s 2ms/step - loss: 0.0075 - acc: 0.9978 - val_loss: 0.0547 - val_acc: 0.9894
293-
60000/60000 [==============================] - 111s 2ms/step - loss: 0.0093 - acc: 0.9968 - val_loss: 0.0585 - val_acc: 0.9895
288+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0949 - accuracy: 0.9706 - val_loss: 0.0763 - val_accuracy: 0.9756
289+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0831 - accuracy: 0.9743 - val_loss: 0.0656 - val_accuracy: 0.9807
290+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0752 - accuracy: 0.9763 - val_loss: 0.0604 - val_accuracy: 0.9802
291+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0685 - accuracy: 0.9785 - val_loss: 0.0588 - val_accuracy: 0.9813
292+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0624 - accuracy: 0.9801 - val_loss: 0.0545 - val_accuracy: 0.9822
293+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0603 - accuracy: 0.9812 - val_loss: 0.0478 - val_accuracy: 0.9835
294+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0566 - accuracy: 0.9821 - val_loss: 0.0546 - val_accuracy: 0.9826
295+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0503 - accuracy: 0.9843 - val_loss: 0.0441 - val_accuracy: 0.9853
296+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0486 - accuracy: 0.9840 - val_loss: 0.0572 - val_accuracy: 0.9832
297+
1875/1875 [==============================] - 46s 25ms/step - loss: 0.0453 - accuracy: 0.9858 - val_loss: 0.0424 - val_accuracy: 0.9862
294298
```
295299

296300
## Testing

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
long_description=open('README.md').read(),
1111
packages=['tcn'],
1212
install_requires=[
13-
'numpy', 'tensorflow'
13+
'numpy', 'tensorflow', 'tensorflow_addons'
1414
]
1515
)

tasks/adding_problem/main.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,39 @@
1-
import keras
21
import numpy as np
3-
4-
from tcn import compiled_tcn
2+
from tensorflow.keras.callbacks import Callback
53
from utils import data_generator
64

5+
from tcn import compiled_tcn, tcn_full_summary
6+
77
x_train, y_train = data_generator(n=200000, seq_length=600)
88
x_test, y_test = data_generator(n=40000, seq_length=600)
99

1010

11-
class PrintSomeValues(keras.callbacks.Callback):
11+
class PrintSomeValues(Callback):
1212

1313
def on_epoch_begin(self, epoch, logs={}):
1414
print('y_true, y_pred')
1515
print(np.hstack([y_test[:5], self.model.predict(x_test[:5])]))
1616

1717

1818
def run_task():
19-
model = compiled_tcn(return_sequences=False,
20-
num_feat=x_train.shape[2],
21-
num_classes=0,
22-
nb_filters=24,
23-
kernel_size=8,
24-
dilations=[2 ** i for i in range(9)],
25-
nb_stacks=1,
26-
max_len=x_train.shape[1],
27-
use_skip_connections=False,
28-
regression=True,
29-
dropout_rate=0)
30-
31-
print(f'x_train.shape = {x_train.shape}')
32-
print(f'y_train.shape = {y_train.shape}')
33-
34-
psv = PrintSomeValues()
35-
36-
# Using sparse softmax.
37-
# http://chappers.github.io/web%20micro%20log/2017/01/26/quick-models-in-keras/
38-
model.summary()
39-
19+
model = compiled_tcn(
20+
return_sequences=False,
21+
num_feat=x_train.shape[2],
22+
num_classes=0,
23+
nb_filters=24,
24+
kernel_size=8,
25+
dilations=[2 ** i for i in range(9)],
26+
nb_stacks=1,
27+
max_len=x_train.shape[1],
28+
use_skip_connections=False,
29+
use_weight_norm=True,
30+
regression=True,
31+
dropout_rate=0
32+
)
33+
34+
tcn_full_summary(model)
4035
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=15,
41-
batch_size=256, callbacks=[psv])
36+
batch_size=256, callbacks=[PrintSomeValues()])
4237

4338

4439
if __name__ == '__main__':

tasks/copy_memory/main.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from uuid import uuid4
22

3-
import keras
43
import numpy as np
4+
from tensorflow.keras.callbacks import Callback
55

66
from tcn import compiled_tcn
77
from utils import data_generator
@@ -10,7 +10,7 @@
1010
x_test, y_test = data_generator(601, 10, 6000)
1111

1212

13-
class PrintSomeValues(keras.callbacks.Callback):
13+
class PrintSomeValues(Callback):
1414

1515
def on_epoch_begin(self, epoch, logs={}):
1616
print('y_true')
@@ -30,6 +30,7 @@ def run_task():
3030
use_skip_connections=True,
3131
opt='rmsprop',
3232
lr=5e-4,
33+
use_weight_norm=True,
3334
return_sequences=True)
3435

3536
print(f'x_train.shape = {x_train.shape}')

tasks/mnist_pixel/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def run_task():
1414
dilations=[2 ** i for i in range(9)],
1515
nb_stacks=1,
1616
max_len=x_train[0:1].shape[1],
17+
use_weight_norm=True,
1718
use_skip_connections=True)
1819

1920
print(f'x_train.shape = {x_train.shape}')

tasks/mnist_pixel/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
2-
from keras.datasets import mnist
3-
from keras.utils import to_categorical
2+
from tensorflow.keras.datasets import mnist
3+
from tensorflow.keras.utils import to_categorical
44

55

66
def data_generator():

tasks/tcn_call_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from tensorflow.keras import Input
55
from tensorflow.keras import Model
6+
from tensorflow.keras.models import Sequential
67

78
from tcn import TCN
89

@@ -99,6 +100,27 @@ def test_non_causal_time_dim_unknown_return_no_sequences(self):
99100
r = predict_with_tcn(time_steps=None, padding='same', return_sequences=False)
100101
self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS], [1, NB_FILTERS], [1, NB_FILTERS]])
101102

103+
def test_norms(self):
104+
Sequential(layers=[TCN(input_shape=(20, 2), use_weight_norm=True)]).compile(optimizer='adam', loss='mse')
105+
Sequential(layers=[TCN(input_shape=(20, 2), use_weight_norm=False)]).compile(optimizer='adam', loss='mse')
106+
Sequential(layers=[TCN(input_shape=(20, 2), use_layer_norm=True)]).compile(optimizer='adam', loss='mse')
107+
Sequential(layers=[TCN(input_shape=(20, 2), use_layer_norm=False)]).compile(optimizer='adam', loss='mse')
108+
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True)]).compile(optimizer='adam', loss='mse')
109+
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=False)]).compile(optimizer='adam', loss='mse')
110+
try:
111+
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True, use_weight_norm=True)]).compile(
112+
optimizer='adam', loss='mse')
113+
raise AssertionError('test failed.')
114+
except ValueError:
115+
pass
116+
try:
117+
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True,
118+
use_weight_norm=True, use_layer_norm=True)]).compile(
119+
optimizer='adam', loss='mse')
120+
raise AssertionError('test failed.')
121+
except ValueError:
122+
pass
123+
102124

103125
if __name__ == '__main__':
104126
unittest.main()

0 commit comments

Comments
 (0)