forked from Adversarial-Deep-Learning/code-soup
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_gan.py
148 lines (132 loc) · 4.66 KB
/
mnist_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
import torch
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from code_soup.ch5.datasets import MnistDataset
from code_soup.ch5.models import Discriminator, Generator
parser = argparse.ArgumentParser(
prog="mnist_gan.py", description="Train an MNIST GAN model"
)
parser.add_argument(
"--batch_size",
type=int,
action="store",
help="Specifies batch size of the GAN Trainer",
default=64,
)
parser.add_argument(
"--latent_dims",
type=int,
action="store",
help="Specifies size of latent vectors for generating noise",
default=128,
)
parser.add_argument(
"--learning_rate",
type=int,
action="store",
help="Specifies learning rate for training",
default=0.0002,
)
parser.add_argument(
"--epochs",
type=int,
action="store",
help="Specifies learning rate for training",
default=200,
)
args = parser.parse_args()
dataloader_batch_size = args.batch_size
latent_dims = args.latent_dims
lr = args.learning_rate
epochs = args.epochs
def train_mnist_gan():
# Loading the dataset
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
dataset = MnistDataset(transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=dataloader_batch_size, shuffle=True
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initializing the models
generator = Generator(latent_dims).to(device)
discriminator = Discriminator(784).to(device)
# Initializing the optimizers
optimizerD = optim.Adam(discriminator.parameters(), lr=lr)
optimizerG = optim.Adam(generator.parameters(), lr=lr)
# Defining Loss function
criterion = torch.nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
fixed_noise = torch.randn(64, latent_dims, device=device)
# Establish convention for real and fake labels during training
real_label = 1.0
fake_label = 0.0
for epoch in range(epochs):
for i, data in enumerate(dataloader, 0):
discriminator.zero_grad()
real_image, _ = data
real_image = real_image.to(device)
batch_size = real_image.shape[0]
label = torch.full(
(batch_size,), real_label, dtype=torch.float, device=device
)
# Forward pass real batch through D
output = discriminator(real_image).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
# Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(batch_size, latent_dims, device=device)
# Generate fake image batch with G
fake = generator(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = discriminator(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
generator.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = discriminator(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print(
"[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f"
% (
epoch + 1,
epochs,
i,
len(dataloader),
errD.item(),
errG.item(),
D_x,
D_G_z1,
D_G_z2,
)
)
# save model weights
torch.save(discriminator.state_dict(), "./discriminator.pth")
torch.save(generator.state_dict(), "./generator.pth")
if __name__ == "__main__":
train_mnist_gan()