-
Notifications
You must be signed in to change notification settings - Fork 21
/
cgan.py
164 lines (131 loc) · 5.49 KB
/
cgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#coding:utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import numpy as np
import os
# 超参数
gpu_id = None
if gpu_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
device = torch.device('cuda')
else:
device = torch.device('cpu')
if os.path.exists('cgan_images') is False:
os.makedirs('cgan_images')
z_dim = 100
batch_size = 64
learning_rate = 0.0002
total_epochs = 200
class Discriminator(nn.Module):
'''全连接判别器,用于1x28x28的MNIST数据,输出是数据和类别'''
def __init__(self):
super(Discriminator, self).__init__()
layers = []
# 第一层
layers.append(nn.Linear(in_features=28*28+10, out_features=512, bias=True))
layers.append(nn.LeakyReLU(0.2, inplace=True))
# 第二层
layers.append(nn.Linear(in_features=512, out_features=256, bias=True))
layers.append(nn.LeakyReLU(0.2, inplace=True))
# 输出层
layers.append(nn.Linear(in_features=256, out_features=1, bias=True))
layers.append(nn.Sigmoid())
self.model = nn.Sequential(*layers)
def forward(self, x, c):
x = x.view(x.size(0), -1)
validity = self.model(torch.cat([x, c], -1))
return validity
class Generator(nn.Module):
'''全连接生成器,用于1x28x28的MNIST数据,输入是噪声和类别'''
def __init__(self, z_dim):
super(Generator, self).__init__()
layers = []
# 第一层
layers.append(nn.Linear(in_features=z_dim+10, out_features=128))
layers.append(nn.LeakyReLU(0.2, inplace=True))
# 第二层
layers.append(nn.Linear(in_features=128, out_features=256))
layers.append(nn.BatchNorm1d(256, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
# 第三层
layers.append(nn.Linear(in_features=256, out_features=512))
layers.append(nn.BatchNorm1d(512, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
# 输出层
layers.append(nn.Linear(in_features=512, out_features=28*28))
layers.append(nn.Tanh())
self.model = nn.Sequential(*layers)
def forward(self, z, c):
x = self.model(torch.cat([z, c], dim=1))
x = x.view(-1, 1, 28, 28)
return x
def one_hot(labels, class_num):
'''把标签转换成one-hot类型'''
tmp = torch.FloatTensor(labels.size(0), class_num).zero_()
one_hot = tmp.scatter_(dim=1, index=torch.LongTensor(labels.view(-1, 1)), value=1)
return one_hot
# 初始化构建判别器和生成器
discriminator = Discriminator().to(device)
generator = Generator(z_dim=z_dim).to(device)
# 初始化二值交叉熵损失
bce = torch.nn.BCELoss().to(device)
ones = torch.ones(batch_size).to(device)
zeros = torch.zeros(batch_size).to(device)
# 初始化优化器,使用Adam优化器
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=[0.5, 0.999])
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=[0.5, 0.999])
# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = torchvision.datasets.MNIST(root='data/', train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
#用于生成效果图
# 生成100个one_hot向量,每类10个
fixed_c = torch.FloatTensor(100, 10).zero_()
fixed_c = fixed_c.scatter_(dim=1, index=torch.LongTensor(np.array(np.arange(0, 10).tolist()*10).reshape([100, 1])), value=1)
fixed_c = fixed_c.to(device)
# 生成100个随机噪声向量
fixed_z = torch.randn([100, z_dim]).to(device)
# 开始训练,一共训练total_epochs
for epoch in range(total_epochs):
# 在训练阶段,把生成器设置为训练模式;对应于后面的,在测试阶段,把生成器设置为测试模式
generator = generator.train()
# 训练一个epoch
for i, data in enumerate(dataloader):
# 加载真实数据
real_images, real_labels = data
real_images = real_images.to(device)
# 把对应的标签转化成 one-hot 类型
tmp = torch.FloatTensor(real_labels.size(0), 10).zero_()
real_labels = tmp.scatter_(dim=1, index=torch.LongTensor(real_labels.view(-1, 1)), value=1)
real_labels = real_labels.to(device)
# 生成数据
# 用正态分布中采样batch_size个随机噪声
z = torch.randn([batch_size, z_dim]).to(device)
# 生成 batch_size 个 ont-hot 标签
c = torch.FloatTensor(batch_size, 10).zero_()
c = c.scatter_(dim=1, index=torch.LongTensor(np.random.choice(10, batch_size).reshape([batch_size, 1])), value=1)
c = c.to(device)
# 生成数据
fake_images = generator(z,c)
# 计算判别器损失,并优化判别器
real_loss = bce(discriminator(real_images, real_labels), ones)
fake_loss = bce(discriminator(fake_images.detach(), c), zeros)
d_loss = real_loss + fake_loss
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# 计算生成器损失,并优化生成器
g_loss = bce(discriminator(fake_images, c), ones)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# 输出损失
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, total_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
# 把生成器设置为测试模型,生成效果图并保存
generator = generator.eval()
fixed_fake_images = generator(fixed_z, fixed_c)
save_image(fixed_fake_images, 'cgan_images/{}.png'.format(epoch), nrow=10, normalize=True)