Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensorboard logging for scalars and images, add FastImageFolder for large folders #55

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions FastImageFolder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import torch
import numpy as np
from PIL import Image
from skimage.io import imread
import torch.utils.data as data

def gray_2_rgb(im):
w, h = im.shape
ret = np.empty((w, h, 3), dtype=np.uint8)
ret[:, :, 0] = im
ret[:, :, 1] = im
ret[:, :, 2] = im
return ret

class FastImageFolder(data.Dataset):
def __init__(self,
root = '',
img_list = [],
transform = None):

self.root = root
self.img_list = img_list
self.transforms = transform

def __getitem__(self, index):
_ = 1
try:
im = imread(os.path.join(self.root,self.img_list[index]))
# control for grayscale images
if len(im.shape)==2:
im = gray_2_rgb(im)
# control for images with alpha channel or other weird stuff
if im.shape[2]>3:
print('Alpha channel image {}'.format(index))
im = im[:,:,0:3]

# transform image to PIL format for PyTorch transforms
im = Image.fromarray(im)

if self.transforms is not None:
im = self.transforms(im)
except:
print('Exception triggered with image {}'.format(index))
im = imread('random.jpg')

if im.shape[2]>3:
print('Alpha channel image {}'.format(index))
im = im[:,:,0:3]

# transform image to PIL format for PyTorch transforms
im = Image.fromarray(im)

if self.transforms is not None:
im = self.transforms(im)
return im,_

def __len__(self):
return len(self.img_list)
71 changes: 71 additions & 0 deletions TbLogger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
import tensorflow as tf
import numpy as np
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x


class Logger(object):

def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)

def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)

def image_summary(self, tag, images, step):
"""Log a list of images."""

img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")

# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)

def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""

# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)

# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))

# Drop the start of the first bin
bin_edges = bin_edges[1:]

# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)

# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
73 changes: 73 additions & 0 deletions main.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import print_function
# Custom tensorboard logging
from TbLogger import Logger

import argparse
import random
import torch
Expand All @@ -15,6 +18,8 @@

import models.dcgan as dcgan
import models.mlp as mlp
from FastImageFolder import FastImageFolder
import pickle

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
Expand Down Expand Up @@ -43,9 +48,31 @@
parser.add_argument('--n_extra_layers', type=int, default=0, help='Number of extra layers on gen and disc')
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--tensorboard', action='store_true', help='Whether to use Tensorboard')
parser.add_argument('--tensorboard_images', action='store_true', help='Whether to use Tensorboard to diplay images')
parser.add_argument('--imgList', help='path to pre-processed image list for fast folder')
opt = parser.parse_args()
print(opt)


def to_np(x):
x = x.cpu().numpy()
if len(x.shape)>3:
return x[:,0:3,:,:]
else:
return x

# remove the log file if it exists if we run the script in the training mode
print('Folder {} delete triggered'.format(opt.experiment))
try:
shutil.rmtree('tb_logs/{}/'.format(opt.experiment))
except:
pass

# Set the Tensorboard logger
if opt.tensorboard or opt.tensorboard_images:
logger = Logger('./tb_logs/{}'.format(opt.experiment))

if opt.experiment is None:
opt.experiment = 'samples'
os.system('mkdir {0}'.format(opt.experiment))
Expand Down Expand Up @@ -85,6 +112,18 @@
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
)
elif opt.dataset == 'fastfolder':
# load the pre-processed list of images
with open(opt.imgList, 'rb') as handle:
img_list = pickle.load(handle)
# fast folder dataset
dataset = FastImageFolder(root=opt.dataroot,img_list = img_list,
transform=transforms.Compose([
transforms.Scale(opt.imageSize),
transforms.CenterCrop(opt.imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
shuffle=True, num_workers=int(opt.workers))
Expand Down Expand Up @@ -113,6 +152,7 @@ def weights_init(m):
netG = dcgan.DCGAN_G(opt.imageSize, nz, nc, ngf, ngpu, n_extra_layers)

netG.apply(weights_init)

if opt.netG != '': # load checkpoint if needed
netG.load_state_dict(torch.load(opt.netG))
print(netG)
Expand Down Expand Up @@ -149,6 +189,7 @@ def weights_init(m):
optimizerG = optim.RMSprop(netG.parameters(), lr = opt.lrG)

gen_iterations = 0
tb_steps = 0
for epoch in range(opt.niter):
data_iter = iter(dataloader)
i = 0
Expand Down Expand Up @@ -217,12 +258,44 @@ def weights_init(m):
print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
% (epoch, opt.niter, i, len(dataloader), gen_iterations,
errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

#============ TensorBoard logging ============#
# Log the scalar values
if opt.tensorboard:
info = {
'Loss_D': errD.data[0],
'Loss_G': errG.data[0],
'Loss_D_real': errD_real.data[0],
'Loss_D_fake': errD_fake.data[0],
}
for tag, value in info.items():
logger.scalar_summary(tag, value, tb_steps)

tb_steps += 1

if gen_iterations % 500 == 0:
real_cpu = real_cpu.mul(0.5).add(0.5)
vutils.save_image(real_cpu, '{0}/real_samples.png'.format(opt.experiment))
fake = netG(Variable(fixed_noise, volatile=True))
fake.data = fake.data.mul(0.5).add(0.5)
vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(opt.experiment, gen_iterations))

#============ TensorBoard logging ============#
# Show real samples
if opt.tensorboard_images:
info = {
'real_samples': to_np(real_cpu.view(-1,opt.nc,opt.imageSize, opt.imageSize)[:50])
}
for tag, images in info.items():
logger.image_summary(tag, images, tb_steps)
# Show fake samples
if opt.tensorboard_images:
info = {
'fake_samples': to_np(fake.data.view(-1,opt.nc,opt.imageSize, opt.imageSize)[:50])
}
for tag, images in info.items():
logger.image_summary(tag, images, tb_steps)


# do checkpointing
torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(opt.experiment, epoch))
Expand Down
Binary file added random.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.