diff --git a/bvae/sample_layer.py b/bvae/sample_layer.py index 42eb6c3..d432dd7 100644 --- a/bvae/sample_layer.py +++ b/bvae/sample_layer.py @@ -84,11 +84,11 @@ def call(self, x, training=None): def reparameterization_trick(): epsilon = K.random_normal(shape=logvar.shape, - mean=0., logvar=1.) + mean=0., stddev=1.) stddev = K.exp(logvar*0.5) - return mean + stddev * epsilon * inf + return mean + stddev * epsilon - return K.in_train_phase(reparameterization_trick, mean + 0*logvar, training=training) + return K.in_train_phase(reparameterization_trick, mean + 0*logvar, training=training) # TODO figure out why this is not working in the specified tf version??? def compute_output_shape(self, input_shape): - return input_shape[0] \ No newline at end of file + return input_shape[0]