Skip to content

Commit 1ce3d25

Browse files
committedAug 30, 2020
add the training script
1 parent 6c43621 commit 1ce3d25

15 files changed

+733
-135
lines changed
 

‎InstColorization.ipynb

+87-130
Large diffs are not rendered by default.

‎README.md

+3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ All the colorized results would save in `results` folder.
7070

7171
* Note: all the images would convert into L channel to colorize in [test_fusion.py's L51](test_fusion.py#L51)
7272

73+
## Training the Model
74+
Please follow this [tutorial](README_TRAIN.md) to train the colorization model.
75+
7376
## License
7477
This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
7578

‎README_TRAIN.md

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# [CVPR 2020] Instance-aware Image Colorization
2+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)
3+
4+
### [[Paper](https://arxiv.org/abs/2005.10825)] [[Project Website](https://ericsujw.github.io/InstColorization/)] [[Google Colab](https://colab.research.google.com/github/ericsujw/InstColorization/blob/master/InstColorization.ipynb)]
5+
6+
<p align='center'>
7+
<img src='imgs/teaser.png' width=1000>
8+
</p>
9+
10+
Image colorization is inherently an ill-posed problem with multi-modal uncertainty. Previous methods leverage the deep neural network to map input grayscale images to plausible color outputs directly. Although these learning-based methods have shown impressive performance, they usually fail on the input images that contain multiple objects. The leading cause is that existing models perform learning and colorization on the entire image. In the absence of a clear figure-ground separation, these models cannot effectively locate and learn meaningful object-level semantics. In this paper, we propose a method for achieving instance-aware colorization. Our network architecture leverages an off-the-shelf object detector to obtain cropped object images and uses an instance colorization network to extract object-level features. We use a similar network to extract the full-image features and apply a fusion module to full object-level and image-level features to predict the final colors. Both colorization networks and fusion modules are learned from a large-scale dataset. Experimental results show that our work outperforms existing methods on different quality metrics and achieves state-of-the-art performance on image colorization.
11+
12+
13+
**Instance-aware Image Colorization**
14+
<br/>
15+
[Jheng-Wei Su](https://github.com/ericsujw),
16+
[Hung-Kuo Chu](https://cgv.cs.nthu.edu.tw/hkchu/), and
17+
[Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/)
18+
<br/>
19+
In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020.
20+
21+
## Prerequisites
22+
* [CUDA 10.1](https://developer.nvidia.com/cuda-10.1-download-archive-update2)
23+
* Python3
24+
* Pytorch >= 1.5
25+
* Detectron2
26+
* OpenCV-Python
27+
* Pillow/scikit-image
28+
* Please refer to the [env.yml](env.yml) for detail dependencies.
29+
30+
## Getting Started
31+
1. Clone this repo:
32+
```sh
33+
git clone https://github.com/ericsujw/InstColorization
34+
cd InstColorization
35+
```
36+
2. Install [conda](https://www.anaconda.com/).
37+
3. Install all the dependencies
38+
```sh
39+
conda env create --file env.yml
40+
```
41+
4. Switch to the conda environment
42+
```sh
43+
conda activate instacolorization
44+
```
45+
5. Install other dependencies
46+
```sh
47+
sh scripts/install.sh
48+
```
49+
50+
## Dataset Preparation
51+
### COCOStuff
52+
1. Download and unzip the COCOStuff training set:
53+
```sh
54+
sh scripts/prepare_cocostuff.sh
55+
```
56+
2. Now the COCOStuff train set would place in [train_data](train_data).
57+
58+
### Your own Dataset
59+
1. If you want to train on your dataset, you should change the dataset path in [scripts/prepare_train_box.sh's L1](scripts/prepare_train_box.sh#L1) and in [scripts/train.sh's L1](scripts/train.sh#L1).
60+
61+
## Pretrained Model
62+
1. Download it from [google drive](https://drive.google.com/open?id=1Xb-DKAA9ibCVLqm8teKd1MWk6imjwTBh).
63+
```sh
64+
sh scripts/download_model.sh
65+
```
66+
2. Now the pretrained models would place in [checkpoints](checkpoints).
67+
68+
## Instance Prediction
69+
Please follow the command below to predict all the bounding boxes fo the images in `${DATASET_DIR}` folder.
70+
```sh
71+
sh scripts/prepare_train_box.sh
72+
```
73+
All the prediction results would save in `${DATASET_DIR}_bbox` folder.
74+
75+
## Training the Instance-aware Image Colorization model
76+
Simply run the following command, then the training pipeline would get start.
77+
```sh
78+
sh scripts/train.sh
79+
```
80+
To view training results and loss plots, run `visdom -port 8098` and click the URL http://localhost:8098.
81+
82+
This is a 3 stage training process.
83+
1. We would start to train our full image colorization branch based on the [siggraph_retrained's pretrained weight](https://github.com/richzhang/colorization-pytorch).
84+
2. We would use the full image colorization branch's weight as our instance colorization branch's pretrained weight.
85+
3. Finally, we would train the fusion module.
86+
87+
## Testing the Instance-aware Image Colorization model
88+
1. Our model's weight would place in [checkpoints/coco_mask](checkpoints/coco_mask).
89+
2. Change the checkpoint's path in [test_fusion.py's L38](test_fusion.py#L38) from `coco_finetuned_mask_256_ffs` to `coco_mask`
90+
3. Please follow the command below to colorize all the images in `example` foler based on the weight placed in `coco_mask`.
91+
92+
```
93+
python test_fusion.py --name test_fusion --sample_p 1.0 --model fusion --fineSize 256 --test_img_dir example --results_img_dir results
94+
```
95+
All the colorized results would save in `results` folder.
96+
97+
## License
98+
This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
99+
100+
## Citation
101+
If you find our code/models useful, please consider citing our paper:
102+
```
103+
@inproceedings{Su-CVPR-2020,
104+
author = {Su, Jheng-Wei and Chu, Hung-Kuo and Huang, Jia-Bin},
105+
title = {Instance-aware Image Colorization},
106+
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
107+
year = {2020}
108+
}
109+
```
110+
111+
## Acknowledgments
112+
Our code borrows heavily from the amazing [colorization-pytorch](https://github.com/richzhang/colorization-pytorch) repository.

‎download.py

+24-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
22
import requests
3+
from os.path import join, isdir
4+
import os
5+
from argparse import ArgumentParser
36

47
def download_file_from_google_drive(id, destination):
58
URL = "https://docs.google.com/uc?export=download"
@@ -30,6 +33,24 @@ def save_response_content(response, destination):
3033
if chunk: # filter out keep-alive new chunks
3134
f.write(chunk)
3235

33-
file_id = '1Xb-DKAA9ibCVLqm8teKd1MWk6imjwTBh'
34-
destination = 'checkpoints.zip'
35-
download_file_from_google_drive(file_id, destination)
36+
37+
parser = ArgumentParser()
38+
parser.add_argument("--mode", type=str, default='pretrained-weight', help='pretrained-weight / cocostuff')
39+
parser.add_argument("--dataset_dir", type=str, default='data', help='training dataset path')
40+
args = parser.parse_args()
41+
42+
if args.mode == 'pretrained-weight':
43+
44+
file_id = '1Xb-DKAA9ibCVLqm8teKd1MWk6imjwTBh'
45+
destination = 'checkpoints.zip'
46+
download_file_from_google_drive(file_id, destination)
47+
48+
elif args.mode == 'cocostuff':
49+
print('download cocostuff training dataset')
50+
url = "http://images.cocodataset.org/zips/train2017.zip"
51+
response = requests.get(url, stream = True)
52+
if isdir(join(args.dataset_dir, "cocostuff")) is False:
53+
os.makedirs(join(args.dataset_dir, "cocostuff"))
54+
save_response_content(response, join(args.dataset_dir, "cocostuff", "train.zip"))
55+
else:
56+
print('Error Mode!')

‎fusion_dataset.py

+122
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from os import listdir
22
from os.path import isfile, join
3+
from random import sample
34

45
import numpy as np
56
import torch
@@ -54,5 +55,126 @@ def __getitem__(self, index):
5455
output['empty_box'] = True
5556
return output
5657

58+
def __len__(self):
59+
return len(self.IMAGE_ID_LIST)
60+
61+
62+
class Training_Full_Dataset(Data.Dataset):
63+
'''
64+
Training on COCOStuff dataset. [train2017.zip]
65+
66+
Download the training set from https://github.com/nightrome/cocostuff
67+
'''
68+
def __init__(self, opt):
69+
self.IMAGE_DIR = opt.train_img_dir
70+
self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2),
71+
transforms.ToTensor()])
72+
self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))]
73+
74+
def __getitem__(self, index):
75+
output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index])
76+
rgb_img, gray_img = gen_gray_color_pil(output_image_path)
77+
output = {}
78+
output['rgb_img'] = self.transforms(rgb_img)
79+
output['gray_img'] = self.transforms(gray_img)
80+
return output
81+
82+
def __len__(self):
83+
return len(self.IMAGE_ID_LIST)
84+
85+
86+
class Training_Instance_Dataset(Data.Dataset):
87+
'''
88+
Training on COCOStuff dataset. [train2017.zip]
89+
90+
Download the training set from https://github.com/nightrome/cocostuff
91+
92+
Make sure you've predicted all the images' bounding boxes using inference_bbox.py
93+
94+
It would be better if you can filter out the images which don't have any box.
95+
'''
96+
def __init__(self, opt):
97+
self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.train_img_dir)
98+
self.IMAGE_DIR = opt.train_img_dir
99+
self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))]
100+
self.transforms = transforms.Compose([
101+
transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2),
102+
transforms.ToTensor()
103+
])
104+
105+
def __getitem__(self, index):
106+
pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz')
107+
output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index])
108+
pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path)
109+
110+
rgb_img, gray_img = gen_gray_color_pil(output_image_path)
111+
112+
index_list = range(len(pred_bbox))
113+
index_list = sample(index_list, 1)
114+
startx, starty, endx, endy = pred_bbox[index_list[0]]
115+
output = {}
116+
output['rgb_img'] = self.transforms(rgb_img.crop((startx, starty, endx, endy)))
117+
output['gray_img'] = self.transforms(gray_img.crop((startx, starty, endx, endy)))
118+
return output
119+
120+
def __len__(self):
121+
return len(self.IMAGE_ID_LIST)
122+
123+
124+
class Training_Fusion_Dataset(Data.Dataset):
125+
'''
126+
Training on COCOStuff dataset. [train2017.zip]
127+
128+
Download the training set from https://github.com/nightrome/cocostuff
129+
130+
Make sure you've predicted all the images' bounding boxes using inference_bbox.py
131+
132+
It would be better if you can filter out the images which don't have any box.
133+
'''
134+
def __init__(self, opt, box_num=8):
135+
self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.train_img_dir)
136+
self.IMAGE_DIR = opt.train_img_dir
137+
self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))]
138+
139+
self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2),
140+
transforms.ToTensor()])
141+
self.final_size = opt.fineSize
142+
self.box_num = box_num
143+
144+
def __getitem__(self, index):
145+
pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz')
146+
output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index])
147+
pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path, self.box_num)
148+
149+
full_rgb_list = []
150+
full_gray_list = []
151+
rgb_img, gray_image = gen_gray_color_pil(output_image_path)
152+
full_rgb_list.append(self.transforms(rgb_img))
153+
full_gray_list.append(self.transforms(gray_image))
154+
155+
cropped_rgb_list = []
156+
cropped_gray_list = []
157+
index_list = range(len(pred_bbox))
158+
box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros((4, len(index_list), 6))
159+
for i in range(len(index_list)):
160+
startx, starty, endx, endy = pred_bbox[i]
161+
box_info[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size))
162+
box_info_2x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 2))
163+
box_info_4x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 4))
164+
box_info_8x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 8))
165+
cropped_rgb_list.append(self.transforms(rgb_img.crop((startx, starty, endx, endy))))
166+
cropped_gray_list.append(self.transforms(gray_image.crop((startx, starty, endx, endy))))
167+
output = {}
168+
output['cropped_rgb'] = torch.stack(cropped_rgb_list)
169+
output['cropped_gray'] = torch.stack(cropped_gray_list)
170+
output['full_rgb'] = torch.stack(full_rgb_list)
171+
output['full_gray'] = torch.stack(full_gray_list)
172+
output['box_info'] = torch.from_numpy(box_info).type(torch.long)
173+
output['box_info_2x'] = torch.from_numpy(box_info_2x).type(torch.long)
174+
output['box_info_4x'] = torch.from_numpy(box_info_4x).type(torch.long)
175+
output['box_info_8x'] = torch.from_numpy(box_info_8x).type(torch.long)
176+
output['file_id'] = self.IMAGE_ID_LIST[index]
177+
return output
178+
57179
def __len__(self):
58180
return len(self.IMAGE_ID_LIST)

