-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtrainer.py
328 lines (258 loc) · 14.5 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import datetime
import numpy as np
import os
import random
import sys
import time
import torch
import torch.nn as nn
import torchvision.utils as vutils
from torch.backends import cudnn
import utils
from sagan_models import Generator, Discriminator
class Trainer(object):
def __init__(self, config):
# Config
self.config = config
self.start = 0 # Unless using pre-trained model
# Create directories if not exist
utils.make_folder(self.config.save_path)
utils.make_folder(self.config.model_weights_path)
utils.make_folder(self.config.sample_images_path)
# Copy files
utils.write_config_to_file(self.config, self.config.save_path)
utils.copy_scripts(self.config.save_path)
# Check for CUDA
utils.check_for_CUDA(self)
# Make dataloader
self.dataloader, self.num_of_classes = utils.make_dataloader(batch_size=self.config.batch_size_in_gpu,
dataset_type=self.config.dataset,
data_path=self.config.data_path,
shuffle=self.config.shuffle,
drop_last=self.config.drop_last,
dataloader_args=self.config.dataloader_args,
resize=self.config.resize,
imsize=self.config.imsize,
centercrop=self.config.centercrop,
centercrop_size=self.config.centercrop_size,
normalize=self.config.normalize,
)
# Data iterator
self.data_iter = iter(self.dataloader)
# Build G and D
self.build_models()
if self.config.adv_loss == 'dcgan':
self.criterion = nn.BCELoss()
def train(self):
# Seed
np.random.seed(self.config.manual_seed)
random.seed(self.config.manual_seed)
torch.manual_seed(self.config.manual_seed)
# For fast training
cudnn.benchmark = True
# For BatchNorm
self.G.train()
self.D.train()
# Fixed noise for sampling from G
fixed_noise = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)
if self.num_of_classes < self.config.batch_size_in_gpu:
fixed_labels = torch.from_numpy(np.tile(np.arange(self.num_of_classes), self.config.batch_size_in_gpu//self.num_of_classes + 1)[:self.config.batch_size_in_gpu]).to(self.device)
else:
fixed_labels = torch.from_numpy(np.arange(self.config.batch_size_in_gpu)).to(self.device)
# For gan loss
label = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)
ones = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)
# Losses file
log_file_name = os.path.join(self.config.save_path, 'log.txt')
log_file = open(log_file_name, "wt")
# Init
start_time = time.time()
G_losses = []
D_losses_real = []
D_losses_fake = []
D_losses = []
D_xs = []
D_Gz_trainDs = []
D_Gz_trainGs = []
# Instance noise - make random noise mean (0) and std for injecting
inst_noise_mean = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), 0, device=self.device)
inst_noise_std = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), self.config.inst_noise_sigma, device=self.device)
self.gpu_batches = self.config.batch_size//self.config.batch_size_in_gpu
# Start training
for self.step in range(self.start, self.config.total_step):
# Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters
inst_noise_sigma_curr = 0 if self.step > self.config.inst_noise_sigma_iters else (1 - self.step/self.config.inst_noise_sigma_iters)*self.config.inst_noise_sigma
inst_noise_std.fill_(inst_noise_sigma_curr)
# ================== TRAIN D ================== #
for _ in range(self.config.d_steps_per_iter):
# Zero grad
self.reset_grad()
# Accumulate losses for full batch_size
# while running GPU computations on only batch_size_in_gpu
for gpu_batch in range(self.gpu_batches):
# TRAIN with REAL
# Get real images & real labels
real_images, real_labels = self.get_real_samples()
# Get D output for real images & real labels
inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
d_out_real = self.D(real_images + inst_noise, real_labels)
# Compute D loss with real images & real labels
if self.config.adv_loss == 'hinge':
d_loss_real = torch.nn.ReLU()(ones - d_out_real).mean()
elif self.config.adv_loss == 'wgan_gp':
d_loss_real = -d_out_real.mean()
else:
label.fill_(1)
d_loss_real = self.criterion(d_out_real, label)
# Backward
d_loss_real /= self.gpu_batches
d_loss_real.backward()
# Delete loss, output
if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
del d_out_real, d_loss_real
# TRAIN with FAKE
# Create random noise
z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)
# Generate fake images for same real labels
fake_images = self.G(z, real_labels)
# Get D output for fake images & same real labels
inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
d_out_fake = self.D(fake_images.detach() + inst_noise, real_labels)
# Compute D loss with fake images & real labels
if self.config.adv_loss == 'hinge':
d_loss_fake = torch.nn.ReLU()(ones + d_out_fake).mean()
elif self.config.adv_loss == 'dcgan':
label.fill_(0)
d_loss_fake = self.criterion(d_out_fake, label)
else:
d_loss_fake = d_out_fake.mean()
# If WGAN_GP, compute GP and add to D loss
if self.config.adv_loss == 'wgan_gp':
d_loss_gp = self.config.lambda_gp * self.compute_gradient_penalty(real_images, real_labels, fake_images.detach())
d_loss_fake += d_loss_gp
# Backward
d_loss_fake /= self.gpu_batches
d_loss_fake.backward()
# Delete loss, output
del fake_images
if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
del d_out_fake, d_loss_fake
# Optimize
self.D_optimizer.step()
# ================== TRAIN G ================== #
for _ in range(self.config.g_steps_per_iter):
# Zero grad
self.reset_grad()
# Accumulate losses for full batch_size
# while running GPU computations on only batch_size_in_gpu
for gpu_batch in range(self.gpu_batches):
# Get real images & real labels (only need real labels)
real_images, real_labels = self.get_real_samples()
# Create random noise
z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim).to(self.device)
# Generate fake images for same real labels
fake_images = self.G(z, real_labels)
# Get D output for fake images & same real labels
inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)
g_out_fake = self.D(fake_images + inst_noise, real_labels)
# Compute G loss with fake images & real labels
if self.config.adv_loss == 'dcgan':
label.fill_(1)
g_loss = self.criterion(g_out_fake, label)
else:
g_loss = -g_out_fake.mean()
# Backward
g_loss /= self.gpu_batches
g_loss.backward()
# Delete loss, output
del fake_images
if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:
del g_out_fake, g_loss
# Optimize
self.G_optimizer.step()
# Print out log info
if self.step % self.config.log_step == 0:
G_losses.append(g_loss.mean().item())
D_losses_real.append(d_loss_real.mean().item())
D_losses_fake.append(d_loss_fake.mean().item())
D_loss = D_losses_real[-1] + D_losses_fake[-1]
if self.config.adv_loss == 'wgan_gp':
D_loss += d_loss_gp.mean().item()
D_losses.append(D_loss)
D_xs.append(d_out_real.mean().item())
D_Gz_trainDs.append(d_out_fake.mean().item())
D_Gz_trainGs.append(g_out_fake.mean().item())
curr_time = time.time()
curr_time_str = datetime.datetime.fromtimestamp(curr_time).strftime('%Y-%m-%d %H:%M:%S')
elapsed = str(datetime.timedelta(seconds=(curr_time - start_time)))
log = ("[{}] : Elapsed [{}], Iter [{} / {}], G_loss: {:.4f}, D_loss: {:.4f}, D_loss_real: {:.4f}, D_loss_fake: {:.4f}, D(x): {:.4f}, D(G(z))_trainD: {:.4f}, D(G(z))_trainG: {:.4f}\n".
format(curr_time_str, elapsed, self.step, self.config.total_step,
G_losses[-1], D_losses[-1], D_losses_real[-1], D_losses_fake[-1],
D_xs[-1], D_Gz_trainDs[-1], D_Gz_trainGs[-1]))
print('\n' + log)
log_file.write(log)
log_file.flush()
utils.make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs,
self.config.log_step, self.config.save_path)
# Delete loss, output
del d_out_real, d_loss_real, d_out_fake, d_loss_fake, g_out_fake, g_loss
# Sample images
if self.step % self.config.sample_step == 0:
print("Saving image samples..")
self.G.eval()
fake_images = self.G(fixed_noise, fixed_labels)
self.G.train()
sample_images = utils.denorm(fake_images.detach()[:self.config.save_n_images])
# Save batch images
vutils.save_image(sample_images, os.path.join(self.config.sample_images_path, 'fake_{:05d}.png'.format(self.step)), nrow=self.config.nrow)
# Save gif
utils.make_gif(sample_images[0].cpu().numpy().transpose(1, 2, 0)*255, self.step,
self.config.sample_images_path, self.config.name, max_frames_per_gif=self.config.max_frames_per_gif)
# Delete output
del fake_images
# Save model
if self.step % self.config.model_save_step == 0:
utils.save_ckpt(self)
def build_models(self):
self.G = Generator(self.config.z_dim, self.config.g_conv_dim, self.num_of_classes).to(self.device)
self.D = Discriminator(self.config.d_conv_dim, self.num_of_classes).to(self.device)
# Loss and optimizer
# self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
self.G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.config.g_lr, [self.config.beta1, self.config.beta2])
self.D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.config.d_lr, [self.config.beta1, self.config.beta2])
# Start with pretrained model (if it exists)
if self.config.pretrained_model != '':
utils.load_pretrained_model(self)
if 'cuda' in self.device.type and self.config.parallel and torch.cuda.device_count() > 1:
self.G = nn.DataParallel(self.G)
self.D = nn.DataParallel(self.D)
# print networks
print(self.G)
print(self.D)
def reset_grad(self):
self.G_optimizer.zero_grad()
self.D_optimizer.zero_grad()
def get_real_samples(self):
try:
real_images, real_labels = next(self.data_iter)
except:
self.data_iter = iter(self.dataloader)
real_images, real_labels = next(self.data_iter)
real_images, real_labels = real_images.to(self.device), real_labels.to(self.device)
return real_images, real_labels
def compute_gradient_penalty(self, real_images, real_labels, fake_images):
# Compute gradient penalty
alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images).to(device)
interpolated = torch.tensor(alpha * real_images + (1 - alpha) * fake_images, requires_grad=True)
out = self.D(interpolated, real_labels)
exp_grad = torch.ones(out.size()).to(device)
grad = torch.autograd.grad(outputs=out,
inputs=interpolated,
grad_outputs=exp_grad,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)
return d_loss_gp