-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
78 lines (69 loc) · 2.61 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from keras.datasets import mnist
from PIL import Image
from model import discriminator, generator
from keras.models import Sequential
from keras.optimizers import SGD, Adam
from visualizer import *
BATCH_SIZE = 32
NUM_EPOCH = 50
LR = 0.0002 # initial learning rate
B1 = 0.5 # momentum term
GENERATED_IMAGE_PATH = 'images/'
GENERATED_MODEL_PATH = 'models/'
def train():
(X_train, y_train), (_, _) = mnist.load_data()
# normalize images
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)
# build GAN
g = generator()
d = discriminator()
opt = Adam(lr=LR,beta_1=B1)
d.trainable = True
d.compile(loss='binary_crossentropy',
metrics=['accuracy'],
optimizer=opt)
d.trainable = False
dcgan = Sequential([g, d])
opt= Adam(lr=LR,beta_1=B1)
dcgan.compile(loss='binary_crossentropy',
metrics=['accuracy'],
optimizer=opt)
num_batches = int(X_train.shape[0] / BATCH_SIZE)
# create directory
if not os.path.exists(GENERATED_IMAGE_PATH):
os.mkdir(GENERATED_IMAGE_PATH)
if not os.path.exists(GENERATED_MODEL_PATH):
os.mkdir(GENERATED_MODEL_PATH)
print("-------------------")
print("Total epoch:", NUM_EPOCH, "Number of batches:", num_batches)
print("-------------------")
z_pred = np.array([np.random.uniform(-1,1,100) for _ in range(49)])
y_g = [1]*BATCH_SIZE
y_d_true = [1]*BATCH_SIZE
y_d_gen = [0]*BATCH_SIZE
for epoch in list(map(lambda x: x+1,range(NUM_EPOCH))):
for index in range(num_batches):
X_d_true = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
X_g = np.array([np.random.normal(0,0.5,100) for _ in range(BATCH_SIZE)])
X_d_gen = g.predict(X_g, verbose=0)
# train discriminator
d_loss = d.train_on_batch(X_d_true, y_d_true)
d_loss = d.train_on_batch(X_d_gen, y_d_gen)
# train generator
g_loss = dcgan.train_on_batch(X_g, y_g)
show_progress(epoch,index,g_loss[0],d_loss[0],g_loss[1],d_loss[1])
# save generated images
image = combine_images(g.predict(z_pred))
image = image*127.5 + 127.5
Image.fromarray(image.astype(np.uint8))\
.save(GENERATED_IMAGE_PATH+"%03depoch.png" % (epoch))
print()
# save models
g.save(GENERATED_MODEL_PATH+'dcgan_generator.h5')
d.save(GENERATED_MODEL_PATH+'dcgan_discriminator.h5')
if __name__ == '__main__':
train()