Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Tensorboard #53

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
__pycache__

../preprocessing_mask_celebA_hq.py
img_align_celeba
kaggle.json
test.ipynb
models/*
samples/*
attn/*
logs/*
data/*
env
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@ Self-attentions are applied to later two layers of both discriminator and genera

## Current update status
* [ ] Supervised setting
* [ ] Tensorboard loggings
* [x] Tensorboard loggings
* [x] **[20180608] updated the self-attention module. Thanks to my colleague [Cheonbok Park](https://github.com/cheonbok94)! see 'sagan_models.py' for the update. Should be efficient, and run on large sized images**
* [x] Attention visualization (LSUN Church-outdoor)
* [x] Unsupervised setting (use no label yet)
@@ -35,6 +35,8 @@ Per-pixel attention result of SAGAN on LSUN church-outdoor dataset. It shows tha
## Prerequisites
* [Python 3.5+](https://www.continuum.io/downloads)
* [PyTorch 0.3.0](http://pytorch.org/)
#### For Tensorboard
* [Tensorboard 2.0+](https://www.tensorflow.org/tensorboard)

 

@@ -47,11 +49,10 @@ $ cd Self-Attention-GAN
```

#### 2. Install datasets (CelebA or LSUN)
```bash
$ bash download.sh CelebA
or
$ bash download.sh LSUN
```
Save All datasets in data dir
```
* [Celeb](https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg)


#### 3. Train
@@ -61,6 +62,13 @@ $ python python main.py --batch_size 64 --imsize 64 --dataset celeb --adv_loss h
or
$ python python main.py --batch_size 64 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun
```
##### (ii) With Tensorboard
```bash
$ python python main.py --batch_size 64 --imsize 64 --dataset celeb --adv_loss hinge --version sagan_celeb --use_tensorboard
or
$ python python main.py --batch_size 64 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun --use_tensorboard
```

#### 4. Enjoy the results
```bash
$ cd samples/sagan_celeb
58 changes: 54 additions & 4 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
import os
import glob
import torch
import torchvision.datasets as dsets
from PIL import Image
from os.path import join
from os.path import basename
from torchvision import transforms
from torch.utils.data import Dataset
import torchvision.datasets as dsets


class DataLoaderSegmentation(Dataset):
def __init__(self,path, img_transform=None, mask_transform=None):
super(DataLoaderSegmentation, self).__init__()
self.path = path
self.img_transform = img_transform
self.mask_transform = mask_transform
self.img_files = glob.glob(join(self.path, 'CelebA-HQ-img','*.jpg'))
self.mask_files = []
for i, img_path in enumerate(self.img_files):
img_val = int(img_path.split('/')[-1][:-4])
self.mask_files.append(join(self.path,'mask','{}.png'.format(img_val)))

def __getitem__(self,index):
img_path = self.img_files[index]
mask_path = self.mask_files[index]
image = Image.open(img_path)
mask = Image.open(mask_path)
if self.img_transform:
image = self.img_transform(image)
if self.mask_transform:
mask = self.mask_transform(mask)
return image,mask

def __len__(self):
return len(self.mask_files)

class Data_Loader():
def __init__(self, train, dataset, image_path, image_size, batch_size, shuf=True):
self.dataset = dataset
@@ -31,17 +63,35 @@ def load_lsun(self, classes='church_outdoor_train'):
return dataset

def load_celeb(self):
path = join(self.path, 'CelebA')
transforms = self.transform(True, True, True, True)
dataset = dsets.ImageFolder(self.path+'/CelebA', transform=transforms)
dataset = dsets.ImageFolder(root=path, transform=transforms)
return dataset

def load_imagenet(self):
if self.train:
path = join(self.path, 'train')
else:
path = self.path
transforms = self.transform(True, True, True, False)
dataset = dsets.ImageFolder(root=path,transform=transforms)
return dataset

def load_celebA_semantic(self):
img_transform = self.transform(True,True,True,False)
mask_transform = self.transform(True, True, False, False)
dataset = DataLoaderSegmentation(path,img_transform, mask_transform)
return dataset


def loader(self):
if self.dataset == 'lsun':
dataset = self.load_lsun()
elif self.dataset == 'celeb':
dataset = self.load_celeb()

elif self.dataset == 'imagenet':
dataset = self.load_imagenet()
elif self.dataset == 'semantic':
dataset = self.load_custom()
loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=self.batch,
shuffle=self.shuf,
47 changes: 47 additions & 0 deletions imagenet_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import pickle
from PIL import Image
import numpy as np
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo)
return dict

for j in range(2,11):
data_path = 'data/imagenet/train/train_data_batch_{}'.format(j)


def numpytoimage(data_path='data/imagenet/train/train_data_batch_1', output_dir ='data/imagenet/train'):
d = unpickle(data_path)
x = d['data']
y = d['labels']
mean_image = d['mean']
img_size = int((x.shape[1]/3)**(0.5))
data_size = x.shape[0]
img_size2 = img_size**2
x = np.dstack((x[:,:img_size2], x[:,img_size2:2*img_size2], x[:,2*img_size2:]))
x = x.reshape((x.shape[0], img_size, img_size,3))
# saving the image
for i in range(data_size):
class_count = y[i]
folder_path = output_dir+ '/{}'.format(class_count)
if not (os.path.isdir(folder_path)):
os.mkdir(folder_path)
file_count = len(os.listdir(folder_path))
file_name = output_dir+'/{}/{}.png'.format(class_count, file_count+1)
image = x[i,:,:,:]
image = (255.0 / image.max() * (image - image.min())).astype(np.uint8)
im = Image.fromarray(image)
im.save(file_name)
if i%10000 == 0:
print(file_name)
print(data_path)

def main():
for i in range(1,11):
data_path = 'data/imagenet/train/train_data_batch_{}'.format(i)
final_dir = 'data/imagenet/train'
numpytoimage(data_path, final_dir)

if __name__ == "__main__":
main()
20 changes: 20 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import numpy as np
# import matlibplot.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter


def Logger(path:str):
logger = SummaryWriter(path)
return logger

def show_img(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.cpu().numpy()
# plt.imshow(np.transpose(img, (1, 2, 0)))
return npimg
def image_grid_writer(writer, data, name, step, nrow=8):
image_grid = make_grid(data, nrow=nrow)
image_grid = show_img(image_grid)
writer.add_image(name, image_grid,step)
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@

import os
from parameter import *
from trainer import Trainer
# from tester import Tester
@@ -33,6 +33,8 @@ def main(config):
tester.test()

if __name__ == '__main__':
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '6'
config = get_parameters()
print(config)
main(config)
251 changes: 251 additions & 0 deletions newGenrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from spectral import SpectralNorm
from torch.autograd import Variable



class Self_Attn(nn.Module):
""" Self attention Layer"""
def __init__(self,in_dim,activation):
super(Self_Attn,self).__init__()
self.chanel_in = in_dim
self.activation = activation

self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
self.gamma = nn.Parameter(torch.zeros(1))

self.softmax = nn.Softmax(dim=-1) #
def forward(self,x):
"""
inputs :
x : input feature maps( B X C X W X H)
returns :
out : self attention value + input feature
attention: B X N X N (N is Width*Height)
"""
m_batchsize,C,width ,height = x.size()
proj_query = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
proj_key = self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
energy = torch.bmm(proj_query,proj_key) # transpose check
attention = self.softmax(energy) # BX (N) X (N)
proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

out = torch.bmm(proj_value,attention.permute(0,2,1) )
out = out.view(m_batchsize,C,width,height)

out = self.gamma*out + x
return out,attention


class baseGenBlock(nn.Module):
"""base block for generator"""
self.layer
def __init__(self, in_channel=64, out_channel=64, size=4, stride=1, padding=0):
super(baseGenBlock,self).__init__()
layer = []
layer.append(SpectralNorm(nn.ConvTranspose2d(in_channel, out_channel,size,stride,padding)))
layer.append(nn.BatchNorm2d(out_channel))
layer.append(nn.ReLU())
self.layer = nn.Sequential(*layer)

def forward(self,x):
out = self.layer(x)
return out

class hqGenerator(nn.Module):
'''hqGenerator.'''
def __init__(self, batch_size, image_size=64, z_dim=128, conv_dim=64):
super(hqGenerator, self).__init__()
self.image_size = image_size
self.batch_size = batch_size
conv_block = []
last_layer = []
out_shape = conv_dim*16
repeat_num = int(np.log2(self.imsize)) - 3
conv_block.append(baseGenBlock(z_dim,out_shape,4))
mult = 2 ** repeat_num
while(out_shape>=128):
conv_block.append(baseGenBlock(out_shape,int(out_shape/2),4, 2,1))
out_shape = int(out_shape/2)
self.conv_block = nn.Sequential(*conv_block)
self.middle_layer = baseGenBlock(128,64,4,2)
last_layer.append(nn.ConvTranspose2d(64,3,4,2,1))
last_layer.append(nn.Tanh())
self.final_layer = nn.Sequential(*last_layer)
self.attn1 = Self_Attn( 128, 'relu')
self.attn2 = Self_Attn( 64, 'relu')

def forward(self,x):
out = self.conv_block(x)
out,p1 = self.atten1(out)
out = self.middle_layer(out)
out,p2 = self.attn2(out)
self.final_layer(out)
out = nn.functional.interpolate(out, (self.batch_size, self.image_size,self.image_size))
return out, p1, p2

class baseDisBlock(nn.Module):
"""base block for discriminator"""
self.layer
def __init__(self, in_channel=64, out_channel=64, size=4, stride=1, padding=1):
super(baseDisBlock, self).__init__()
layer.append(SpectralNorm(nn.Conv2d(in_channel, out_channel, size, stride)))
layer.append(nn.LeakyReLU(0.1))
self.layer = nn.Sequential(*layer)
def forward(self,x):
out = self.layer(x)
return out

class hqDiscriminator(nn.Module):
def __init__(self,batch_size=64, image_size=64, conv_dim=64):
super(hqDiscriminator,self).__init__()
self.imsize = image_size

conv_block = []
loop_count = np.log2(self.image_size) - 4
curr_dim = 2 **int(np.log2(256) - loop_count)
conv_block.append(baseDisBlock(3, curr_dim,4,2,1))
for i in range(loop_count):
curr_dim *=2
conv_block.append(baseDisBlock(curr_dim//2, curr_dim,4,2,1))
curr_dim *= 2
self.conv_block = nn.Sequential(*conv_block)
self.middle_layer = baseDisBlock(curr_dim//2, curr_dim,4,2,1)
self.final_layer = nn.Conv2d(curr_dim*2,1,4)
self.attn1 = Self_Attn(256, 'relu')
self.attn2 = Self_Attn(512, 'relu')

def forward(self,x):
out = conv_block(x)
out, p1 = self.attn1(out)
out = self.middle_layer(out)
out, p2 = self.attn2(out)
out = self.final_layer(out)
return out.squeeze(), p1, p2

class Generator(nn.Module):
"""Generator."""

def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
super(Generator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []

repeat_num = int(np.log2(self.imsize)) - 3
mult = 2 ** repeat_num # 8
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
layer1.append(nn.BatchNorm2d(conv_dim * mult))
layer1.append(nn.ReLU())

curr_dim = conv_dim * mult

layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer2.append(nn.ReLU())

curr_dim = int(curr_dim / 2)

layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer3.append(nn.ReLU())

if self.imsize == 64:
layer4 = []
curr_dim = int(curr_dim / 2)
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer4.append(nn.ReLU())
self.l4 = nn.Sequential(*layer4)
curr_dim = int(curr_dim / 2)

self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)

last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
last.append(nn.Tanh())
self.last = nn.Sequential(*last)

self.attn1 = Self_Attn( 128, 'relu')
self.attn2 = Self_Attn( 64, 'relu')

def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
out=self.l1(z)
out=self.l2(out)
out=self.l3(out)
out,p1 = self.attn1(out)
out=self.l4(out)
out,p2 = self.attn2(out)
out=self.last(out)

return out, p1, p2

class Generator(nn.Module):
"""Generator."""

def __init__(self, batch_size, image_size=64, z_dim=100, conv_dim=64):
super(Generator, self).__init__()
self.imsize = image_size
layer1 = []
layer2 = []
layer3 = []
last = []

repeat_num = int(np.log2(self.imsize)) - 3
mult = 2 ** repeat_num # 8
layer1.append(SpectralNorm(nn.ConvTranspose2d(z_dim, conv_dim * mult, 4)))
layer1.append(nn.BatchNorm2d(conv_dim * mult))
layer1.append(nn.ReLU())

curr_dim = conv_dim * mult

layer2.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer2.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer2.append(nn.ReLU())

curr_dim = int(curr_dim / 2)

layer3.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer3.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer3.append(nn.ReLU())

if self.imsize == 64:
layer4 = []
curr_dim = int(curr_dim / 2)
layer4.append(SpectralNorm(nn.ConvTranspose2d(curr_dim, int(curr_dim / 2), 4, 2, 1)))
layer4.append(nn.BatchNorm2d(int(curr_dim / 2)))
layer4.append(nn.ReLU())
self.l4 = nn.Sequential(*layer4)
curr_dim = int(curr_dim / 2)

self.l1 = nn.Sequential(*layer1)
self.l2 = nn.Sequential(*layer2)
self.l3 = nn.Sequential(*layer3)

last.append(nn.ConvTranspose2d(curr_dim, 3, 4, 2, 1))
last.append(nn.Tanh())
self.last = nn.Sequential(*last)

self.attn1 = Self_Attn( 128, 'relu')
self.attn2 = Self_Attn( 64, 'relu')

def forward(self, z):
z = z.view(z.size(0), z.size(1), 1, 1)
out=self.l1(z)
out=self.l2(out)
out=self.l3(out)
out,p1 = self.attn1(out)
out=self.l4(out)
out,p2 = self.attn2(out)
out=self.last(out)

return out, p1, p2
4 changes: 2 additions & 2 deletions parameter.py
Original file line number Diff line number Diff line change
@@ -35,8 +35,8 @@ def get_parameters():
# Misc
parser.add_argument('--train', type=str2bool, default=True)
parser.add_argument('--parallel', type=str2bool, default=False)
parser.add_argument('--dataset', type=str, default='cifar', choices=['lsun', 'celeb'])
parser.add_argument('--use_tensorboard', type=str2bool, default=False)
parser.add_argument('--dataset', type=str, default='celeb', choices=['semantic', 'imagenet' ,'lsun', 'celeb'])
parser.add_argument('--use_tensorboard', action='store_true', default=False)

# Path
parser.add_argument('--image_path', type=str, default='./data')
38 changes: 38 additions & 0 deletions preprocessing_mask_celebA_hq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-

import os.path as osp
import os
import cv2
from transform import *
from PIL import Image
base_path = '/home/ugrads/a/avashist/Self-Attention-GAN/data'
face_data = osp.join(base_path, 'CelebAMask-HQ/CelebA-HQ-img')
face_sep_mask = osp.join(base_path, 'CelebAMask-HQ/CelebAMask-HQ-mask-anno')
mask_path = osp.join(base_path, 'CelebAMask-HQ/mask')
counter = 0
total = 0
for i in range(15):

atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']

for j in range(i * 2000, (i + 1) * 2000):

mask = np.zeros((512, 512))

for l, att in enumerate(atts, 1):
total += 1
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
path = osp.join(face_sep_mask, str(i), file_name)

if os.path.exists(path):
counter += 1
sep_mask = np.array(Image.open(path).convert('P'))
# print(np.unique(sep_mask))

mask[sep_mask == 225] = l
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
print('{}/{}.png'.format(mask_path, j))

print(counter, total)
165 changes: 165 additions & 0 deletions sagan_models.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,48 @@ def forward(self,x):
out = self.gamma*out + x
return out,attention

class lightSelfAtten(nn.Module):
''' channel attention block'''
def __init__(self, in_channels, activation):
super(lightSelfAtten, self).__init__()
self.gamma = nn.Parameter(torch.zeros(1))
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
self.softmax = nn.Softmax(dim=-1)
def forward(self,x):
bh,ch,hi,wh = x.shape
first = x.view(bh,ch,-1)
second = x.view(bh,ch,-1).permute(0,2,1)
third = x.view(bh,ch,-1)
atten = torch.bmm(first,second)
atten = torch.max(atten, dim=-1, keepdim=True)[0].expand_as(atten) - atten
atten = torch.bmm(atten,third).view(bh,ch,hi,wh)
atten = self.softmax(self.depthwise(atten))
out = self.gamma*atten + x
return out, atten



class lightSelfAtten2(nn.Module):
def __init__(self, in_channels, activation):
super(lightSelfAtten2,self).__init__()

self.activation = activation
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
self.pointwise = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
bh,ch,hi,wh = x.shape
x = x.view([bh,ch,hi*wh])
first = torch.matmul(torch.transpose(x,2,1),x)
atten = torch.matmul(x,first).view([bh,ch,hi,wh])
atten = self.depthwise(atten)
atten = self.pointwise(atten)
x = x.view([bh,ch,hi,wh])
out = self.gamma*atten + x
# print(self.gamma)
return out, atten


class Generator(nn.Module):
"""Generator."""

@@ -151,3 +193,126 @@ def forward(self, x):
out=self.last(out)

return out.squeeze(), p1, p2


class baseGenBlock(nn.Module):
"""base block for generator"""
def __init__(self, in_channel=64, out_channel=64, size:int=4, stride=1, padding=0):
super(baseGenBlock,self).__init__()
layer = []
layer.append(SpectralNorm(nn.ConvTranspose2d(in_channel, out_channel,size,stride,padding)))
layer.append(nn.BatchNorm2d(out_channel))
layer.append(nn.ReLU())
self.layer = nn.Sequential(*layer)

def forward(self,x):
out = self.layer(x)
return out

class hqGenerator(nn.Module):
'''hqGenerator.'''
def __init__(self, batch_size, image_size=64, z_dim=128, conv_dim=64):
super(hqGenerator, self).__init__()
self.image_size = image_size
self.batch_size = batch_size
self.layer_count = int(np.log2(self.image_size)) - 7
conv_block = []
last_layer = []
curr_shape = self.image_size
conv_block.append(baseGenBlock(z_dim, self.image_size,4))
for i in range(2):
conv_block.append(baseGenBlock(curr_shape, int(curr_shape//2), 4, 2, 1))
curr_shape = int(curr_shape//2)
conv_block.append(baseGenBlock(curr_shape, 128,4,2,1))
self.conv_block = nn.Sequential(*conv_block)
self.middle_layer = baseGenBlock(128,64,4,2,1)
curr_shape = 64
for i in range(self.layer_count):
last_layer.append(baseGenBlock(curr_shape,int(curr_shape//2), 4, 2, 1))
curr_shape = int(curr_shape//2)
last_layer.append(nn.ConvTranspose2d(curr_shape,3,4,2,1))
last_layer.append(nn.Tanh())
self.final_layer = nn.Sequential(*last_layer)
self.attn1 = lightSelfAtten(128,'relu')
self.attn2 = lightSelfAtten(64,'relu')
# self.attn1 = Self_Attn( 128, 'relu')
# self.attn2 = Self_Attn( 64, 'relu')

def forward(self,x):
# print("layer count in the block : {}".format(self.layer_count))
# print("input x shape : ", x.size())
x = x.view(x.size(0), x.size(1), 1, 1)
# print("input x shape after reshape : ", x.size())
# out = self.inital_layer(x)
# print("input out shape after init : ", out.size())
# out, p1 = self.attn1(out)
# print("input out shape after atten : ", out.size())
out = self.conv_block(x)
# print("input out shape after conv : ", x.size())
out,p1 = self.attn1(out)
# print("input out shape after atten : ", out.size())
out = self.middle_layer(out)
# print("input out shape after middle : ", out.size())
out,p2 = self.attn2(out)
# print("input out shape after atten : ", out.size())
out = self.final_layer(out)
# print("input out shape after final : ", out.size())
return out, p1, p2

class baseEncBlock(nn.Module):
"""base block for encoding"""
def __init__(self, in_channel=64, out_channel=64, size=4, stride=1, padding=1):
super(baseEncBlock, self).__init__()
layer = []
layer.append(SpectralNorm(nn.Conv2d(in_channel, out_channel, size, stride, padding)))
layer.append(nn.LeakyReLU(0.1))
self.layer = nn.Sequential(*layer)
def forward(self,x):
out = self.layer(x)
return out

class hqDiscriminator(nn.Module):
def __init__(self,batch_size=64, image_size=64, conv_dim=64):
super(hqDiscriminator,self).__init__()
conv_block = []
self.imsize = image_size
loop_count = int(np.log2(image_size) - 4)
curr_dim = 2**int(np.log2(256) - loop_count)
conv_block.append(baseEncBlock(3, curr_dim, 4, 2, 1))
for i in range(loop_count):
curr_dim *=2
conv_block.append(baseEncBlock(curr_dim//2, curr_dim,4,2,1))
curr_dim *= 2
self.conv_block = nn.Sequential(*conv_block)
self.middle_layer = baseEncBlock(curr_dim//2, curr_dim,4,2,1)
self.final_layer = nn.Conv2d(curr_dim,1,4)
# self.attn1 = Self_Attn(256, 'relu')
# self.attn2 = Self_Attn(512, 'relu')
self.attn1 = lightSelfAtten(256, 'relu')
self.attn2 = lightSelfAtten(512, 'relu')

def forward(self,x):
out = self.conv_block(x)
out, p1 = self.attn1(out)
out = self.middle_layer(out)
out, p2 = self.attn2(out)
out = self.final_layer(out)
return out.squeeze(), p1, p2

class mapEncoder(nn.Module):
def __init__(self, in_channel, im_size, z_dim):
super(mapEncoder,self).__init__()
self.z_dim = z_dim
self.im_size = im_size
self.in_channel = in_channel
self.count = int(np.log2(self.im_size) - 4)
layer = []
for i in range(self.count):
layer.append(baseEncBlock(1,1,4,2,1))
self.inital = nn.Sequential(*layer)
self.final = nn.Linear(256,z_dim)
def forward(self,x):
out = self.inital(x)
out = out.view(out.size(0),-1)
out = self.final(out)
return out
42 changes: 28 additions & 14 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -7,8 +7,8 @@
import torch.nn as nn
from torch.autograd import Variable
from torchvision.utils import save_image

from sagan_models import Generator, Discriminator
from logger import image_grid_writer
from sagan_models import Generator, Discriminator, hqGenerator, hqDiscriminator
from utils import *

class Trainer(object):
@@ -76,6 +76,8 @@ def train(self):
model_save_step = int(self.model_save_step * step_per_epoch)

# Fixed input for debugging


fixed_z = tensor2var(torch.randn(self.batch_size, self.z_dim))

# Start with trained model
@@ -165,33 +167,45 @@ def train(self):
g_loss_fake.backward()
self.g_optimizer.step()


iters = step+1
# Print out log info
if (step + 1) % self.log_step == 0:
if iters % self.log_step == 0:
elapsed = time.time() - start_time
elapsed = str(datetime.timedelta(seconds=elapsed))
if self.use_tensorboard:
self.logger.add_scalar("d_loss_real", d_loss_real.item(), iters)
self.logger.add_scalar("d_loss_fake", d_loss_fake.item(), iters)
self.logger.add_scalar("d_loss", d_loss.item(), iters)
self.logger.add_scalar("g_loss_fake", g_loss_fake.item(),iters)
self.logger.add_scalar("ave_gamma_l3", self.G.attn1.gamma.mean().item(),iters)
self.logger.add_scalar("ave_gamma_l4", self.G.attn2.gamma.mean().item(), iters)
print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_out_real: {:.4f}, "
" ave_gamma_l3: {:.4f}, ave_gamma_l4: {:.4f}".
format(elapsed, step + 1, self.total_step, (step + 1),
self.total_step , d_loss_real.data[0],
self.G.attn1.gamma.mean().data[0], self.G.attn2.gamma.mean().data[0] ))
format(elapsed, iters, self.total_step, iters,
self.total_step , d_loss_real.item(),
self.G.attn1.gamma.mean().item(), self.G.attn2.gamma.mean().item() ))

# Sample images
if (step + 1) % self.sample_step == 0:
if iters % self.sample_step == 0:
fake_images,_,_= self.G(fixed_z)
save_image(denorm(fake_images.data),
os.path.join(self.sample_path, '{}_fake.png'.format(step + 1)))
os.path.join(self.sample_path, '{}_fake.png'.format(iters)))
if self.use_tensorboard:
image_grid_writer(self.logger, fake_images.data.clone(),"fake_image",iters)
image_grid_writer(self.logger, real_images.data.clone(),"real_image",iters)

if (step+1) % model_save_step==0:
if (iters) % model_save_step==0:
torch.save(self.G.state_dict(),
os.path.join(self.model_save_path, '{}_G.pth'.format(step + 1)))
os.path.join(self.model_save_path, '{}_G.pth'.format(iters)))
torch.save(self.D.state_dict(),
os.path.join(self.model_save_path, '{}_D.pth'.format(step + 1)))
os.path.join(self.model_save_path, '{}_D.pth'.format(iters)))

def build_model(self):

self.G = Generator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim).cuda()
self.D = Discriminator(self.batch_size,self.imsize, self.d_conv_dim).cuda()
# self.G = Generator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim).cuda()
# self.D = Discriminator(self.batch_size,self.imsize, self.d_conv_dim).cuda()
self.G = hqGenerator(self.batch_size,self.imsize, self.z_dim, self.g_conv_dim).cuda()
self.D = hqDiscriminator(self.batch_size,self.imsize, self.d_conv_dim).cuda()
if self.parallel:
self.G = nn.DataParallel(self.G)
self.D = nn.DataParallel(self.D)
129 changes: 129 additions & 0 deletions transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/python
# -*- encoding: utf-8 -*-


from PIL import Image
import PIL.ImageEnhance as ImageEnhance
import random
import numpy as np

class RandomCrop(object):
def __init__(self, size, *args, **kwargs):
self.size = size

def __call__(self, im_lb):
im = im_lb['im']
lb = im_lb['lb']
assert im.size == lb.size
W, H = self.size
w, h = im.size

if (W, H) == (w, h): return dict(im=im, lb=lb)
if w < W or h < H:
scale = float(W) / w if w < h else float(H) / h
w, h = int(scale * w + 1), int(scale * h + 1)
im = im.resize((w, h), Image.BILINEAR)
lb = lb.resize((w, h), Image.NEAREST)
sw, sh = random.random() * (w - W), random.random() * (h - H)
crop = int(sw), int(sh), int(sw) + W, int(sh) + H
return dict(
im = im.crop(crop),
lb = lb.crop(crop)
)


class HorizontalFlip(object):
def __init__(self, p=0.5, *args, **kwargs):
self.p = p

def __call__(self, im_lb):
if random.random() > self.p:
return im_lb
else:
im = im_lb['im']
lb = im_lb['lb']

# atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
# 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']

flip_lb = np.array(lb)
flip_lb[lb == 2] = 3
flip_lb[lb == 3] = 2
flip_lb[lb == 4] = 5
flip_lb[lb == 5] = 4
flip_lb[lb == 7] = 8
flip_lb[lb == 8] = 7
flip_lb = Image.fromarray(flip_lb)
return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),
)


class RandomScale(object):
def __init__(self, scales=(1, ), *args, **kwargs):
self.scales = scales

def __call__(self, im_lb):
im = im_lb['im']
lb = im_lb['lb']
W, H = im.size
scale = random.choice(self.scales)
w, h = int(W * scale), int(H * scale)
return dict(im = im.resize((w, h), Image.BILINEAR),
lb = lb.resize((w, h), Image.NEAREST),
)


class ColorJitter(object):
def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
if not brightness is None and brightness>0:
self.brightness = [max(1-brightness, 0), 1+brightness]
if not contrast is None and contrast>0:
self.contrast = [max(1-contrast, 0), 1+contrast]
if not saturation is None and saturation>0:
self.saturation = [max(1-saturation, 0), 1+saturation]

def __call__(self, im_lb):
im = im_lb['im']
lb = im_lb['lb']
r_brightness = random.uniform(self.brightness[0], self.brightness[1])
r_contrast = random.uniform(self.contrast[0], self.contrast[1])
r_saturation = random.uniform(self.saturation[0], self.saturation[1])
im = ImageEnhance.Brightness(im).enhance(r_brightness)
im = ImageEnhance.Contrast(im).enhance(r_contrast)
im = ImageEnhance.Color(im).enhance(r_saturation)
return dict(im = im,
lb = lb,
)


class MultiScale(object):
def __init__(self, scales):
self.scales = scales

def __call__(self, img):
W, H = img.size
sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
imgs = []
[imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
return imgs


class Compose(object):
def __init__(self, do_list):
self.do_list = do_list

def __call__(self, im_lb):
for comp in self.do_list:
im_lb = comp(im_lb)
return im_lb




if __name__ == '__main__':
flip = HorizontalFlip(p = 1)
crop = RandomCrop((321, 321))
rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0))
img = Image.open('data/img.jpg')
lb = Image.open('data/label.png')