Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch-MNIST-GAN and pytorch-MNIST-DCGAN updates #6

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 132 additions & 115 deletions pytorch_MNIST_DCGAN.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, time
import os
import matplotlib.pyplot as plt
import itertools
import pickle
Expand All @@ -10,7 +10,6 @@
from torchvision import datasets, transforms
from torch.autograd import Variable

# G(z)
class generator(nn.Module):
# initializers
def __init__(self, d=128):
Expand All @@ -33,6 +32,7 @@ def weight_init(self, mean, std):
# forward method
def forward(self, input):
# x = F.relu(self.deconv1(input))
input = input.view(-1,100,1,1)
x = F.relu(self.deconv1_bn(self.deconv1(input)))
x = F.relu(self.deconv2_bn(self.deconv2(x)))
x = F.relu(self.deconv3_bn(self.deconv3(x)))
Expand Down Expand Up @@ -61,11 +61,14 @@ def weight_init(self, mean, std):

# forward method
def forward(self, input):
# input = input.view(-1,28,28)
# print (input.shape)
x = F.leaky_relu(self.conv1(input), 0.2)
x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
x = F.sigmoid(self.conv5(x))
temp = self.conv5(x)
x = F.sigmoid(temp)

return x

Expand All @@ -74,17 +77,23 @@ def normal_init(m, mean, std):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()

fixed_z_ = torch.randn((5 * 5, 100)).view(-1, 100, 1, 1) # fixed noise
fixed_z_ = Variable(fixed_z_.cuda(), volatile=True)
# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 100
PATH = './MNIST_DCGAN_results'
NOISE_SIZE = 100

fixed_z_ = torch.randn((5 * 5, NOISE_SIZE)) # fixed noise
fixed_z_ = Variable(fixed_z_.cuda())

def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):
z_ = torch.randn((5*5, 100)).view(-1, 100, 1, 1)
z_ = Variable(z_.cuda(), volatile=True)
z_ = torch.randn((5*5, NOISE_SIZE))
z_ = Variable(z_.cuda())

G.eval()
if isFix:
test_images = G(fixed_z_)
else:
test_images = G(z_)
test_images = G(fixed_z_) if isFix else G(z_)

G.train()

size_figure_grid = 5
Expand All @@ -97,10 +106,11 @@ def show_result(num_epoch, show = False, save = False, path = 'result.png', isFi
i = k // 5
j = k % 5
ax[i, j].cla()
ax[i, j].imshow(test_images[k, 0].cpu().data.numpy(), cmap='gray')
ax[i, j].imshow(test_images[k, :].cpu().data.view(64, 64).numpy(), cmap='gray')

label = 'Epoch {0}'.format(num_epoch)
fig.text(0.5, 0.04, label, ha='center')
print ("Saved to :",path)
plt.savefig(path)

if show:
Expand All @@ -117,7 +127,7 @@ def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
plt.plot(x, y1, label='D_loss')
plt.plot(x, y2, label='G_loss')

plt.xlabel('Iter')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.legend(loc=4)
Expand All @@ -132,21 +142,19 @@ def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
else:
plt.close()

# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 20

# data_loader
img_size = 64
transform = transforms.Compose([
transforms.Scale(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

print ("Loading data")
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=transform),
batch_size=batch_size, shuffle=True)
print ("Loaded data")

# network
G = generator(128)
Expand All @@ -156,6 +164,8 @@ def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
G.cuda()
D.cuda()


print ("Created models")
# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

Expand All @@ -164,100 +174,107 @@ def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

# results save folder
if not os.path.isdir('MNIST_DCGAN_results'):
os.mkdir('MNIST_DCGAN_results')
if not os.path.isdir('MNIST_DCGAN_results/Random_results'):
os.mkdir('MNIST_DCGAN_results/Random_results')
if not os.path.isdir('MNIST_DCGAN_results/Fixed_results'):
os.mkdir('MNIST_DCGAN_results/Fixed_results')

train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []
num_iter = 0

print('training start!')
start_time = time.time()
for epoch in range(train_epoch):
D_losses = []
G_losses = []
epoch_start_time = time.time()
for x_, _ in train_loader:
# train discriminator D
D.zero_grad()

mini_batch = x_.size()[0]

y_real_ = torch.ones(mini_batch)
y_fake_ = torch.zeros(mini_batch)

x_, y_real_, y_fake_ = Variable(x_.cuda()), Variable(y_real_.cuda()), Variable(y_fake_.cuda())
D_result = D(x_).squeeze()
D_real_loss = BCE_loss(D_result, y_real_)

z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1)
z_ = Variable(z_.cuda())
G_result = G(z_)

D_result = D(G_result).squeeze()
D_fake_loss = BCE_loss(D_result, y_fake_)
D_fake_score = D_result.data.mean()

D_train_loss = D_real_loss + D_fake_loss

D_train_loss.backward()
D_optimizer.step()

# D_losses.append(D_train_loss.data[0])
D_losses.append(D_train_loss.data[0])

# train generator G
G.zero_grad()

z_ = torch.randn((mini_batch, 100)).view(-1, 100, 1, 1)
z_ = Variable(z_.cuda())

G_result = G(z_)
D_result = D(G_result).squeeze()
G_train_loss = BCE_loss(D_result, y_real_)
G_train_loss.backward()
G_optimizer.step()

G_losses.append(G_train_loss.data[0])

num_iter += 1

epoch_end_time = time.time()
per_epoch_ptime = epoch_end_time - epoch_start_time


print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)),
torch.mean(torch.FloatTensor(G_losses))))
p = 'MNIST_DCGAN_results/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
fixed_p = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
show_result((epoch+1), save=True, path=p, isFix=False)
show_result((epoch+1), save=True, path=fixed_p, isFix=True)
train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)