‎image_util.py

+13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@
33
from skimage import color
44
import torch
55

6+
def gen_gray_color_pil(color_img_path):
7+
'''
8+
return: RGB and GRAY pillow image object
9+
'''
10+
rgb_img = Image.open(color_img_path)
11+
if len(np.asarray(rgb_img).shape) == 2:
12+
rgb_img = np.stack([np.asarray(rgb_img), np.asarray(rgb_img), np.asarray(rgb_img)], 2)
13+
rgb_img = Image.fromarray(rgb_img)
14+
gray_img = np.round(color.rgb2gray(np.asarray(rgb_img)) * 255.0).astype(np.uint8)
15+
gray_img = np.stack([gray_img, gray_img, gray_img], -1)
16+
gray_img = Image.fromarray(gray_img)
17+
return rgb_img, gray_img
18+
619
def read_to_pil(img_path):
720
'''
821
return: pillow image object HxWx3

‎inference_bbox.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from detectron2.config import get_cfg
1818

1919
import torch
20+
from tqdm import tqdm
2021

2122
cfg = get_cfg()
2223
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
@@ -26,6 +27,7 @@
2627

2728
parser = ArgumentParser()
2829
parser.add_argument("--test_img_dir", type=str, default='example', help='testing images folder')
30+
parser.add_argument('--filter_no_obj', action='store_true')
2931
args = parser.parse_args()
3032

3133
input_dir = args.test_img_dir
@@ -35,7 +37,7 @@
3537
print('Create path: {0}'.format(output_npz_dir))
3638
os.makedirs(output_npz_dir)
3739

38-
for image_path in image_list:
40+
for image_path in tqdm(image_list):
3941
img = cv2.imread(join(input_dir, image_path))
4042
lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
4143
l_channel, a_channel, b_channel = cv2.split(lab_image)
@@ -44,4 +46,8 @@
4446
save_path = join(output_npz_dir, image_path.split('.')[0])
4547
pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy()
4648
pred_scores = outputs["instances"].scores.cpu().data.numpy()
49+
if args.filter_no_obj is True and pred_bbox.shape[0] == 0:
50+
print('delete {0}'.format(image_path))
51+
os.remove(join(input_dir, image_path))
52+
continue
4753
np.savez(save_path, bbox = pred_bbox, scores = pred_scores)

‎models/base_model.py

-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def setup(self, opt, parser=None):
4141

4242
if not self.isTrain or opt.load_model:
4343
self.load_networks(opt.which_epoch)
44-
# self.print_networks(opt.verbose)
4544

4645
# make models eval mode during test time
4746
def eval(self):

‎models/networks.py

+40
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.nn import init
44
import functools
55
import torch.nn.functional as F
6+
from torch.optim import lr_scheduler
67

78

89
def get_norm_layer(norm_type='instance'):
@@ -17,6 +18,21 @@ def get_norm_layer(norm_type='instance'):
1718
return norm_layer
1819

1920

21+
def get_scheduler(optimizer, opt):
22+
if opt.lr_policy == 'lambda':
23+
def lambda_rule(epoch):
24+
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
25+
return lr_l
26+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
27+
elif opt.lr_policy == 'step':
28+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
29+
elif opt.lr_policy == 'plateau':
30+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
31+
else:
32+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
33+
return scheduler
34+
35+
2036
def init_weights(net, init_type='xavier', gain=0.02):
2137
def init_func(m):
2238
classname = m.__class__.__name__
@@ -65,6 +81,30 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
6581
return init_net(netG, init_type, gpu_ids)
6682

6783

84+
class HuberLoss(nn.Module):
85+
def __init__(self, delta=.01):
86+
super(HuberLoss, self).__init__()
87+
self.delta=delta
88+
89+
def __call__(self, in0, in1):
90+
mask = torch.zeros_like(in0)
91+
mann = torch.abs(in0-in1)
92+
eucl = .5 * (mann**2)
93+
mask[...] = mann < self.delta
94+
95+
# loss = eucl*mask + self.delta*(mann-.5*self.delta)*(1-mask)
96+
loss = eucl*mask/self.delta + (mann-.5*self.delta)*(1-mask)
97+
return torch.sum(loss,dim=1,keepdim=True)
98+
99+
100+
class L1Loss(nn.Module):
101+
def __init__(self):
102+
super(L1Loss, self).__init__()
103+
104+
def __call__(self, in0, in1):
105+
return torch.sum(torch.abs(in0-in1),dim=1,keepdim=True)
106+
107+
68108
class SIGGRAPHGenerator(nn.Module):
69109
def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, use_tanh=True, classification=True):
70110
super(SIGGRAPHGenerator, self).__init__()

‎models/train_model.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import os
2+
3+
import torch
4+
from collections import OrderedDict
5+
from util.image_pool import ImagePool
6+
from util import util
7+
from .base_model import BaseModel
8+
from . import networks
9+
import numpy as np
10+
from skimage import io
11+
from skimage import img_as_ubyte
12+
13+
import matplotlib.pyplot as plt
14+
import math
15+
from matplotlib import colors
16+
17+
18+
class TrainModel(BaseModel):
19+
def name(self):
20+
return 'TrainModel'
21+
22+
@staticmethod
23+
def modify_commandline_options(parser, is_train=True):
24+
return parser
25+
26+
def initialize(self, opt):
27+
BaseModel.initialize(self, opt)
28+
self.loss_names = ['G', 'L1']
29+
# load/define networks
30+
num_in = opt.input_nc + opt.output_nc + 1
31+
self.optimizers = []
32+
if opt.stage == 'full' or opt.stage == 'instance':
33+
self.model_names = ['G']
34+
self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf,
35+
'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
36+
use_tanh=True, classification=opt.classification)
37+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
38+
lr=opt.lr, betas=(opt.beta1, 0.999))
39+
self.optimizers.append(self.optimizer_G)
40+
elif opt.stage == 'fusion':
41+
self.model_names = ['G', 'GF', 'GComp']
42+
self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf,
43+
'instance', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
44+
use_tanh=True, classification=False)
45+
self.netG.eval()
46+
47+
self.netGF = networks.define_G(num_in, opt.output_nc, opt.ngf,
48+
'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
49+
use_tanh=True, classification=False)
50+
self.netGF.eval()
51+
52+
self.netGComp = networks.define_G(num_in, opt.output_nc, opt.ngf,
53+
'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
54+
use_tanh=True, classification=opt.classification)
55+
self.netGComp.eval()
56+
self.optimizer_G = torch.optim.Adam(list(self.netGF.module.weight_layer.parameters()) +
57+
list(self.netGF.module.weight_layer2.parameters()) +
58+
list(self.netGF.module.weight_layer3.parameters()) +
59+
list(self.netGF.module.weight_layer4.parameters()) +
60+
list(self.netGF.module.weight_layer5.parameters()) +
61+
list(self.netGF.module.weight_layer6.parameters()) +
62+
list(self.netGF.module.weight_layer7.parameters()) +
63+
list(self.netGF.module.weight_layer8_1.parameters()) +
64+
list(self.netGF.module.weight_layer8_2.parameters()) +
65+
list(self.netGF.module.weight_layer9_1.parameters()) +
66+
list(self.netGF.module.weight_layer9_2.parameters()) +
67+
list(self.netGF.module.weight_layer10_1.parameters()) +
68+
list(self.netGF.module.weight_layer10_2.parameters()) +
69+
list(self.netGF.module.model10.parameters()) +
70+
list(self.netGF.module.model_out.parameters()),
71+
lr=opt.lr, betas=(opt.beta1, 0.999))
72+
self.optimizers.append(self.optimizer_G)
73+
else:
74+
print('Error Stage!')
75+
exit()
76+
self.criterionL1 = networks.HuberLoss(delta=1. / opt.ab_norm)
77+
# self.criterionL1 = networks.L1Loss()
78+
79+
# initialize average loss values
80+
self.avg_losses = OrderedDict()
81+
self.avg_loss_alpha = opt.avg_loss_alpha
82+
self.error_cnt = 0
83+
for loss_name in self.loss_names:
84+
self.avg_losses[loss_name] = 0
85+
86+
def set_input(self, input):
87+
AtoB = self.opt.which_direction == 'AtoB'
88+
self.real_A = input['A' if AtoB else 'B'].to(self.device)
89+
self.real_B = input['B' if AtoB else 'A'].to(self.device)
90+
self.hint_B = input['hint_B'].to(self.device)
91+
92+
self.mask_B = input['mask_B'].to(self.device)
93+
self.mask_B_nc = self.mask_B + self.opt.mask_cent
94+
95+
self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)
96+
97+
def set_fusion_input(self, input, box_info):
98+
AtoB = self.opt.which_direction == 'AtoB'
99+
self.full_real_A = input['A' if AtoB else 'B'].to(self.device)
100+
self.full_real_B = input['B' if AtoB else 'A'].to(self.device)
101+
102+
self.full_hint_B = input['hint_B'].to(self.device)
103+
self.full_mask_B = input['mask_B'].to(self.device)
104+
105+
self.full_mask_B_nc = self.full_mask_B + self.opt.mask_cent
106+
self.full_real_B_enc = util.encode_ab_ind(self.full_real_B[:, :, ::4, ::4], self.opt)
107+
self.box_info_list = box_info
108+
109+
def forward(self):
110+
if self.opt.stage == 'full' or self.opt.stage == 'instance':
111+
(_, self.fake_B_reg) = self.netG(self.real_A, self.hint_B, self.mask_B)
112+
elif self.opt.stage == 'fusion':
113+
(_, self.comp_B_reg) = self.netGComp(self.full_real_A, self.full_hint_B, self.full_mask_B)
114+
(_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B)
115+
self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, self.box_info_list)
116+
else:
117+
print('Error! Wrong stage selection!')
118+
exit()
119+
120+
def optimize_parameters(self):
121+
self.forward()
122+
self.optimizer_G.zero_grad()
123+
if self.opt.stage == 'full' or self.opt.stage == 'instance':
124+
self.loss_L1 = torch.mean(self.criterionL1(self.fake_B_reg.type(torch.cuda.FloatTensor),
125+
self.real_B.type(torch.cuda.FloatTensor)))
126+
self.loss_G = 10 * torch.mean(self.criterionL1(self.fake_B_reg.type(torch.cuda.FloatTensor),
127+
self.real_B.type(torch.cuda.FloatTensor)))
128+
elif self.opt.stage == 'fusion':
129+
self.loss_L1 = torch.mean(self.criterionL1(self.fake_B_reg.type(torch.cuda.FloatTensor),
130+
self.full_real_B.type(torch.cuda.FloatTensor)))
131+
self.loss_G = 10 * torch.mean(self.criterionL1(self.fake_B_reg.type(torch.cuda.FloatTensor),
132+
self.full_real_B.type(torch.cuda.FloatTensor)))
133+
else:
134+
print('Error! Wrong stage selection!')
135+
exit()
136+
self.loss_G.backward()
137+
self.optimizer_G.step()
138+
139+
def get_current_visuals(self):
140+
from collections import OrderedDict
141+
visual_ret = OrderedDict()
142+
if self.opt.stage == 'full' or self.opt.stage == 'instance':
143+
visual_ret['gray'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), torch.zeros_like(self.real_B).type(torch.cuda.FloatTensor)), dim=1), self.opt)
144+
visual_ret['real'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
145+
visual_ret['fake_reg'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
146+
147+
visual_ret['hint'] = util.lab2rgb(torch.cat((self.real_A.type(torch.cuda.FloatTensor), self.hint_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
148+
visual_ret['real_ab'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
149+
visual_ret['fake_ab_reg'] = util.lab2rgb(torch.cat((torch.zeros_like(self.real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
150+
151+
elif self.opt.stage == 'fusion':
152+
visual_ret['gray'] = util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), torch.zeros_like(self.full_real_B).type(torch.cuda.FloatTensor)), dim=1), self.opt)
153+
visual_ret['real'] = util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.full_real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
154+
visual_ret['comp_reg'] = util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.comp_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
155+
visual_ret['fake_reg'] = util.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
156+
157+
self.instance_mask = torch.nn.functional.interpolate(torch.zeros([1, 1, 176, 176]), size=visual_ret['gray'].shape[2:], mode='bilinear').type(torch.cuda.FloatTensor)
158+
visual_ret['box_mask'] = torch.cat((self.instance_mask, self.instance_mask, self.instance_mask), 1)
159+
visual_ret['real_ab'] = util.lab2rgb(torch.cat((torch.zeros_like(self.full_real_A.type(torch.cuda.FloatTensor)), self.full_real_B.type(torch.cuda.FloatTensor)), dim=1), self.opt)
160+
visual_ret['comp_ab_reg'] = util.lab2rgb(torch.cat((torch.zeros_like(self.full_real_A.type(torch.cuda.FloatTensor)), self.comp_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
161+
visual_ret['fake_ab_reg'] = util.lab2rgb(torch.cat((torch.zeros_like(self.full_real_A.type(torch.cuda.FloatTensor)), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt)
162+
else:
163+
print('Error! Wrong stage selection!')
164+
exit()
165+
return visual_ret
166+
167+
# return training losses/errors. train.py will print out these errors as debugging information
168+
def get_current_losses(self):
169+
self.error_cnt += 1
170+
errors_ret = OrderedDict()
171+
for name in self.loss_names:
172+
if isinstance(name, str):
173+
# float(...) works for both scalar tensor and float number
174+
self.avg_losses[name] = float(getattr(self, 'loss_' + name)) + self.avg_loss_alpha * self.avg_losses[name]
175+
errors_ret[name] = (1 - self.avg_loss_alpha) / (1 - self.avg_loss_alpha**self.error_cnt) * self.avg_losses[name]
176+
return errors_ret
177+
178+
def save_fusion_epoch(self, epoch):
179+
path = '{0}/{1}_net_GF.pth'.format(os.path.join(self.opt.checkpoints_dir, self.opt.name), epoch)
180+
latest_path = '{0}/latest_net_GF.pth'.format(os.path.join(self.opt.checkpoints_dir, self.opt.name))
181+
torch.save(self.netGF.state_dict(), path)
182+
torch.save(self.netGF.state_dict(), latest_path)

‎options/train_options.py

+4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
class TrainOptions(BaseOptions):
55
def initialize(self, parser):
66
BaseOptions.initialize(self, parser)
7+
parser.add_argument('--stage', type=str, default='full', help='only full, instance or fusion')
8+
parser.add_argument('--train_img_dir', type=str, default='train_data/train2017', help='training images folder')
9+
parser.add_argument('--model', type=str, default='train', help='only train_model need to be used')
10+
parser.add_argument('--name', type=str, default='coco_mask', help='name of the experiment. It decides where to store samples and models')
711
parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen')
812
parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
913
parser.add_argument('--update_html_freq', type=int, default=10000, help='frequency of saving training results to html')

‎scripts/prepare_cocostuff.sh

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
DATASET_DIR="train_data"
2+
3+
python download.py --mode cocostuff --dataset_dir $DATASET_DIR
4+
echo "Finish download."
5+
unzip "$DATASET_DIR/cocostuff/train.zip" -d "$DATASET_DIR"

‎scripts/prepare_train_box.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
DATASET_DIR=train_data/train2017
2+
3+
python inference_bbox.py --test_img_dir $DATASET_DIR --filter_no_obj

‎scripts/train.sh

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
DATASET_DIR=train_data/train2017
2+
3+
# Stage 1: Training Full Image Colorization
4+
mkdir ./checkpoints/coco_full
5+
cp ./checkpoints/siggraph_retrained/latest_net_G.pth ./checkpoints/coco_full/
6+
python train.py --stage full --name coco_full --sample_p 1.0 --niter 100 --niter_decay 50 --load_model --lr 0.0005 --model train --fineSize 256 --batch_size 16 --display_ncols 3 --display_freq 1600 --print_freq 1600 --train_img_dir $DATASET_DIR
7+
8+
# Stage 2: Training Instance Image Colorization
9+
mkdir ./checkpoints/coco_instance
10+
cp ./checkpoints/coco_full/latest_net_G.pth ./checkpoints/coco_instance/
11+
python train.py --stage instance --name coco_instance --sample_p 1.0 --niter 100 --niter_decay 50 --load_model --lr 0.0005 --model train --fineSize 256 --batch_size 16 --display_ncols 3 --display_freq 1600 --print_freq 1600 --train_img_dir $DATASET_DIR
12+
13+
# Stage 3: Training Fusion Module
14+
mkdir ./checkpoints/coco_mask
15+
cp ./checkpoints/coco_full/latest_net_G.pth ./checkpoints/coco_mask/latest_net_GF.pth
16+
cp ./checkpoints/coco_instance/latest_net_G.pth ./checkpoints/coco_mask/latest_net_G.pth
17+
cp ./checkpoints/coco_full/latest_net_G.pth ./checkpoints/coco_mask/latest_net_GComp.pth
18+
python train.py --stage fusion --name coco_mask --sample_p 1.0 --niter 10 --niter_decay 20 --lr 0.00005 --model train --load_model --display_ncols 4 --fineSize 256 --batch_size 1 --display_freq 500 --print_freq 500 --train_img_dir $DATASET_DIR

‎train.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import time
2+
from options.train_options import TrainOptions
3+
from models import create_model
4+
from util.visualizer import Visualizer
5+
6+
import torch
7+
import torchvision
8+
import torchvision.transforms as transforms
9+
from tqdm import trange, tqdm
10+
11+
from fusion_dataset import *
12+
from util import util
13+
import os
14+
15+
if __name__ == '__main__':
16+
opt = TrainOptions().parse()
17+
if opt.stage == 'full':
18+
dataset = Training_Full_Dataset(opt)
19+
elif opt.stage == 'instance':
20+
dataset = Training_Instance_Dataset(opt)
21+
elif opt.stage == 'fusion':
22+
dataset = Training_Fusion_Dataset(opt)
23+
else:
24+
print('Error! Wrong stage selection!')
25+
exit()
26+
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=8)
27+
28+
dataset_size = len(dataset)
29+
print('#training images = %d' % dataset_size)
30+
31+
model = create_model(opt)
32+
model.setup(opt)
33+
34+
opt.display_port = 8098
35+
visualizer = Visualizer(opt)
36+
total_steps = 0
37+
38+
if opt.stage == 'full' or opt.stage == 'instance':
39+
for epoch in trange(opt.epoch_count, opt.niter + opt.niter_decay, desc='epoch', dynamic_ncols=True):
40+
epoch_iter = 0
41+
42+
for data_raw in tqdm(dataset_loader, desc='batch', dynamic_ncols=True, leave=False):
43+
total_steps += opt.batch_size
44+
epoch_iter += opt.batch_size
45+
46+
data_raw['rgb_img'] = [data_raw['rgb_img']]
47+
data_raw['gray_img'] = [data_raw['gray_img']]
48+
49+
input_data = util.get_colorization_data(data_raw['gray_img'], opt, p=1.0, ab_thresh=0)
50+
gt_data = util.get_colorization_data(data_raw['rgb_img'], opt, p=1.0, ab_thresh=10.0)
51+
if gt_data is None:
52+
continue
53+
if(gt_data['B'].shape[0] < opt.batch_size):
54+
continue
55+
input_data['B'] = gt_data['B']
56+
input_data['hint_B'] = gt_data['hint_B']
57+
input_data['mask_B'] = gt_data['mask_B']
58+
59+
visualizer.reset()
60+
model.set_input(input_data)
61+
model.optimize_parameters()
62+
63+
if total_steps % opt.display_freq == 0:
64+
save_result = total_steps % opt.update_html_freq == 0
65+
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
66+
67+
if total_steps % opt.print_freq == 0:
68+
losses = model.get_current_losses()
69+
if opt.display_id > 0:
70+
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
71+
72+
if epoch % opt.save_epoch_freq == 0:
73+
model.save_networks('latest')
74+
model.save_networks(epoch)
75+
model.update_learning_rate()
76+
elif opt.stage == 'fusion':
77+
for epoch in trange(opt.epoch_count, opt.niter + opt.niter_decay, desc='epoch', dynamic_ncols=True):
78+
epoch_iter = 0
79+
80+
for data_raw in tqdm(dataset_loader, desc='batch', dynamic_ncols=True, leave=False):
81+
total_steps += opt.batch_size
82+
epoch_iter += opt.batch_size
83+
box_info = data_raw['box_info'][0]
84+
box_info_2x = data_raw['box_info_2x'][0]
85+
box_info_4x = data_raw['box_info_4x'][0]
86+
box_info_8x = data_raw['box_info_8x'][0]
87+
cropped_input_data = util.get_colorization_data(data_raw['cropped_gray'], opt, p=1.0, ab_thresh=0)
88+
cropped_gt_data = util.get_colorization_data(data_raw['cropped_rgb'], opt, p=1.0, ab_thresh=10.0)
89+
full_input_data = util.get_colorization_data(data_raw['full_gray'], opt, p=1.0, ab_thresh=0)
90+
full_gt_data = util.get_colorization_data(data_raw['full_rgb'], opt, p=1.0, ab_thresh=10.0)
91+
if cropped_gt_data is None or full_gt_data is None:
92+
continue
93+
cropped_input_data['B'] = cropped_gt_data['B']
94+
full_input_data['B'] = full_gt_data['B']
95+
visualizer.reset()
96+
model.set_input(cropped_input_data)
97+
model.set_fusion_input(full_input_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
98+
model.optimize_parameters()
99+
100+
if total_steps % opt.display_freq == 0:
101+
save_result = total_steps % opt.update_html_freq == 0
102+
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
103+
104+
if total_steps % opt.print_freq == 0:
105+
losses = model.get_current_losses()
106+
if opt.display_id > 0:
107+
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
108+
if epoch % opt.save_epoch_freq == 0:
109+
model.save_fusion_epoch(epoch)
110+
model.update_learning_rate()
111+
else:
112+
print('Error! Wrong stage selection!')
113+
exit()

0 commit comments

Comments
 (0)
Please sign in to comment.