1
1
import tensorflow as tf
2
2
from keras .layers import (Dense , Reshape , BatchNormalization , Conv2DTranspose ,
3
- Dropout , LayerNormalization , Embedding , Input , Conv2D , LeakyReLU , Flatten )
4
- from keras .models import Sequential
3
+ Dropout , LayerNormalization , Embedding , Input , Conv2D , LeakyReLU , Flatten ,
4
+ Concatenate , concatenate , Lambda , ReLU )
5
+ from utils import *
5
6
6
7
def scaled_dot_product (q , k , v ):
7
8
dk = tf .cast (tf .shape (k )[- 1 ], tf .float32 )
@@ -88,11 +89,20 @@ def build_generator(noise_dim,
88
89
projection_dim ,
89
90
num_heads ,
90
91
mlp_dim ):
92
+ # Input layer
93
+ embed_input = Input (shape = (1024 ,))
94
+ x = Dense (256 )(embed_input )
95
+ mean_logsigma = LeakyReLU (alpha = 0.2 )(x )
96
+
97
+ c = Lambda (generate_c )(mean_logsigma )
91
98
92
99
noise_input = Input (shape = (noise_dim ,))
93
100
94
- x = Dense (8 * 8 * projection_dim )(noise_input )
101
+ gen_input = Concatenate (axis = 1 )([c , noise_input ])
102
+
103
+ x = Dense (8 * 8 * projection_dim )(gen_input )
95
104
x = Reshape ((8 * 8 , projection_dim ))(x )
105
+ # x = layers.BatchNormalization()(x)
96
106
97
107
positional_embeddings = PositionalEmbedding (64 , projection_dim )
98
108
x = positional_embeddings (x )
@@ -109,29 +119,39 @@ def build_generator(noise_dim,
109
119
110
120
outputs = Conv2DTranspose (3 , kernel_size = 3 , strides = 2 , padding = "SAME" ,activation = "tanh" )(x )
111
121
112
- return tf .keras .Model (inputs = noise_input , outputs = outputs , name = 'generator' )
122
+ return tf .keras .Model (inputs = [ embed_input , noise_input ] , outputs = outputs , name = 'generator' )
113
123
114
124
def build_discriminator ():
115
- return Sequential ([
125
+ image_input = Input (shape = (64 ,64 ,3 ))
126
+
127
+ x = Conv2D (64 , kernel_size = 4 , strides = 2 , padding = "SAME" , activation = LeakyReLU (0.2 ))(image_input )
128
+ x = LayerNormalization ()(x )
129
+ x = Conv2D (128 , kernel_size = 4 , strides = 2 , padding = "SAME" , activation = LeakyReLU (0.2 ))(x )
130
+ x = LayerNormalization ()(x )
131
+ x = Conv2D (256 , kernel_size = 4 , strides = 2 , padding = "SAME" , activation = LeakyReLU (0.2 ))(x )
132
+ x = LayerNormalization ()(x )
133
+ x = Conv2D (512 , kernel_size = 4 , strides = 2 , padding = "SAME" , activation = LeakyReLU (0.2 ))(x )
134
+
135
+ x = Dropout (0.4 )(x )
116
136
117
- Conv2D (64 , kernel_size = 4 , strides = 1 , padding = "SAME" , activation = LeakyReLU (0.2 ), input_shape = [64 ,64 ,3 ]),
118
- LayerNormalization (),
119
- Conv2D (128 , kernel_size = 4 , strides = 2 , padding = "SAME" , activation = LeakyReLU (0.2 )),
120
- LayerNormalization (),
121
- Conv2D (256 , kernel_size = 4 , strides = 2 , padding = "SAME" , activation = LeakyReLU (0.2 )),
122
- LayerNormalization (),
123
- Conv2D (512 , kernel_size = 4 , strides = 2 , padding = "SAME" , activation = LeakyReLU (0.2 )),
137
+ embedding_input = Input (shape = (1024 ,))
138
+ compressed_embedding = Dense (128 )(embedding_input )
139
+ compressed_embedding = ReLU ()(compressed_embedding )
124
140
125
- Dropout (0.4 ),
141
+ compressed_embedding = tf .reshape (compressed_embedding , (- 1 , 1 , 1 , 128 ))
142
+ compressed_embedding = tf .tile (compressed_embedding , (1 , 4 , 4 , 1 ))
126
143
127
- Conv2D (64 * 8 , kernel_size = 1 , strides = 1 , padding = "SAME" , activation = LeakyReLU (0.2 )),
128
- LayerNormalization (),
144
+ concat_input = concatenate ([x , compressed_embedding ])
129
145
130
- Dropout ( 0.4 ),
131
- Flatten (),
146
+ x = Conv2D ( 64 * 8 , kernel_size = 1 , strides = 1 , padding = "SAME" , activation = LeakyReLU ( 0.2 ))( concat_input )
147
+ x = LayerNormalization ()( x )
132
148
133
- Dense (1 ),
134
- ], name = 'discriminator' )
149
+ x = Dropout (0.4 )(x )
150
+ x = Flatten ()(x )
151
+
152
+ outputs = Dense (1 )(x )
153
+
154
+ return tf .keras .Model (inputs = [image_input ,embedding_input ], outputs = outputs , name = 'discriminator' )
135
155
136
156
class WGAN (tf .keras .Model ):
137
157
def __init__ (
@@ -156,7 +176,7 @@ def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
156
176
self .d_loss_fn = d_loss_fn
157
177
self .g_loss_fn = g_loss_fn
158
178
159
- def gradient_penalty (self , batch_size , real_images , fake_images ):
179
+ def gradient_penalty (self , batch_size , real_images , fake_images , text_embeddings ):
160
180
""" Calculates the gradient penalty.
161
181
162
182
This loss is calculated on an interpolated image
@@ -170,7 +190,7 @@ def gradient_penalty(self, batch_size, real_images, fake_images):
170
190
with tf .GradientTape () as gp_tape :
171
191
gp_tape .watch (interpolated )
172
192
# 1. Get the discriminator output for this interpolated image.
173
- pred = self .discriminator (interpolated , training = True )
193
+ pred = self .discriminator ([ interpolated , text_embeddings ] , training = True )
174
194
175
195
# 2. Calculate the gradients w.r.t to this interpolated image.
176
196
grads = gp_tape .gradient (pred , [interpolated ])[0 ]
@@ -179,9 +199,14 @@ def gradient_penalty(self, batch_size, real_images, fake_images):
179
199
gp = tf .reduce_mean ((norm - 1.0 ) ** 2 )
180
200
return gp
181
201
182
- def train_step (self , real_images ):
202
+ def train_step (self , dataset ):
203
+
204
+ real_images , text_embeddings = dataset
205
+
183
206
if isinstance (real_images , tuple ):
184
207
real_images = real_images [0 ]
208
+ if isinstance (text_embeddings , tuple ):
209
+ text_embeddings = text_embeddings [0 ]
185
210
186
211
batch_size = tf .shape (real_images )[0 ]
187
212
@@ -201,36 +226,24 @@ def train_step(self, real_images):
201
226
shape = (batch_size , self .latent_dim )
202
227
)
203
228
with tf .GradientTape () as tape :
204
- fake_images = self .generator (random_latent_vectors , training = True )
205
- fake_logits = self .discriminator (fake_images , training = True )
206
- real_logits = self .discriminator (real_images , training = True )
207
-
229
+ fake_images = self .generator ([text_embeddings ,random_latent_vectors ], training = True )
230
+ fake_logits = self .discriminator ([fake_images , text_embeddings ], training = True )
231
+ real_logits = self .discriminator ([real_images , text_embeddings ], training = True )
208
232
d_cost = self .d_loss_fn (real_img = real_logits , fake_img = fake_logits )
209
- gp = self .gradient_penalty (batch_size , real_images , fake_images )
233
+ gp = self .gradient_penalty (batch_size , real_images , fake_images , text_embeddings )
210
234
d_loss = d_cost + gp * self .gp_weight
211
-
212
235
d_gradient = tape .gradient (d_loss , self .discriminator .trainable_variables )
213
236
self .d_optimizer .apply_gradients (
214
237
zip (d_gradient , self .discriminator .trainable_variables )
215
238
)
216
-
217
239
random_latent_vectors = tf .random .normal (shape = (batch_size , self .latent_dim ))
218
240
with tf .GradientTape () as tape :
219
- generated_images = self .generator (random_latent_vectors , training = True )
220
- gen_img_logits = self .discriminator (generated_images , training = True )
241
+ generated_images = self .generator ([ text_embeddings , random_latent_vectors ] , training = True )
242
+ gen_img_logits = self .discriminator ([ generated_images , text_embeddings ] , training = True )
221
243
g_loss = self .g_loss_fn (gen_img_logits )
222
244
223
245
gen_gradient = tape .gradient (g_loss , self .generator .trainable_variables )
224
246
self .g_optimizer .apply_gradients (
225
247
zip (gen_gradient , self .generator .trainable_variables )
226
248
)
227
- return {"d_loss" : d_loss , "g_loss" : g_loss }
228
-
229
- def discriminator_loss (real_img , fake_img ):
230
- real_loss = tf .reduce_mean (real_img )
231
- fake_loss = tf .reduce_mean (fake_img )
232
- return fake_loss - real_loss
233
-
234
-
235
- def generator_loss (fake_img ):
236
- return - tf .reduce_mean (fake_img )
249
+ return {"d_loss" : d_loss , "g_loss" : g_loss }
0 commit comments