Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GAN with selftaughtlearning for unseen task processing #378

Open
wants to merge 1 commit into
base: feature-lifelong-n
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README_ospp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# README
------
I modified sedna source code to get my algorithm integrated into lifelong learning. I would display the framework and illustrute what I modified.
## Framework
![](framework_with_gan_selftaughtlearning.png)

## What I Modified

1. Delete redundant annotations.
2. Replce `print` with `logger`.
3. Integrate my algorithm into `sedna.lib.sedna.algorithms.unseen_task_processing.unseen_task_processing.py`.
4. Replace absolute path with relative path.
5. Remove redundant code.
6. Provide link for developers to download trained model.

## What I Refer to
I refer to [FastGAN-pytorch](https://github.com/odegeasslbc/FastGAN-pytorch) to implement my GAN module.

## Model to Download
In `GAN.lpips.weights`, developers may need the pre-trained model.
Also, the trained GAN model and trained encoder model is also provided.
Click here [link](https://drive.google.com/drive/folders/1IOQCQ3sntxrbt7RtJIsSlBo0PFrR7Ets?usp=share_link) to download.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import train
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

import torch
import torch.nn.functional as F


def DiffAugment(x, policy='', channels_first=True):
if policy:
if not channels_first:
x = x.permute(0, 3, 1, 2)
for p in policy.split(','):
for f in AUGMENT_FNS[p]:
x = f(x)
if not channels_first:
x = x.permute(0, 2, 3, 1)
x = x.contiguous()
return x


def rand_brightness(x):
x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
return x


def rand_saturation(x):
x_mean = x.mean(dim=1, keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
return x


def rand_contrast(x):
x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
return x


def rand_translation(x, ratio=0.125):
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(x.size(2), dtype=torch.long, device=x.device),
torch.arange(x.size(3), dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)
return x


def rand_cutout(x, ratio=0.5):
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
grid_batch, grid_x, grid_y = torch.meshgrid(
torch.arange(x.size(0), dtype=torch.long, device=x.device),
torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
)
grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
mask[grid_batch, grid_x, grid_y] = 0
x = x * mask.unsqueeze(1)
return x


AUGMENT_FNS = {
'color': [rand_brightness, rand_saturation, rand_contrast],
'translation': [rand_translation],
'cutout': [rand_cutout],
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch

from models import Generator, weights_init

import matplotlib.pyplot as plt

import os

from collections import OrderedDict

import numpy as np

from skimage import io


device = 'cuda'

ngf = 64
nz = 256
im_size = 1024
netG = Generator(ngf=ngf, nz=nz, im_size=im_size).to(device)
weights_init(netG)
weights = torch.load(os.getcwd() + '/train_results/test1/models/50000.pth')
netG_weights = OrderedDict()
for name, weight in weights['g'].items():
name = name.split('.')[1:]
name = '.'.join(name)
netG_weights[name] = weight
netG.load_state_dict(netG_weights)
current_batch_size = 1


index = 1
while index <= 3000:
noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device)
fake_images = netG(noise)[0]
for fake_image in fake_images:
fake_image = fake_image.detach().cpu().numpy().transpose(1, 2, 0)
fake_image = fake_image * np.array([0.5, 0.5, 0.5])
fake_image = fake_image + np.array([0.5, 0.5, 0.5])
fake_image = (fake_image * 255).astype(np.uint8)
io.imsave('../data/fake_imgs1/' + str(index) + '.png', fake_image)
index += 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import skimage
import torch
from torch.autograd import Variable

from lpips import dist_model


from skimage.metrics import structural_similarity as compare_ssim


class PerceptualLoss(torch.nn.Module):
# VGG using our perceptually-learned weights (LPIPS metric)
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]):
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
super(PerceptualLoss, self).__init__()
print('Setting up Perceptual loss...')
self.use_gpu = use_gpu
self.spatial = spatial
self.gpu_ids = gpu_ids
self.model = dist_model.DistModel()
self.model.initialize(model=model, net=net, use_gpu=use_gpu,
colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
print('...[%s] initialized' % self.model.name())
print('...Done')

def forward(self, pred, target, normalize=False):
"""
Pred and target are Variables.
If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
If normalize is False, assumes the images are already between [-1,+1]

Inputs pred and target are Nx3xHxW
Output pytorch Variable N long
"""

if normalize:
target = 2 * target - 1
pred = 2 * pred - 1

return self.model.forward(target, pred)


def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat/(norm_factor+eps)


def l2(p0, p1, range=255.):
return .5*np.mean((p0 / range - p1 / range)**2)


def psnr(p0, p1, peak=255.):
return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))


def dssim(p0, p1, range=255.):
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.


def rgb2lab(in_img, mean_cent=False):
from skimage import color
img_lab = color.rgb2lab(in_img)
if(mean_cent):
img_lab[:, :, 0] = img_lab[:, :, 0]-50
return img_lab


def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))


def np2tensor(np_obj):
# change dimenion of np array into tensor array
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))


def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
# image tensor to lab tensor
from skimage import color

img = tensor2im(image_tensor)
img_lab = color.rgb2lab(img)
if(mc_only):
img_lab[:, :, 0] = img_lab[:, :, 0]-50
if(to_norm and not mc_only):
img_lab[:, :, 0] = img_lab[:, :, 0]-50
img_lab = img_lab/100.

return np2tensor(img_lab)


def tensorlab2tensor(lab_tensor, return_inbnd=False):
from skimage import color
import warnings
warnings.filterwarnings("ignore")

lab = tensor2np(lab_tensor)*100.
lab[:, :, 0] = lab[:, :, 0]+50

rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
if(return_inbnd):
# convert back to lab, see if we match
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
mask = 1.*np.isclose(lab_back, lab, atol=2.)
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
return (im2tensor(rgb_back), mask)
else:
return im2tensor(rgb_back)


def rgb2lab(input):
from skimage import color
return color.rgb2lab(input / 255.)


def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)


def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))


def tensor2vec(vector_tensor):
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]


def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))

# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]

# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap


def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
# def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)


def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
# def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
return torch.Tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import torch
from torch.autograd import Variable
from pdb import set_trace as st
from IPython import embed

class BaseModel():
def __init__(self):
pass;

def name(self):
return 'BaseModel'

def initialize(self, use_gpu=True, gpu_ids=[0]):
self.use_gpu = use_gpu
self.gpu_ids = gpu_ids

def forward(self):
pass

def get_image_paths(self):
pass

def optimize_parameters(self):
pass

def get_current_visuals(self):
return self.input

def get_current_errors(self):
return {}

def save(self, label):
pass

# helper saving function that can be used by subclasses
def save_network(self, network, path, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(path, save_filename)
torch.save(network.state_dict(), save_path)

# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
print('Loading network from %s'%save_path)
network.load_state_dict(torch.load(save_path))

def update_learning_rate():
pass

def get_image_paths(self):
return self.image_paths

def save_done(self, flag=False):
np.save(os.path.join(self.save_dir, 'done_flag'),flag)
np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')

Loading