Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
# abc
# Image Colorization图片上色

完善中

## 项目文件说明

``` text
|-- README.md
|-- requirements.txt # 相关依赖库
|-- config
| `-- places10.yaml # 训练配置
|-- loader.py # 训练模型
`-- src
|-- colnet.py # colnet模型
|-- dataset.py
|-- trainer.py
`-- utils.py
```
19 changes: 19 additions & 0 deletions config/places10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
---
epochs: 45
batch_size: 32

# divisor of net optput sizes. DEFAULT: 1.
net_divisor: 1
learning_rate: 0.001

# number of workers in trainloader. DEFAULT: 4
num_workers: 12

img_dir_train: ./data/places10/train
img_dir_val: ./data/places10/val
img_dir_test: ./data/places10/test

# a directory where colorized images are saved.
img_out_dir: ./out/places10
# a directory to which models are saved.
models_dir: ./model/places10
97 changes: 97 additions & 0 deletions loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""This module is responsible for reading configuration from YAML file."""


help = """
Structure of YAML file:
model_checkpoint: (optional) path to a checkpoint of a net state.
If given, training resume on based on rest of parameters.
models_dir: a directory to which models are saved. DEFAULT: ../model
img_out_dir: a directory where colorized images are saved. DEFAULT: ../out
epochs: total number of epoches for model yo run.
batch_size: batch size for train, test and dev sets.
net_divisor: (optional) divisor of net optput sizes. DEFAULT: 1.
learning_rate: (optional) learning rate of a net. DEFAULT: 0.0001.

img_dir_train: name of directory containing images for TRAINING.
img_dir_val: name of directory containing images for VALIDATING.
img_dir_test: name of directory containing images for TESTING.

num_workers: number of workers in trainloader. DEFAULT: 4
"""

import yaml
import argparse
from src.trainer import Training
from src.colnet import ColNet


def load_config(config_file, model_checkpoint=None):
"""Loads config from YAML file
Args:
config_file: path to config file
Returns:
Instance of Training environment
"""

# Default parameters
net_divisor = 1
learning_rate = 0.0001
num_workers = 4
models_dir = './model/'
img_out_dir = './out/'

with open(config_file, 'r') as conf:
y = yaml.load(conf)

if 'net_divisor' in y:
net_divisor = y['net_divisor']

if 'learning_rate' in y:
learning_rate = y['learning_rate']

if 'model_checkpoint' in y:
model_checkpoint = y['model_checkpoint']

if 'num_workers' in y:
num_workers = y['num_workers']

if 'models_dir' in y:
models_dir = y['models_dir']

if 'img_out_dir' in y:
img_out_dir = y['img_out_dir']


train = Training(batch_size=y['batch_size'],
epochs=y['epochs'],
img_dir_train=y['img_dir_train'],
img_dir_val=y['img_dir_val'],
img_dir_test=y['img_dir_test'],
net_divisor=net_divisor,
learning_rate=learning_rate,
model_checkpoint=model_checkpoint,
num_workers=num_workers,
models_dir=models_dir,
img_out_dir=img_out_dir)
return train



if __name__ == "__main__":
# 从YAML中导入网络模型的配置
short_desc = 'Loads network configuration from YAML file.\n'

parser = argparse.ArgumentParser(description=short_desc + help,
# 将description以输入格式输出,不会合并为一行
formatter_class=argparse.RawDescriptionHelpFormatter)
# config是positional argument
parser.add_argument('config', metavar='config', help='Path to .yaml config file')
# --model是optional argument
parser.add_argument('--model', help='Path to pretrained .pt model')
args = parser.parse_args()

t = load_config(args.config, args.model)
t.info()
t.run()
t.test()

211 changes: 211 additions & 0 deletions src/colnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


def Conv2d(in_ch, out_ch, stride, kernel_size=3, padding=1):
"""Returns an instance of nn.Conv2d"""
return nn.Conv2d(in_channels=in_ch, out_channels=out_ch,
stride=stride, kernel_size=kernel_size, padding=padding)


class LowLevelFeatures(nn.Module):
"""Low-Level Features Network"""

def __init__(self, net_divisor=1):
super(LowLevelFeatures, self).__init__()

ksize = np.array([1, 64, 128, 128, 256, 256, 512]) // net_divisor
ksize[0] = 1

self.conv1 = Conv2d(1, ksize[1], 2)
self.conv2 = Conv2d(ksize[1], ksize[2], 1)
self.conv3 = Conv2d(ksize[2], ksize[3], 2)
self.conv4 = Conv2d(ksize[3], ksize[4], 1)
self.conv5 = Conv2d(ksize[4], ksize[5], 2)
self.conv6 = Conv2d(ksize[5], ksize[6], 1)

def forward(self, x):
out = F.relu(self.conv1(x))
out = F.relu(self.conv2(out))
out = F.relu(self.conv3(out))
out = F.relu(self.conv4(out))
out = F.relu(self.conv5(out))
out = F.relu(self.conv6(out))
return out


class MidLevelFeatures(nn.Module):
"""Mid-Level Features Network"""

def __init__(self, net_divisor=1):
super(MidLevelFeatures, self).__init__()

ksize = np.array([512, 512, 256]) // net_divisor

self.conv7 = Conv2d(ksize[0], ksize[1], 1)
self.conv8 = Conv2d(ksize[1], ksize[2], 1)

def forward(self, x):
out = F.relu(self.conv7(x))
out = F.relu(self.conv8(out))
return out


class GlobalFeatures(nn.Module):
"""Global Features Network"""

def __init__(self, net_divisor=1):
super(GlobalFeatures, self).__init__()

ksize = np.array([512, 1024, 512, 256]) // net_divisor
self.ksize0 = ksize[0]

self.conv1 = Conv2d(ksize[0], ksize[0], 2)
self.conv2 = Conv2d(ksize[0], ksize[0], 1)
self.conv3 = Conv2d(ksize[0], ksize[0], 2)
self.conv4 = Conv2d(ksize[0], ksize[0], 1)
self.fc1 = nn.Linear(7 * 7 * ksize[0], ksize[1])
self.fc2 = nn.Linear(ksize[1], ksize[2])
self.fc3 = nn.Linear(ksize[2], ksize[3])

def forward(self, x):
y = F.relu(self.conv1(x))
y = F.relu(self.conv2(y))
y = F.relu(self.conv3(y))
y = F.relu(self.conv4(y))
y = y.view(-1, 7 * 7 * self.ksize0)
y = F.relu(self.fc1(y))
y = F.relu(self.fc2(y))

# Branching
out = y
classification_in = y

out = F.relu(self.fc3(out))

return out, classification_in


class ColorizationNetwork(nn.Module):
"""Colorizaion Network"""

def __init__(self, net_divisor=1):
super(ColorizationNetwork, self).__init__()

ksize = np.array([256, 128, 64, 64, 32]) // net_divisor

self.conv9 = Conv2d(ksize[0], ksize[1], 1)

# Here comes upsample #1

self.conv10 = Conv2d(ksize[1], ksize[2], 1)
self.conv11 = Conv2d(ksize[2], ksize[3], 1)

# Here comes upsample #2

self.conv12 = Conv2d(ksize[3], ksize[4], 1)
self.conv13 = Conv2d(ksize[4], 2, 1)

def forward(self, x):
out = F.relu(self.conv9(x))

# Upsample #1
out = nn.functional.interpolate(input=out, scale_factor=2)

out = F.relu(self.conv10(out))
out = F.relu(self.conv11(out))

# Upsample #2
out = nn.functional.interpolate(input=out, scale_factor=2)

out = F.relu(self.conv12(out))
out = torch.sigmoid(self.conv13(out))

# Upsample #3
out = nn.functional.interpolate(input=out, scale_factor=2)

return out


class ClassNet(nn.Module):
"""Classification Network Class"""

def __init__(self, num_classes, net_divisor=1):
super(ClassNet, self).__init__()

self.num_classes = num_classes
ksize = np.array([512, 256]) // net_divisor

self.fc1 = nn.Linear(ksize[0], ksize[1])
self.fc2 = nn.Linear(ksize[1], num_classes)

def forward(self, x):
out = F.relu(self.fc1(x))
out = self.fc2(out)
return out


class ColNet(nn.Module):
"""Colorization network class"""

def __init__(self, num_classes, net_divisor=1):
"""Initializes the network.
Args:
net_divisor - divisor of net output sizes. Useful for debugging.
"""
super(ColNet, self).__init__()

self.net_divisor = net_divisor

self.conv_fuse = Conv2d(512 // net_divisor, 256 // net_divisor, 1, kernel_size=1, padding=0)

self.low = LowLevelFeatures(net_divisor)

self.mid = MidLevelFeatures(net_divisor)

self.classifier = ClassNet(num_classes, net_divisor)

self.glob = GlobalFeatures(net_divisor)

self.col = ColorizationNetwork(net_divisor)

def fusion_layer(self, mid_out, glob_out):
h = mid_out.shape[2] # Height of a picture
w = mid_out.shape[3] # Width of a picture

glob_stack2d = torch.stack(tuple(glob_out for _ in range(w)), 1)
glob_stack3d = torch.stack(tuple(glob_stack2d for _ in range(h)), 1)
glob_stack3d = glob_stack3d.permute(0, 3, 1, 2)

# 'Merge' two volumes
stack_volume = torch.cat((mid_out, glob_stack3d), 1)

out = F.relu(self.conv_fuse(stack_volume))
return out

def forward(self, x):
# Low level
low_out = self.low(x)

# Net branch
mid_out = low_out
glob_out = low_out

# Mid level
mid_out = self.mid(mid_out)

# Global
glob_out, classification_in = self.glob(glob_out)

# Classification
classification_out = self.classifier(classification_in)

# Fusion layer
out = self.fusion_layer(mid_out, glob_out)

# Colorization Net
out = self.col(out)

return out, classification_out
Loading