print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime))
print("Training finish!... save training results")
torch.save(G.state_dict(), "MNIST_DCGAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_DCGAN_results/discriminator_param.pkl")
with open('MNIST_DCGAN_results/train_hist.pkl', 'wb') as f:
pickle.dump(train_hist, f)

show_train_hist(train_hist, save=True, path='MNIST_DCGAN_results/MNIST_DCGAN_train_hist.png')

images = []
for e in range(train_epoch):
img_name = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(e + 1) + '.png'
images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_DCGAN_results/generation_animation.gif', images, fps=5)
if not os.path.isdir(PATH):
os.mkdir(PATH)
if not os.path.isdir(PATH+'/Random_results'):
os.mkdir(PATH+'/Random_results')
if not os.path.isdir(PATH+'/Fixed_results'):
os.mkdir(PATH+'/Fixed_results')

def train_discriminator(input,mini_batch_size):
D.zero_grad()

y_real = torch.ones(mini_batch_size) # D(real) = 1
y_fake = torch.zeros(mini_batch_size) # D(fake) = 0
y_real, y_fake = Variable(y_real.cuda()), Variable(y_fake.cuda())

# Calculate loss for real sample
# x = input.view(-1, 28, 28)
x = Variable(input.cuda())

D_real_result = D(x).view((-1))
D_real_loss = BCE_loss(D_real_result, y_real)

# Calculate loss for generated sample
z = torch.randn((mini_batch_size, NOISE_SIZE))
z = Variable(z.cuda())
G_result = G(z) # Generator's result

D_fake_result = D(G_result)
D_fake_loss = BCE_loss(D_fake_result, y_fake)

# Calculating total loss
D_train_loss = D_real_loss + D_fake_loss

# Propogate loss backwards and return loss
D_train_loss.backward()
D_optimizer.step()
return D_train_loss.item()

def train_generator(mini_batch_size):
G.zero_grad()

# Generate z with random values
z = torch.randn((mini_batch_size, NOISE_SIZE))
y = torch.ones(mini_batch_size) # Attempting to be real
z, y = Variable(z.cuda()), Variable(y.cuda())

# Calculate loss for generator
# Comparing discriminator's prediction with ones (ie, real)
G_result = G(z)
D_result = D(G_result).view((-1))
G_train_loss = BCE_loss(D_result, y)

# Propogate loss backwards and return loss
G_train_loss.backward()
G_optimizer.step()
return G_train_loss.item()

def save_models(train_hist):
torch.save(G.state_dict(), PATH+"/generator_param_"+str(train_epoch)+".pkl")
torch.save(D.state_dict(), PATH +"/discriminator_param_"+str(train_epoch)+".pkl")

with open(PATH+'/train_hist.pkl', 'wb') as f:
pickle.dump(train_hist, f)

def save_gif():
images = []
for e in range(train_epoch):
img_name = PATH+'/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
images.append(imageio.imread(img_name))
imageio.mimsave(PATH+'/generation_animation.gif', images, fps=5)

def train():
train_hist = {'D_losses':[],'G_losses':[]}
for epoch in range(train_epoch):
D_losses = []
G_losses = []
iter_lim = 0
for x_, _ in train_loader:
iter_lim+=1
if iter_lim == 20:
break
mini_batch_size = x_.size()[0]
D_loss = train_discriminator(x_,mini_batch_size)
D_losses.append(D_loss)

G_loss = train_generator(mini_batch_size)
G_losses.append(G_loss)
print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
(epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
p = PATH+'/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
fixed_p = PATH+'/Fixed_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
show_result((epoch+1), save=True, path=p, isFix=False)
show_result((epoch+1), save=True, path=fixed_p, isFix=True)

train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))


print("Training complete. Saving.")
save_models(train_hist)
show_train_hist(train_hist, save=True, path=PATH+'/MNIST_GAN_train_hist.png')
save_gif()

if __name__ == '__main__':
train()
Loading