diff --git a/examples/1-sineprediction.py b/examples/1-sineprediction.py new file mode 100644 index 0000000..e0bd25d --- /dev/null +++ b/examples/1-sineprediction.py @@ -0,0 +1,46 @@ +# Normal imports for everybody +from tensorflow import keras +import mdn +import numpy as np + + +# Generating some data: +NSAMPLE = 3000 + +y_data = np.float32(np.random.uniform(-10.5, 10.5, NSAMPLE)) +r_data = np.random.normal(size=NSAMPLE) +x_data = np.sin(0.75 * y_data) * 7.0 + y_data * 0.5 + r_data * 1.0 +x_data = x_data.reshape((NSAMPLE, 1)) + +N_HIDDEN = 15 +N_MIXES = 10 + +model = keras.Sequential() +model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) +model.add(keras.layers.Dense(N_HIDDEN, activation='relu')) +model.add(mdn.MDN(1, N_MIXES)) +model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) +model.summary() + +history = model.fit(x=x_data, y=y_data, batch_size=128, epochs=500, validation_split=0.15) + +# Sample on some test data: +x_test = np.float32(np.arange(-15, 15, 0.01)) +NTEST = x_test.size +print("Testing:", NTEST, "samples.") +x_test = x_test.reshape(NTEST, 1) # needs to be a matrix, not a vector + +# Make predictions from the model +y_test = model.predict(x_test) +# y_test contains parameters for distributions, not actual points on the graph. +# To find points on the graph, we need to sample from each distribution. + +# Sample from the predicted distributions +y_samples = np.apply_along_axis(mdn.sample_from_output, 1, y_test, 1, N_MIXES, temp=1.0) + +# Split up the mixture parameters (for future fun) +mus = np.apply_along_axis((lambda a: a[:N_MIXES]), 1, y_test) +sigs = np.apply_along_axis((lambda a: a[N_MIXES:2*N_MIXES]), 1, y_test) +pis = np.apply_along_axis((lambda a: mdn.softmax(a[2*N_MIXES:])), 1, y_test) + +print("Done.") \ No newline at end of file diff --git a/examples/4-robojam-touch-generation.py b/examples/4-robojam-touch-generation.py new file mode 100644 index 0000000..33aafdf --- /dev/null +++ b/examples/4-robojam-touch-generation.py @@ -0,0 +1,323 @@ +from tensorflow.compat.v1 import keras +from tensorflow.compat.v1.keras import backend as K +from tensorflow.compat.v1.keras.layers import Dense, Input +import numpy as np +import tensorflow.compat.v1 as tf +import math +import h5py +import random +import time +import pandas as pd +import mdn +#import matplotlib.pyplot as plt + + +#input_colour = 'darkblue' +#gen_colour = 'firebrick' +#plt.style.use('seaborn-talk') +import os +os.environ["CUDA_VISIBLE_DEVICES"]="1" + +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +sess = tf.Session(config=config) +K.set_session(sess) + +# Download microjam performance data if needed. +import urllib.request +url = 'http://folk.uio.no/charlepm/datasets/TinyPerformanceCorpus.h5' +urllib.request.urlretrieve(url, './TinyPerformanceCorpus.h5') + + +# ## Helper functions for touchscreen performances +# +# We need a few helper functions for managing performances: +# +# - Convert performances to and from pandas dataframes. +# - Generate random touches. +# - Sample whole performances from scratch and from a priming performance. +# - Plot performances including dividing into swipes. + +SCALE_FACTOR = 1 + +def perf_df_to_array(perf_df, include_moving=False): + """Converts a dataframe of a performance into array a,b,dt format.""" + perf_df['dt'] = perf_df.time.diff() + perf_df.dt = perf_df.dt.fillna(0.0) + # Clean performance data + # Tiny Performance bounds defined to be in [[0,1],[0,1]], edit to fix this. + perf_df.at[perf_df[perf_df.dt > 5].index, 'dt'] = 5.0 + perf_df.at[perf_df[perf_df.dt < 0].index, 'dt'] = 0.0 + perf_df.at[perf_df[perf_df.x > 1].index, 'x'] = 1.0 + perf_df.at[perf_df[perf_df.x < 0].index, 'x'] = 0.0 + perf_df.at[perf_df[perf_df.y > 1].index, 'y'] = 1.0 + perf_df.at[perf_df[perf_df.y < 0].index, 'y'] = 0.0 + if include_moving: + output = np.array(perf_df[['x', 'y', 'dt', 'moving']]) + else: + output = np.array(perf_df[['x', 'y', 'dt']]) + return output + + +def perf_array_to_df(perf_array): + """Converts an array of a performance (a,b,dt(,moving) format) into a dataframe.""" + perf_array = perf_array.T + perf_df = pd.DataFrame({'x': perf_array[0], 'y': perf_array[1], 'dt': perf_array[2]}) + if len(perf_array) == 4: + perf_df['moving'] = perf_array[3] + else: + # As a rule of thumb, could classify taps with dt>0.1 as taps, dt<0.1 as moving touches. + perf_df['moving'] = 1 + perf_df.at[perf_df[perf_df.dt > 0.1].index, 'moving'] = 0 + perf_df['time'] = perf_df.dt.cumsum() + perf_df['z'] = 38.0 + perf_df = perf_df.set_index(['time']) + return perf_df[['x', 'y', 'z', 'moving']] + + +def random_touch(with_moving=False): + """Generate a random tiny performance touch.""" + if with_moving: + return np.array([np.random.rand(), np.random.rand(), 0.01, 0]) + else: + return np.array([np.random.rand(), np.random.rand(), 0.01]) + + +def constrain_touch(touch, with_moving=False): + """Constrain touch values from the MDRNN""" + touch[0] = min(max(touch[0], 0.0), 1.0) # x in [0,1] + touch[1] = min(max(touch[1], 0.0), 1.0) # y in [0,1] + touch[2] = max(touch[2], 0.001) # dt # define minimum time step + if with_moving: + touch[3] = np.greater(touch[3], 0.5) * 1.0 + return touch + + +def generate_random_tiny_performance(model, n_mixtures, first_touch, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False): + """Generates a tiny performance up to 5 seconds in length.""" + if predict_moving: + out_dim = 4 + else: + out_dim = 3 + time = 0 + steps = 0 + previous_touch = first_touch + performance = [previous_touch.reshape((out_dim,))] + while (steps < steps_limit and time < time_limit): + params = model.predict(previous_touch.reshape(1,1,out_dim) * SCALE_FACTOR) + previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR + output_touch = previous_touch.reshape(out_dim,) + output_touch = constrain_touch(output_touch, with_moving=predict_moving) + performance.append(output_touch.reshape((out_dim,))) + steps += 1 + time += output_touch[2] + return np.array(performance) + + +def condition_and_generate(model, perf, n_mixtures, time_limit=5.0, steps_limit=1000, temp=1.0, sigma_temp=0.0, predict_moving=False): + """Conditions the network on an existing tiny performance, then generates a new one.""" + if predict_moving: + out_dim = 4 + else: + out_dim = 3 + time = 0 + steps = 0 + # condition + for touch in perf: + params = model.predict(touch.reshape(1, 1, out_dim) * SCALE_FACTOR) + previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR + output = [previous_touch.reshape((out_dim,))] + # generate + while (steps < steps_limit and time < time_limit): + params = model.predict(previous_touch.reshape(1, 1, out_dim) * SCALE_FACTOR) + previous_touch = mdn.sample_from_output(params[0], out_dim, n_mixtures, temp=temp, sigma_temp=sigma_temp) / SCALE_FACTOR + output_touch = previous_touch.reshape(out_dim,) + output_touch = constrain_touch(output_touch, with_moving=predict_moving) + output.append(output_touch.reshape((out_dim,))) + steps += 1 + time += output_touch[2] + net_output = np.array(output) + return net_output + + +def divide_performance_into_swipes(perf_df): + """Divides a performance into a sequence of swipe dataframes for plotting.""" + touch_starts = perf_df[perf_df.moving == 0].index + performance_swipes = [] + remainder = perf_df + for att in touch_starts: + swipe = remainder.iloc[remainder.index < att] + performance_swipes.append(swipe) + remainder = remainder.iloc[remainder.index >= att] + performance_swipes.append(remainder) + return performance_swipes + + +input_colour = "#4388ff" +gen_colour = "#ec0205" + +def plot_perf_on_ax(perf_df, ax, color="#ec0205", linewidth=3, alpha=0.5): + """Plot a 2D representation of a performance 2D""" + swipes = divide_performance_into_swipes(perf_df) + for swipe in swipes: + p = ax.plot(swipe.x, swipe.y, 'o-', alpha=alpha, markersize=linewidth) + plt.setp(p, color=color, linewidth=linewidth) + ax.set_ylim([1.0,0]) + ax.set_xlim([0,1.0]) + ax.set_xticks([]) + ax.set_yticks([]) + +def plot_2D(perf_df, name="foo", saving=False, figsize=(5, 5)): + """Plot a 2D representation of a performance 2D""" + fig, ax = plt.subplots(figsize=(figsize)) + plot_perf_on_ax(perf_df, ax, color=gen_colour, linewidth=5, alpha=0.7) + if saving: + fig.savefig(name+".png", bbox_inches='tight') + +def plot_double_2d(perf1, perf2, name="foo", saving=False, figsize=(8, 8)): + """Plot two performances in 2D""" + fig, ax = plt.subplots(figsize=(figsize)) + plot_perf_on_ax(perf1, ax, color=input_colour, linewidth=5, alpha=0.7) + plot_perf_on_ax(perf2, ax, color=gen_colour, linewidth=5, alpha=0.7) + if saving: + fig.savefig(name+".png", bbox_inches='tight') + +# # Load up the Dataset: +# +# The dataset consists of around 1000 5-second performances from the MicroJam app. +# This is in a sequence of points consisting of an x-location, a y-location, and a time-delta from the previous point. +# When the user swipes, the time-delta is very small, if they tap it's quite large. +# Let's have a look at some of the data: + +# Load Data +microjam_data_file_name = "./TinyPerformanceCorpus.h5" + +with h5py.File(microjam_data_file_name, 'r') as data_file: + microjam_corpus = data_file['total_performances'][:] + +print("Corpus data points between 100 and 120:") +print(perf_array_to_df(microjam_corpus[100:120])) + +print("Some statistics about the dataset:") +pd.DataFrame(microjam_corpus).describe() + +# - This time, the X and Y locations are *not* differences, but the exact value, but the time is a delta value. +# - The data doesn't have a "pen up" value, but we can just call taps with dt>0.1 as taps, dt<0.1 as moving touches. + +# Plot a bit of the data to have a look: +#plot_2D(perf_array_to_df(microjam_corpus[100:200])) + +# ## MDN RNN +# +# - Now we're going to build an MDN-RNN to predict MicroJam data. +# - The architecture will be: +# - 3 inputs (x, y, dt) +# - 2 layers of 256 LSTM cells each +# - MDN Layer with 3 dimensions and 5 mixtures. +# - Training model will have a sequence length of 30 (prediction model: 1 in, 1 out) +# +# ![RoboJam MDN RNN Model](https://preview.ibb.co/cKZk9T/robojam_mdn_diagram.png) +# +# - Here's the model parameters and training data preparation. +# - We end up with 172K training examples. + +# In[ ]: + + +# Training Hyperparameters: +SEQ_LEN = 30 +BATCH_SIZE = 256 +HIDDEN_UNITS = 256 +EPOCHS = 100 +VAL_SPLIT=0.15 + +# Set random seed for reproducibility +SEED = 2345 +random.seed(SEED) +np.random.seed(SEED) + +def slice_sequence_examples(sequence, num_steps): + xs = [] + for i in range(len(sequence) - num_steps - 1): + example = sequence[i: i + num_steps] + xs.append(example) + return xs + +def seq_to_singleton_format(examples): + xs = [] + ys = [] + for ex in examples: + xs.append(ex[:-1]) + ys.append(ex[-1]) + return (xs,ys) + +sequences = slice_sequence_examples(microjam_corpus, SEQ_LEN+1) +print("Total training examples:", len(sequences)) +X, y = seq_to_singleton_format(sequences) +X = np.array(X) +y = np.array(y) +print("X:", X.shape, "y:", y.shape) + +OUTPUT_DIMENSION = 3 +NUMBER_MIXTURES = 5 + +model = keras.Sequential() +model.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(None,SEQ_LEN,OUTPUT_DIMENSION), return_sequences=True)) +model.add(keras.layers.LSTM(HIDDEN_UNITS)) +model.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)) +model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam()) +model.summary() + +# Define callbacks +filepath="robojam_mdrnn-E{epoch:02d}-VL{val_loss:.2f}.h5" +checkpoint = keras.callbacks.ModelCheckpoint(filepath, save_weights_only=True, verbose=1, save_best_only=True, mode='min') +early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10) +callbacks = [keras.callbacks.TerminateOnNaN(), checkpoint, early_stopping] + +history = model.fit(X, y, batch_size=BATCH_SIZE, epochs=EPOCHS, callbacks=callbacks, validation_split=VAL_SPLIT) + +# Save the Model +model.save('robojam-mdrnn.h5') # creates a HDF5 file of the model + +# Plot the loss +#plt.figure(figsize=(10, 5)) +#plt.plot(history.history['loss']) +#plt.plot(history.history['val_loss']) +#plt.xlabel("epochs") +#plt.ylabel("loss") +#plt.show() + + +# # Try out the model +# +# - Let's try out the model +# - First we will load up a decoding model with a sequence length of 1. +# - The weights are loaded from a the trained model file. + +# Decoding Model +decoder = keras.Sequential() +decoder.add(keras.layers.LSTM(HIDDEN_UNITS, batch_input_shape=(1,1,OUTPUT_DIMENSION), return_sequences=True, stateful=True)) +decoder.add(keras.layers.LSTM(HIDDEN_UNITS, stateful=True)) +decoder.add(mdn.MDN(OUTPUT_DIMENSION, NUMBER_MIXTURES)) +decoder.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam()) +decoder.summary() + +# decoder.set_weights(model.get_weights()) +decoder.load_weights("robojam-mdrnn.h5") + + +# Plotting some conditioned performances. +length = 100 +t = random.randint(0,len(microjam_corpus)-length) +ex = microjam_corpus[t:t+length] #sequences[600] + +decoder.reset_states() +p = condition_and_generate(decoder, ex, NUMBER_MIXTURES, temp=1.5, sigma_temp=0.05) +#plot_double_2d(perf_array_to_df(ex), perf_array_to_df(p), figsize=(4,4)) + +# We can also generate unconditioned performances from a random starting point. + +decoder.reset_states() +t = random_touch() +p = generate_random_tiny_performance(decoder, NUMBER_MIXTURES, t, temp=1.2, sigma_temp=0.01) +#plot_2D(perf_array_to_df(p), figsize=(4,4)) diff --git a/mdn/__init__.py b/mdn/__init__.py index 2c4f94f..b568856 100644 --- a/mdn/__init__.py +++ b/mdn/__init__.py @@ -9,9 +9,11 @@ Provided under MIT License """ from .version import __version__ -import keras +from tensorflow.compat.v1 import keras +from tensorflow.compat.v1.keras import backend as K +from tensorflow.compat.v1.keras import layers import numpy as np -import tensorflow as tf +import tensorflow.compat.v1 as tf from tensorflow_probability import distributions as tfd @@ -21,7 +23,7 @@ def elu_plus_one_plus_epsilon(x): return keras.backend.elu(x) + 1 + keras.backend.epsilon() -class MDN(keras.layers.Layer): +class MDN(layers.Layer): """A Mixture Density Network Layer for Keras. This layer has a few tricks to avoid NaNs in the loss function when training: - Activation for variances is ELU + 1 + 1e-8 (to avoid very small values) @@ -35,15 +37,18 @@ def __init__(self, output_dimension, num_mixtures, **kwargs): self.output_dim = output_dimension self.num_mix = num_mixtures with tf.name_scope('MDN'): - self.mdn_mus = keras.layers.Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation - self.mdn_sigmas = keras.layers.Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation - self.mdn_pi = keras.layers.Dense(self.num_mix, name='mdn_pi') # mix vals, logits + self.mdn_mus = layers.Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation + self.mdn_sigmas = layers.Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation + self.mdn_pi = layers.Dense(self.num_mix, name='mdn_pi') # mix vals, logits super(MDN, self).__init__(**kwargs) def build(self, input_shape): - self.mdn_mus.build(input_shape) - self.mdn_sigmas.build(input_shape) - self.mdn_pi.build(input_shape) + with tf.name_scope('mus'): + self.mdn_mus.build(input_shape) + with tf.name_scope('sigmas'): + self.mdn_sigmas.build(input_shape) + with tf.name_scope('pis'): + self.mdn_pi.build(input_shape) super(MDN, self).build(input_shape) @property @@ -56,10 +61,10 @@ def non_trainable_weights(self): def call(self, x, mask=None): with tf.name_scope('MDN'): - mdn_out = keras.layers.concatenate([self.mdn_mus(x), - self.mdn_sigmas(x), - self.mdn_pi(x)], - name='mdn_outputs') + mdn_out = layers.concatenate([self.mdn_mus(x), + self.mdn_sigmas(x), + self.mdn_pi(x)], + name='mdn_outputs') return mdn_out def compute_output_shape(self, input_shape): @@ -239,7 +244,9 @@ def sample_from_output(params, output_dim, num_mixes, temp=1.0, sigma_temp=1.0): # Alternative way to sample from categorical: # m = np.random.choice(range(len(pis)), p=pis) mus_vector = mus[m * output_dim:(m + 1) * output_dim] - sig_vector = sigs[m * output_dim:(m + 1) * output_dim] * sigma_temp # adjust for temperature - cov_matrix = np.identity(output_dim) * sig_vector + sig_vector = sigs[m * output_dim:(m + 1) * output_dim] + scale_matrix = np.identity(output_dim) * sig_vector # scale matrix from diag + cov_matrix = np.matmul(scale_matrix, scale_matrix.T) # cov is scale squared. + cov_matrix = cov_matrix * sigma_temp # adjust for sigma temperature sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1) return sample diff --git a/mdn/tests/test_mdn.py b/mdn/tests/test_mdn.py index 512154e..33c8103 100644 --- a/mdn/tests/test_mdn.py +++ b/mdn/tests/test_mdn.py @@ -1,4 +1,4 @@ -import keras +from tensorflow.compat.v1 import keras import mdn import numpy as np @@ -12,7 +12,7 @@ def test_build_mdn(): model.add(keras.layers.Dense(N_HIDDEN, activation='relu')) model.add(mdn.MDN(1, N_MIXES)) model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) - assert isinstance(model, keras.engine.sequential.Sequential) + assert isinstance(model, keras.Sequential) def test_number_of_weights(): @@ -37,6 +37,6 @@ def test_save_mdn(): model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu')) model.add(mdn.MDN(1, N_MIXES)) model.compile(loss=mdn.get_mixture_loss_func(1, N_MIXES), optimizer=keras.optimizers.Adam()) - model.save('test_save.h5') + model.save('test_save.h5', overwrite=True, save_format="h5") m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)}) - assert isinstance(m_2, keras.engine.sequential.Sequential) + assert isinstance(m_2, keras.Sequential) diff --git a/mdn/version.py b/mdn/version.py index d93b5b2..0404d81 100644 --- a/mdn/version.py +++ b/mdn/version.py @@ -1 +1 @@ -__version__ = '0.2.3' +__version__ = '0.3.0' diff --git a/notebooks/MDN-1D-sine-prediction.ipynb b/notebooks/MDN-1D-sine-prediction.ipynb index e9f9261..7a50d4e 100644 --- a/notebooks/MDN-1D-sine-prediction.ipynb +++ b/notebooks/MDN-1D-sine-prediction.ipynb @@ -21,7 +21,6 @@ "outputs": [], "source": [ "# Normal imports for everybody\n", - "import keras\n", "from context import * # imports the MDN layer \n", "import numpy as np\n", "import random\n", diff --git a/notebooks/MDN-RNN-RoboJam-touch-generation.ipynb b/notebooks/MDN-RNN-RoboJam-touch-generation.ipynb index dac65cd..4738d32 100644 --- a/notebooks/MDN-RNN-RoboJam-touch-generation.ipynb +++ b/notebooks/MDN-RNN-RoboJam-touch-generation.ipynb @@ -35,11 +35,11 @@ "metadata": {}, "outputs": [], "source": [ - "import keras\n", - "from keras import backend as K\n", - "from keras.layers import Dense, Input\n", + "from tensorflow.compat.v1 import keras\n", + "from tensorflow.compat.v1.keras import backend as K\n", + "from tensorflow.compat.v1.keras.layers import Dense, Input\n", "import numpy as np\n", - "import tensorflow as tf\n", + "import tensorflow.compat.v1 as tf\n", "import math\n", "import h5py\n", "import random\n", @@ -65,11 +65,9 @@ "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", "\n", - "import tensorflow as tf\n", "config = tf.ConfigProto()\n", "config.gpu_options.allow_growth = True\n", "sess = tf.Session(config=config)\n", - "from keras import backend as K\n", "K.set_session(sess)" ] }, @@ -552,7 +550,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.7.4" } }, "nbformat": 4, diff --git a/notebooks/MDN-RNN-kanji-generation-example.ipynb b/notebooks/MDN-RNN-kanji-generation-example.ipynb index e898c99..8e95708 100644 --- a/notebooks/MDN-RNN-kanji-generation-example.ipynb +++ b/notebooks/MDN-RNN-kanji-generation-example.ipynb @@ -34,23 +34,23 @@ "metadata": {}, "outputs": [], "source": [ - "import keras\n", "from context import * # imports the MDN layer \n", "import numpy as np\n", "import random\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D \n", "%matplotlib inline\n", + "import pandas as pd\n", "\n", "# Only for GPU use:\n", "#import os\n", "#os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", "\n", - "import tensorflow as tf\n", + "import tensorflow.compat.v1 as tf\n", + "from tensorflow.compat.v1.keras import backend as K\n", "config = tf.ConfigProto()\n", "config.gpu_options.allow_growth = True\n", "sess = tf.Session(config=config)\n", - "from keras import backend as K\n", "K.set_session(sess)" ] }, @@ -371,10 +371,6 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", - "import matplotlib.pyplot as plt\n", - "%matplotlib inline\n", - "\n", "def zero_start_position():\n", " \"\"\"A zeroed out start position with pen down\"\"\"\n", " out = np.zeros((1, 1, 3), dtype=np.float32)\n", @@ -491,14 +487,15 @@ "outputs": [], "source": [ "# Predict a character and plot the result.\n", - "temperature = 0.1 # seems to work well with rather high temperature (2.5)\n", + "pi_temperature = 2.5 # seems to work well with rather high temperature (2.5)\n", + "sigma_temp = 0.1 # seems to work well with low temp\n", "\n", "p = zero_start_position()\n", "sketch = [p.reshape(3,)]\n", "\n", "for i in range(400):\n", " params = decoder.predict(p.reshape(1,1,3))\n", - " p = mdn.sample_from_output(params[0], OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=temperature)\n", + " p = mdn.sample_from_output(params[0], OUTPUT_DIMENSION, NUMBER_MIXTURES, temp=pi_temperature, sigma_temp=sigma_temp)\n", " sketch.append(p.reshape((3,)))\n", "\n", "sketch = np.array(sketch)\n", diff --git a/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb b/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb index b079e8b..11735ad 100644 --- a/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb +++ b/notebooks/MDN-RNN-time-distributed-MDN-training.ipynb @@ -12,19 +12,13 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using TensorFlow backend.\n" - ] - } - ], + "outputs": [], "source": [ - "import keras\n", + "from tensorflow.compat.v1 import keras\n", + "import tensorflow.compat.v1 as tf\n", + "from tensorflow.compat.v1.keras import backend as K\n", "from context import * # imports the MDN layer \n", "import numpy as np\n", "import random\n", @@ -35,12 +29,9 @@ "# Only for GPU use:\n", "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", - "\n", - "import tensorflow as tf\n", "config = tf.ConfigProto()\n", "config.gpu_options.allow_growth = True\n", "sess = tf.Session(config=config)\n", - "from keras import backend as K\n", "K.set_session(sess)" ] }, @@ -53,20 +44,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "('./kanji.rdp25.npz', )" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# Train from David Ha's Kanji dataset from Sketch-RNN: https://github.com/hardmaru/sketch-rnn-datasets\n", "# Other datasets in \"Sketch 3\" format should also work.\n", @@ -88,21 +68,11 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training kanji: 10358\n", - "Validation kanji: 600\n", - "Testing kanji: 500\n" - ] - } - ], + "outputs": [], "source": [ - "with np.load('./kanji.rdp25.npz') as data:\n", + "with np.load('./kanji.rdp25.npz', allow_pickle=True) as data:\n", " train_set = data['train']\n", " valid_set = data['valid']\n", " test_set = data['test']\n", @@ -121,31 +91,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "_________________________________________________________________\n", - "Layer (type) Output Shape Param # \n", - "=================================================================\n", - "inputs (InputLayer) (None, 50, 3) 0 \n", - "_________________________________________________________________\n", - "lstm1 (LSTM) (None, 50, 256) 266240 \n", - "_________________________________________________________________\n", - "lstm2 (LSTM) (None, 50, 256) 525312 \n", - "_________________________________________________________________\n", - "td_mdn (TimeDistributed) (None, 50, 70) 17990 \n", - "=================================================================\n", - "Total params: 809,542\n", - "Trainable params: 809,542\n", - "Non-trainable params: 0\n", - "_________________________________________________________________\n" - ] - } - ], + "outputs": [], "source": [ "# Training Hyperparameters:\n", "SEQ_LEN = 50\n", @@ -181,19 +129,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of training examples:\n", - "X: (154279, 50, 3)\n", - "y: (154279, 50, 3)\n" - ] - } - ], + "outputs": [], "source": [ "# Functions for slicing up data\n", "def slice_sequence_examples(sequence, num_steps):\n", @@ -229,19 +167,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Number of training examples:\n", - "X: (8928, 50, 3)\n", - "y: (8928, 50, 3)\n" - ] - } - ], + "outputs": [], "source": [ "# Prepare validation data as X and Y.\n", "slices = []\n", @@ -315,29 +243,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "_________________________________________________________________\n", - "Layer (type) Output Shape Param # \n", - "=================================================================\n", - "lstm_1 (LSTM) (1, 1, 256) 266240 \n", - "_________________________________________________________________\n", - "lstm_2 (LSTM) (1, 256) 525312 \n", - "_________________________________________________________________\n", - "mdn_1 (MDN) (1, 70) 17990 \n", - "=================================================================\n", - "Total params: 809,542\n", - "Trainable params: 809,542\n", - "Non-trainable params: 0\n", - "_________________________________________________________________\n" - ] - } - ], + "outputs": [], "source": [ "# Decoding Model\n", "# Same as training model except for dimension and mixtures.\n", @@ -365,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -412,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -482,24 +390,11 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "# Predict a character and plot the result.\n", "temperature = 1.5 # seems to work well with rather high temperature (2.5)\n", @@ -538,7 +433,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.7.4" } }, "nbformat": 4, diff --git a/notebooks/context.py b/notebooks/context.py index e431f70..e6bbcda 100644 --- a/notebooks/context.py +++ b/notebooks/context.py @@ -4,4 +4,5 @@ import sys import os sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -import mdn \ No newline at end of file +import mdn +from tensorflow import keras diff --git a/requirements.txt b/requirements.txt index 485e040..de94db3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,35 @@ -Keras==2.2.4 -numpy==1.16.1 -tensorflow>=1.12.1 -tensorflow-probability>=0.5.0 +absl-py==0.8.1 +astor==0.8.0 +cachetools==3.1.1 +certifi==2019.9.11 +chardet==3.0.4 +cloudpickle==1.1.1 +decorator==4.4.1 +gast==0.2.2 +google-auth==1.6.3 +google-auth-oauthlib==0.4.1 +google-pasta==0.1.7 +grpcio==1.24.3 +h5py==2.10.0 +idna==2.8 +Keras-Applications==1.0.8 +Keras-Preprocessing==1.1.0 +Markdown==3.1.1 +numpy==1.17.3 +oauthlib==3.1.0 +opt-einsum==3.1.0 +protobuf==3.10.0 +pyasn1==0.4.7 +pyasn1-modules==0.2.7 +requests==2.22.0 +requests-oauthlib==1.2.0 +rsa==4.0 +six==1.12.0 +tensorboard==2.0.1 +tensorflow==2.0.0 +tensorflow-estimator==2.0.1 +tensorflow-probability==0.8.0 +termcolor==1.1.0 +urllib3==1.25.6 +Werkzeug==0.16.0 +wrapt==1.11.2