Skip to content

Commit cb2d4b9

Browse files
feat: text-to-image-model
1 parent ec4cf05 commit cb2d4b9

File tree

3 files changed

+93
-83
lines changed

3 files changed

+93
-83
lines changed

main.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,16 @@ def train(n_epochs, batch_size, codings_size, d_steps, gp_w):
3232
d_loss_fn=discriminator_loss,
3333
)
3434

35-
visualization_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=lambda epoch, logs: visualize_generated_images(epoch, generator))
35+
# Create a visualization callback
36+
visualization_callback = tf.keras.callbacks.LambdaCallback(on_epoch_end=lambda epoch, logs: visualize_generated_images(epoch, generator, dataset))
37+
# Create a ModelCheckpoint callback
38+
model_checkpoint_path_weights = 'ckpts/CUB-WGAN-GP-weights-{epoch:02d}.keras'
39+
model_checkpoint_callback_weights = ModelCheckpoint(
40+
filepath=model_checkpoint_path_weights,
41+
save_freq='epoch', # Save every epoch
42+
save_weights_only=True, # Save only the weights
43+
)
44+
3645
history = gan.fit(dataset, epochs=n_epochs, verbose=1, callbacks=[visualization_callback, model_checkpoint_callback_weights])
3746

3847
fig, ax = plt.subplots(figsize=(20, 6))

models.py

+54-41
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import tensorflow as tf
22
from keras.layers import (Dense, Reshape, BatchNormalization, Conv2DTranspose,
3-
Dropout, LayerNormalization, Embedding, Input, Conv2D, LeakyReLU, Flatten)
4-
from keras.models import Sequential
3+
Dropout, LayerNormalization, Embedding, Input, Conv2D, LeakyReLU, Flatten,
4+
Concatenate, concatenate, Lambda, ReLU)
5+
from utils import *
56

67
def scaled_dot_product(q, k, v):
78
dk = tf.cast(tf.shape(k)[-1], tf.float32)
@@ -88,11 +89,20 @@ def build_generator(noise_dim,
8889
projection_dim,
8990
num_heads,
9091
mlp_dim):
92+
# Input layer
93+
embed_input = Input(shape=(1024,))
94+
x = Dense(256)(embed_input)
95+
mean_logsigma = LeakyReLU(alpha=0.2)(x)
96+
97+
c = Lambda(generate_c)(mean_logsigma)
9198

9299
noise_input = Input(shape=(noise_dim,))
93100

94-
x = Dense(8 * 8 * projection_dim)(noise_input)
101+
gen_input = Concatenate(axis=1)([c, noise_input])
102+
103+
x = Dense(8 * 8 * projection_dim)(gen_input)
95104
x = Reshape((8 *8, projection_dim))(x)
105+
# x = layers.BatchNormalization()(x)
96106

