Skip to content

Commit

Permalink
updated MDN RNN TD example
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Oct 26, 2018
1 parent ec9955d commit b5d7468
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions notebooks/MDN-RNN-time-distributed-MDN-training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"inputs = keras.layers.Input(shape=(SEQ_LEN,OUTPUT_DIMENSION), name='inputs')\n",
"lstm1_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm1', return_sequences=True)(inputs)\n",
"lstm2_out = keras.layers.LSTM(HIDDEN_UNITS, name='lstm2', return_sequences=True)(lstm1_out)\n",
"mdn_out = mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES, name='mdn_outputs')(lstm2_out)\n",
"mdn_out = keras.layers.TimeDistributed(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES, name='mdn_outputs'), name='td_mdn')(lstm2_out)\n",
"\n",
"model = keras.models.Model(inputs=inputs, outputs=mdn_out)\n",
"model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer='adam')\n",
Expand Down Expand Up @@ -204,7 +204,7 @@
"source": [
"# Fit the model\n",
"filepath=\"kanji_mdnrnn-{epoch:02d}-{val_acc:.2f}.hdf5\"\n",
"checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')\n",
"checkpoint = keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')\n",
"callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint]\n",
"\n",
"history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_data=(Xval,yval))\n",
Expand Down

0 comments on commit b5d7468

Please sign in to comment.