-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
146 lines (127 loc) · 7.45 KB
/
train.py
File metadata and controls
146 lines (127 loc) · 7.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
from generator import CondGenerator
from discriminator import CondDiscriminator
import utils
"""
Code is based on the tutorial at:
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
authored by Nathan Inkawhich (https://github.com/inkawhich)
"""
# Constants
image_size = 64
batch_size = 128
workers = 2
latent_dims = 100
ngf = 64 # Parameter for number of feature maps in generator
ndf = 64 # Parameter for number of feature maps in discriminator
def get_data_loader():
dataset = dset.CelebA(root=args.celeba_loc,
split="all",
download=False,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
return torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
def train(args):
# Get device and dataloader
device = torch.device(args.device)
dataloader = get_data_loader()
# Initialize BCELoss function
criterion = nn.BCELoss()
# Load checkpoint if needed
start_epoch = 1
if args.resume_from is not None:
start_epoch = args.resume_from + 1
epoch, D, G, optimizerD, optimizerG = utils.load_checkpoint(f"checkpoints/checkpoint{args.resume_from}.pt", device,ndf=ndf,ngf=ngf,latent_dims=latent_dims)
print(f"Resuming from checkpoint after epoch: {args.resume_from}")
else: # Get new networks and optimizers
G = CondGenerator(ngf, latent_dims).to(device)
D = CondDiscriminator(ndf).to(device)
optimizerG = G.get_optimizer()
optimizerD = D.get_optimizer()
print("Starting training loop...")
for epoch in range(start_epoch, start_epoch+args.num_epochs):
for i, data in enumerate(dataloader):
# Format batch
real_imgs = data[0].to(device) # Batch of training images
real_annots = data[1].type(torch.float).to(device) # 40 annotations for each image
curr_batch_size = real_imgs.size(0) # Number of training images in this batch
labels = torch.full((curr_batch_size,), 1, dtype=torch.float, device=device) # Indicate real as 1 label
# Get gradient of D for real
D.zero_grad() # Initialize the gradient of D to 0
output = D(real_imgs, real_annots).view(-1) # Predict real or fake for real images
D_x = output.mean().item() # Portion of real images correctly labelled
errD_real = criterion(output, labels)
errD_real.backward() # Backpropogate loss for D over real images
# Get gradient of D for fake
latent_vector = torch.randn(curr_batch_size, latent_dims, 1, 1, device=device) # Generate noise
fake_imgs = G(latent_vector, real_annots) # Use real annotations, randomly generated annotations aren't realistic
output = D(fake_imgs.detach(), real_annots).view(-1) # Predict real or fake for fake images
D_G_z1 = output.mean().item() # Portion of fake images correctly labelled
labels.fill_(0) # Indicate false as 0 label
errD_fake = criterion(output, labels)
errD_fake.backward() # Backpropogate loss for D over fake images
# Make gradient step for D
optimizerD.step()
# Update G
G.zero_grad() # Initialize the gradient of G to 0
labels.fill_(1) # Generator wants discriminator to yield 1 for fake images
output = D(fake_imgs, real_annots).view(-1) # Predict real or fake for fake images
D_G_z2 = output.mean().item() # Portion of fake images correctly labelled, after step
errG = criterion(output, labels)
errG.backward() # Backpropogate loss for G over fake images
optimizerG.step() # Make gradient step
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, start_epoch+args.num_epochs-1, i, len(dataloader),
(errD_fake+errD_real).item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save checkpoint for each epoch
utils.save_checkpoint(epoch, D, G, optimizerD, optimizerG)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Conditional DCGAN')
parser.add_argument("--celeba_loc", default="./images", type=str, help="Directory of CelebA dataset.")
parser.add_argument("--display", default=-1, type=int, help="Epoch of checkpoint to show batch of images for.")
parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu", type=str, help="'cuda:index' or 'cpu'")
parser.add_argument("--resume_from", type=int, help="Epoch of checkpoint to resume training from.")
parser.add_argument("--num_epochs", default=5, type=int, help="Number of epochs to continue training.")
parser.add_argument("--write_images", type=int, help="Write images to the directory ./generated/, using given checkpoint index")
parser.add_argument("--num_to_write", type=int, default=batch_size, help="Number of images to write for write_images")
args = parser.parse_args() # Get command-line arguments
if args.display >= 0:
# Display a batch of images
device = torch.device(args.device)
epoch, D, G, optimizerD, optimizerG = utils.load_checkpoint(f"checkpoints/checkpoint{args.display}.pt", device,ndf=ndf,ngf=ngf,latent_dims=latent_dims)
dataloader = get_data_loader()
with torch.no_grad():
real_annot = next(iter(dataloader))[1].type(torch.float).to(device) # Retrieve some real annotions
latent_vector = torch.randn(batch_size, latent_dims, 1, 1, device=device) # Generate noise
fake_imgs = G(latent_vector, real_annot).detach().cpu() # Evaluate generated images
utils.show_images(fake_imgs, title=f"Batch of Fake Images After {epoch} Epochs")
elif args.write_images is not None:
# Write args.num_to_write images to files in ./generated/
device = torch.device(args.device)
epoch, D, G, optimizerD, optimizerG = utils.load_checkpoint(f"checkpoints/checkpoint{args.write_images}.pt", device,ndf=ndf,ngf=ngf,latent_dims=latent_dims)
dataloader = get_data_loader()
with torch.no_grad():
images_generated = 0
while images_generated < args.num_to_write:
for i, data in enumerate(dataloader):
real_annots = data[1].type(torch.float).to(device) # 40 annotations for each image
latent_vector = torch.randn(batch_size, latent_dims, 1, 1, device=device) # Generate noise
fake_imgs = G(latent_vector, real_annots).detach().cpu() # Evaluate generated images
utils.write_images(fake_imgs, images_generated)
images_generated += batch_size
if images_generated >= args.num_to_write: # Write num_to_write images in total
break
else:
print(f"Command-line args: {args}")
train(args) # Start training loop