-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathTrain.py
57 lines (51 loc) · 2.4 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from networks import Generator, Discriminator
from ops import Hinge_Loss, MSE
import tensorflow as tf
from utils import read_crop_data
import numpy as np
from PIL import Image
#Paper: CGANS WITH PROJECTION DISCRIMINATOR
#Paper: Residual Dense Network for Image Super-Resolution
BATCH_SIZE = 16
MAX_ITERATION = 600000
TRAINING_SET_PATH = "./TrainingSet/"
LAMBDA = 100
SAVE_MODEL = "./save_para/"
RESULTS = "./results/"
def train():
RDN = Generator("RDN")
D = Discriminator("discriminator")
HR = tf.placeholder(tf.float32, [None, 96, 96, 3])
LR = tf.placeholder(tf.float32, [None, 24, 24, 3])
SR = RDN(LR)
fake_logits = D(SR, LR)
real_logits = D(HR, LR)
D_loss, G_loss = Hinge_Loss(fake_logits, real_logits)
G_loss += MSE(SR, HR) * LAMBDA
itr = tf.Variable(MAX_ITERATION, dtype=tf.int32, trainable=False)
learning_rate = tf.Variable(2e-4, trainable=False)
op_sub = tf.assign_sub(itr, 1)
D_opt = tf.train.AdamOptimizer(learning_rate, beta1=0., beta2=0.9).minimize(D_loss, var_list=D.var_list())
with tf.control_dependencies([op_sub]):
G_opt = tf.train.AdamOptimizer(learning_rate, beta1=0., beta2=0.9).minimize(G_loss, var_list=RDN.var_list())
sess = tf.Session()
sess.run(tf.global_variables_initializer())
lr0 = 2e-4
saver = tf.train.Saver()
while True:
HR_data, LR_data = read_crop_data(TRAINING_SET_PATH, BATCH_SIZE, [96, 96, 3], 4)
sess.run(D_opt, feed_dict={HR: HR_data, LR: LR_data})
[_, iteration] = sess.run([G_opt, itr], feed_dict={HR: HR_data, LR: LR_data})
iteration_ = iteration*1.0
iteration = MAX_ITERATION - iteration
if iteration > MAX_ITERATION // 2:
learning_rate_ = lr0 * (iteration_ * 2 / MAX_ITERATION)
sess.run(tf.assign(learning_rate, learning_rate_))
if iteration % 10 == 0:
[D_LOSS, G_LOSS, LEARNING_RATE, img] = sess.run([D_loss, G_loss, learning_rate, SR], feed_dict={HR: HR_data, LR: LR_data})
output = (np.concatenate((HR_data[0, :, :, :], img[0, :, :, :]), axis=1) + 1) * 127.5
Image.fromarray(np.uint8(output)).save(RESULTS+str(iteration)+".jpg")
print("Iteration: %d, D_loss: %f, G_loss: %f, LearningRate: %f"%(iteration, D_LOSS, G_LOSS, LEARNING_RATE))
if iteration % 500 == 0:
saver.save(sess, SAVE_MODEL + "model.ckpt")
train()