Skip to content

Commit

Permalink
cleaned up some imports
Browse files Browse the repository at this point in the history
  • Loading branch information
cpmpercussion committed Sep 5, 2019
1 parent d7d3179 commit b53163f
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions mdn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,18 @@
"""
from .version import __version__
import keras
from keras import backend as K
from keras.layers import Dense
from keras.engine.topology import Layer
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
from tensorflow_probability import distributions as tfd


def elu_plus_one_plus_epsilon(x):
"""ELU activation with a very small addition to help prevent
NaN in loss."""
return K.elu(x) + 1 + K.epsilon()
return keras.backend.elu(x) + 1 + keras.backend.epsilon()


class MDN(Layer):
class MDN(keras.layers.Layer):
"""A Mixture Density Network Layer for Keras.
This layer has a few tricks to avoid NaNs in the loss function when training:
- Activation for variances is ELU + 1 + 1e-8 (to avoid very small values)
Expand All @@ -39,9 +35,9 @@ def __init__(self, output_dimension, num_mixtures, **kwargs):
self.output_dim = output_dimension
self.num_mix = num_mixtures
with tf.name_scope('MDN'):
self.mdn_mus = Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation
self.mdn_sigmas = Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation
self.mdn_pi = Dense(self.num_mix, name='mdn_pi') # mix vals, logits
self.mdn_mus = keras.layers.Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation
self.mdn_sigmas = keras.layers.Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation
self.mdn_pi = keras.layers.Dense(self.num_mix, name='mdn_pi') # mix vals, logits
super(MDN, self).__init__(**kwargs)

def build(self, input_shape):
Expand Down

0 comments on commit b53163f

Please sign in to comment.