97107
positional_embeddings = PositionalEmbedding(64, projection_dim)
98108
x = positional_embeddings(x)
@@ -109,29 +119,39 @@ def build_generator(noise_dim,
109119

110120
outputs = Conv2DTranspose(3, kernel_size=3, strides=2, padding="SAME",activation="tanh")(x)
111121

112-
return tf.keras.Model(inputs=noise_input, outputs=outputs, name='generator')
122+
return tf.keras.Model(inputs=[embed_input,noise_input], outputs=outputs, name='generator')
113123

114124
def build_discriminator():
115-
return Sequential([
125+
image_input = Input(shape=(64,64,3))
126+
127+
x = Conv2D(64, kernel_size=4, strides=2, padding="SAME", activation=LeakyReLU(0.2))(image_input)
128+
x = LayerNormalization()(x)
129+
x = Conv2D(128, kernel_size=4, strides=2, padding="SAME", activation=LeakyReLU(0.2))(x)
130+
x = LayerNormalization()(x)
131+
x = Conv2D(256, kernel_size=4, strides=2, padding="SAME", activation=LeakyReLU(0.2))(x)
132+
x = LayerNormalization()(x)
133+
x = Conv2D(512, kernel_size=4, strides=2, padding="SAME", activation=LeakyReLU(0.2))(x)
134+
135+
x = Dropout(0.4)(x)
116136

117-
Conv2D(64, kernel_size=4, strides=1, padding="SAME", activation=LeakyReLU(0.2), input_shape=[64,64,3]),
118-
LayerNormalization(),
119-
Conv2D(128, kernel_size=4, strides=2, padding="SAME", activation=LeakyReLU(0.2)),
120-
LayerNormalization(),
121-
Conv2D(256, kernel_size=4, strides=2, padding="SAME", activation=LeakyReLU(0.2)),
122-
LayerNormalization(),
123-
Conv2D(512, kernel_size=4, strides=2, padding="SAME", activation=LeakyReLU(0.2)),
137+
embedding_input = Input(shape=(1024,))
138+
compressed_embedding = Dense(128)(embedding_input)
139+
compressed_embedding = ReLU()(compressed_embedding)
124140

125-
Dropout(0.4),
141+
compressed_embedding = tf.reshape(compressed_embedding, (-1, 1, 1, 128))
142+
compressed_embedding = tf.tile(compressed_embedding, (1, 4, 4, 1))
126143

127-
Conv2D(64 * 8, kernel_size=1, strides=1, padding="SAME", activation=LeakyReLU(0.2)),
128-
LayerNormalization(),
144+
concat_input = concatenate([x, compressed_embedding])
129145

130-
Dropout(0.4),
131-
Flatten(),
146+
x = Conv2D(64 * 8, kernel_size=1, strides=1, padding="SAME", activation=LeakyReLU(0.2))(concat_input)
147+
x = LayerNormalization()(x)
132148

133-
Dense(1),
134-
], name='discriminator')
149+
x = Dropout(0.4)(x)
150+
x = Flatten()(x)
151+
152+
outputs = Dense(1)(x)
153+
154+
return tf.keras.Model(inputs=[image_input,embedding_input], outputs=outputs, name='discriminator')
135155

136156
class WGAN(tf.keras.Model):
137157
def __init__(
@@ -156,7 +176,7 @@ def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
156176
self.d_loss_fn = d_loss_fn
157177
self.g_loss_fn = g_loss_fn
158178

159-
def gradient_penalty(self, batch_size, real_images, fake_images):
179+
def gradient_penalty(self, batch_size, real_images, fake_images, text_embeddings):
160180
""" Calculates the gradient penalty.
161181
162182
This loss is calculated on an interpolated image
@@ -170,7 +190,7 @@ def gradient_penalty(self, batch_size, real_images, fake_images):
170190
with tf.GradientTape() as gp_tape:
171191
gp_tape.watch(interpolated)
172192
# 1. Get the discriminator output for this interpolated image.
173-
pred = self.discriminator(interpolated, training=True)
193+
pred = self.discriminator([interpolated, text_embeddings], training=True)
174194

175195
# 2. Calculate the gradients w.r.t to this interpolated image.
176196
grads = gp_tape.gradient(pred, [interpolated])[0]
@@ -179,9 +199,14 @@ def gradient_penalty(self, batch_size, real_images, fake_images):
179199
gp = tf.reduce_mean((norm - 1.0) ** 2)
180200
return gp
181201

182-
def train_step(self, real_images):
202+
def train_step(self, dataset):
203+
204+
real_images, text_embeddings = dataset
205+
183206
if isinstance(real_images, tuple):
184207
real_images = real_images[0]
208+
if isinstance(text_embeddings, tuple):
209+
text_embeddings = text_embeddings[0]
185210

186211
batch_size = tf.shape(real_images)[0]
187212

@@ -201,36 +226,24 @@ def train_step(self, real_images):
201226
shape=(batch_size, self.latent_dim)
202227
)
203228
with tf.GradientTape() as tape:
204-
fake_images = self.generator(random_latent_vectors, training=True)
205-
fake_logits = self.discriminator(fake_images, training=True)
206-
real_logits = self.discriminator(real_images, training=True)
207-
229+
fake_images = self.generator([text_embeddings,random_latent_vectors], training=True)
230+
fake_logits = self.discriminator([fake_images, text_embeddings], training=True)
231+
real_logits = self.discriminator([real_images, text_embeddings], training=True)
208232
d_cost = self.d_loss_fn(real_img=real_logits, fake_img=fake_logits)
209-
gp = self.gradient_penalty(batch_size, real_images, fake_images)
233+
gp = self.gradient_penalty(batch_size, real_images, fake_images, text_embeddings)
210234
d_loss = d_cost + gp * self.gp_weight
211-
212235
d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
213236
self.d_optimizer.apply_gradients(
214237
zip(d_gradient, self.discriminator.trainable_variables)
215238
)
216-
217239
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
218240
with tf.GradientTape() as tape:
219-
generated_images = self.generator(random_latent_vectors, training=True)
220-
gen_img_logits = self.discriminator(generated_images, training=True)
241+
generated_images = self.generator([text_embeddings, random_latent_vectors], training=True)
242+
gen_img_logits = self.discriminator([generated_images, text_embeddings], training=True)
221243
g_loss = self.g_loss_fn(gen_img_logits)
222244

223245
gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
224246
self.g_optimizer.apply_gradients(
225247
zip(gen_gradient, self.generator.trainable_variables)
226248
)
227-
return {"d_loss": d_loss, "g_loss": g_loss}
228-
229-
def discriminator_loss(real_img, fake_img):
230-
real_loss = tf.reduce_mean(real_img)
231-
fake_loss = tf.reduce_mean(fake_img)
232-
return fake_loss - real_loss
233-
234-
235-
def generator_loss(fake_img):
236-
return -tf.reduce_mean(fake_img)
249+
return {"d_loss": d_loss, "g_loss": g_loss}

utils.py

+29-41
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
import numpy as np
22
import tensorflow as tf
3-
import pickle
3+
from keras import backend as K
44
from keras.callbacks import ModelCheckpoint
55
import matplotlib.pyplot as plt
66

7-
def plot_results(images, n_cols=None, title=None):
7+
def generate_c(x):
8+
mean = x[:, :128]
9+
log_sigma = x[:, 128:]
810

9-
n_cols = n_cols or len(images)
10-
n_rows = (len(images) - 1) // n_cols + 1
11+
stddev = K.exp(log_sigma)
12+
epsilon = K.random_normal(shape=K.constant((mean.shape[1],), dtype='int32'))
13+
c = stddev * epsilon + mean
1114

12-
if images.shape[-1] == 1:
13-
images = np.squeeze(images, axis=-1)
14-
15-
fig = plt.figure(figsize=(n_cols, n_rows))
16-
17-
for index, image in enumerate(images):
18-
plt.subplot(n_rows, n_cols, index + 1)
19-
plt.imshow(image.astype(np.uint8), cmap="binary")
20-
plt.axis("off")
21-
22-
plt.suptitle(title)
15+
return c
2316

2417
def show_dataset_images(images):
2518
fig, axes = plt.subplots(4,4, figsize=(8,8))
@@ -29,9 +22,11 @@ def show_dataset_images(images):
2922
axes[i].axis('off')
3023
plt.show()
3124

32-
def visualize_generated_images(epoch, generator, latent_dim=100, num_samples=5):
25+
# Visualization callback
26+
def visualize_generated_images(epoch, generator, dataset, latent_dim=100, num_samples=5):
27+
real_images, text_embeddings = next(iter(dataset.take(1)))
3328
random_latent_vectors = np.random.normal(size=(num_samples, latent_dim))
34-
generated_images = generator.predict(random_latent_vectors)
29+
generated_images = generator.predict([text_embeddings[:num_samples], random_latent_vectors])
3530
generated_images += 1
3631
generated_images *= 127.5
3732

@@ -43,41 +38,34 @@ def visualize_generated_images(epoch, generator, latent_dim=100, num_samples=5):
4338
plt.suptitle(f'Generated Images - Epoch {epoch}')
4439
plt.show()
4540

46-
# Define the file path for saving the model
47-
model_checkpoint_path_weights = 'ckpts/CUB-WGAN-GP-weights-{epoch:02d}.keras'
48-
49-
# Create a ModelCheckpoint callback
50-
model_checkpoint_callback_weights = ModelCheckpoint(
51-
filepath=model_checkpoint_path_weights,
52-
save_freq='epoch', # Save every epoch
53-
save_weights_only=True, # Save only the weights
54-
)
55-
56-
def load_images(path):
57-
#Loading images from pickle file
58-
with open(path, 'rb') as f_in:
59-
images = pickle.load(f_in)
60-
return images
41+
def prepare_data(batch_size, data_path):
42+
x_train_path = data_path + "/X_train_CUB.npy"
43+
embed_train_path = data_path + "/embeddings_train_CUB.npy"
6144

62-
def load_data(pickle_data_file):
63-
#Load images and embeddings
64-
x = np.array(load_images(pickle_data_file))
65-
return x
66-
67-
def prepare_data(batch_size):
68-
pickle_path_64 = "data/64images.pickle"
69-
x_train_64 = load_data(pickle_path_64)
45+
x_train_64 = np.load(x_train_path)
46+
embed_train_64 = np.load(embed_train_path)
7047

7148
print(f'Dataset images shape: {x_train_64.shape}\n')
49+
print(f'Text embeddings shape: {embed_train_64.shape}\n')
50+
7251
print(f'Dataset images: \n')
7352
show_dataset_images(x_train_64)
7453

7554
# Normalization
7655
x_train = x_train_64.astype(np.float32) / 127.5
7756
x_train = x_train - 1
7857

79-
dataset = tf.data.Dataset.from_tensor_slices(x_train)
58+
dataset = tf.data.Dataset.from_tensor_slices((x_train, embed_train_64))
8059
dataset = dataset.shuffle(1024)
8160
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
8261

8362
return dataset
63+
64+
def discriminator_loss(real_img, fake_img):
65+
real_loss = tf.reduce_mean(real_img)
66+
fake_loss = tf.reduce_mean(fake_img)
67+
return fake_loss - real_loss
68+
69+
70+
def generator_loss(fake_img):
71+
return -tf.reduce_mean(fake_img)

0 commit comments

Comments
 (0)