-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
172 lines (147 loc) · 5.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from tensorflow.keras import layers
from tensorflow.keras import models
import tensorflow as tf
import tensorflow_addons as tfa
import time
import numpy as np
from IPython import display
from IPython.display import clear_output
array_shape=9000
'''input_shape not sure about this'''
latent_dim=2
epochs=20
'''need to figure out the values to use'''
class Resnet1DBlock(tf.keras.Model):
def __init__(self, kernel_size, filters,type='encode'):
super(Resnet1DBlock, self).__init__(name='')
if type=='encode':
self.conv1a = layers.Conv1D(filters, kernel_size, 2,padding="same")
self.conv1b = layers.Conv1D(filters, kernel_size, 1,padding="same")
self.norm1a = tfa.layers.InstanceNormalization()
self.norm1b = tfa.layers.InstanceNormalization()
if type=='decode':
self.conv1a = layers.Conv1DTranspose(filters, kernel_size, 1,padding="same")
self.conv1b = layers.Conv1DTranspose(filters, kernel_size, 1,padding="same")
self.norm1a = tf.keras.layers.BatchNormalization()
self.norm1b = tf.keras.layers.BatchNormalization()
else:
return None
def call(self, input_tensor):
x = tf.nn.relu(input_tensor)
x = self.conv1a(x)
x = self.norm1a(x)
x = layers.LeakyReLU(0.4)(x)
x = self.conv1b(x)
x = self.norm1b(x)
x = layers.LeakyReLU(0.4)(x)
x += input_tensor
return tf.nn.relu(x)
class VAE_model(tf.keras.Model):
def __init__(self, latent_dim):
super(VAE_model, self).__init__()
self.latent_dim = latent_dim
self.encoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(1,array_shape)),
layers.Conv1D(64,1,2),
Resnet1DBlock(64,1),
layers.Conv1D(128,1,2),
Resnet1DBlock(128,1),
layers.Conv1D(128,1,2),
Resnet1DBlock(128,1),
layers.Conv1D(256,1,2),
Resnet1DBlock(256,1),
# No activation
layers.Flatten(),
layers.Dense(latent_dim+latent_dim)
]
)
self.decoder = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
layers.Reshape(target_shape=(1,latent_dim)),
Resnet1DBlock(512,1,'decode'),
layers.Conv1DTranspose(512,1,1),
Resnet1DBlock(256,1,'decode'),
layers.Conv1DTranspose(256,1,1),
Resnet1DBlock(128,1,'decode'),
layers.Conv1DTranspose(128,1,1),
Resnet1DBlock(64,1,'decode'),
layers.Conv1DTranspose(64,1,1),
# No activation
layers.Conv1DTranspose(array_shape,1,1),
]
)
def sample(self, eps=None):
if eps is None:
eps = tf.random.normal(shape=(200, self.latent_dim))
return self.decode(eps, apply_sigmoid=True)
def encode(self, x):
mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
return mean, logvar
def reparameterize(self, mean, logvar):
eps = tf.random.normal(shape=mean.shape)
return eps * tf.exp(logvar * .5) + mean
def decode(self, z, apply_sigmoid=False):
logits = self.decoder(z)
if apply_sigmoid:
probs = tf.sigmoid(logits)
return probs
return logits
model = VAE_model(latent_dim)
optimizer = tf.keras.optimizers.Adam(0.0003,beta_1=0.9, beta_2=0.999,epsilon=1e-08)
def log_normal_pdf(sample, mean, logvar, raxis=1):
log2pi = tf.math.log(2. * np.pi)
return tf.reduce_sum(
-.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
axis=raxis)
def compute_loss(model, x):
mean, logvar = model.encode(x)
z = model.reparameterize(mean, logvar)
x_logit = model.decode(z)
cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
logpx_z = -tf.reduce_sum(cross_ent, axis=[1,2])
logpz = log_normal_pdf(z, 0., 0.)
logqz_x = log_normal_pdf(z, mean, logvar)
return -tf.reduce_mean(logpx_z + logpz - logqz_x)
def train_step(model, x, optimizer):
"""Executes one training step and returns the loss.
This function computes the loss and gradients, and uses the latter to
update the model's parameters.
"""
with tf.GradientTape() as tape:
loss_KL = compute_loss(model,x)
mean, logvar = model.encode(x)
z = model.reparameterize(mean, logvar)
x_logit = model.decode(z)
reconstruction_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(x, x_logit))
total_loss = reconstruction_loss+ loss_KL
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
def train_VAE(epoch,train_dataset,test_dataset,model,optimizer):
for epoch in range(1, epochs + 1):
start_time = time.time()
for train_x in train_dataset:
train_x = np.asarray(train_x)[0]
train_step(model, train_x, optimizer)
end_time = time.time()
loss = tf.keras.metrics.Mean()
for test_x in test_dataset:
test_x = np.asarray(test_x)[0]
loss(compute_loss(model, test_x))
display.clear_output(wait=False)
elbo = -loss.result()
print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'.format(epoch, elbo, end_time - start_time))
class wavenet_model:
def __init__(self, inputs, outputs):
# define some params
pass
def forward(self):
# run a forward pass
pass
def optimize(self):
# calculate loss and add an optimizer
pass
def accuracy(self):
# calculate accuracy
pass