-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathvae.py
93 lines (78 loc) · 3.44 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import sys
import tensorflow as tf
# Stupid Keras things is a smart way to always print. See:
# https://github.com/keras-team/keras/issues/1406
stderr = sys.stderr
sys.stderr = open(os.devnull, "w")
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Reshape, UpSampling2D, Layer, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras.metrics import Mean
sys.stderr = stderr
class Sampling(Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
class VAE(Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = Mean(name="total_loss")
self.reconstruction_loss_tracker = Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction), axis=-1, keepdims=True
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
def create(num_markers):
encoder_input = Input(shape=(num_markers,), name='encoder_input')
x = Dense(1024, activation='relu', name='dense1')(encoder_input)
x = Dense(256, activation='relu', name='dense2')(x)
x = Dense(16, activation='relu', name='dense3')(x)
# Latent Space
z_mean = Dense(2, name='z_mean')(x)
z_log_var = Dense(2, name='z_log_var')(x)
z = Sampling(name='z')([z_mean, z_log_var])
latent_input = Input(shape=(2,), name='latent_input')
x = Dense(16, activation='relu', name='undense3')(latent_input)
x = Dense(256, activation='relu', name='undense2')(x)
x = Dense(1024, activation='relu', name='undense1')(x)
decoded = Dense(num_markers, activation='sigmoid', name='output')(x)
encoder = Model(encoder_input, [z_mean, z_log_var, z], name='encoder')
decoder = keras.Model(latent_input, decoded, name='decoder')
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
return vae