-
Notifications
You must be signed in to change notification settings - Fork 5
/
denoising.py
180 lines (137 loc) · 5.47 KB
/
denoising.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import argparse
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.optim as optim
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
import utils.funcs as fn
import utils.basic_utils as bu
import utils.image_utils as imu
import utils.model_utils as mu
import utils.denoising_utils as du
from utils.common_utils import get_image_grid
from utils.gpu_utils import gpu_filter
from utils.paths import ROOT, IMG_EXT
from utils.common_types import *
from utils.keywords import *
def parse_args():
parser = argparse.ArgumentParser(description='NAS-DIP Denoising')
parser.add_argument('--gpu_index', default=None, type=int)
parser.add_argument('--num_gpu', type=int, default=12)
parser.add_argument('--cpu', action='store_true')
parser.add_argument('--check', action='store_true')
parser.add_argument('img_stem', type=str)
parser.add_argument('--sigma', default=25, type=int)
parser.add_argument('--exp_weight', default=0.99, type=float)
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--reg_noise_std', default=1./30., type=float)
parser.add_argument('--num_iter', default=4000, type=int)
parser.add_argument('--atleast', type=int, default=500)
parser.add_argument('--show_every', default=1, type=int)
args = parser.parse_args()
return args
def main():
args = parse_args()
GPU_INDEX = args.gpu_index
NUM_GPU = args.num_gpu
CPU = args.cpu
DTYPE = torch.FloatTensor if CPU else torch.cuda.FloatTensor
CHECK = args.check
IMG_STEM = args.img_stem
IMG_NAME = f'{IMG_STEM}{IMG_EXT}'
SIGMA: int = args.sigma # this is for images with pixel values in the range [0, 255]
EXP_WEIGHT = args.exp_weight
LR = args.lr
REG_NOISE_STD = args.reg_noise_std
NUM_ITER = args.num_iter
ATLEAST = args.atleast
SHOW_EVERY = args.show_every
# stem is the name of a file without its extension
img_name = IMG_STEM + IMG_EXT
# load the image
img_true_np = bu.read_true_image(DENOISING, IMG_STEM)
img_noisy_np, _ = bu.read_noisy_image(IMG_STEM, sigma=SIGMA)
img_true_np_psd_db_norm = fn.psd_db_norm(img_true_np)
img_noisy_np_psd_db_norm = fn.psd_db_norm(img_noisy_np)
img_true_torch = imu.np_to_torch(img_true_np).type(DTYPE)
img_noisy_torch = imu.np_to_torch(img_noisy_np).type(DTYPE)
psnr_noisy = fn.psnr(img_true_np, img_noisy_np)
out_channels = img_noisy_np.shape[0]
print(f'Image {img_name} is loaded.')
print(f'Shape: {img_true_np.shape}.')
print(f'PSNR of the noisy image: {psnr_noisy:.2f} dB.')
print()
# we will use the same input noise on all models
input_noise = du.get_noise_like(img_true_torch, 1/10, 'uniform').detach()
input_noise_np = imu.torch_to_np(input_noise)
input_noise_np_psd_db_norm = fn.psd_db_norm(input_noise_np)
in_channels = input_noise_np.shape[0]
print(f'input noise shape: {input_noise.shape}.')
print()
# read the models
model_names = gpu_filter(GPU_INDEX, NUM_GPU)
num_models = len(model_names)
print(f'{num_models} models will be processed.\n')
# we will save the results here
datadir = ROOT[BENCHMARK][DENOISING][SIGMA][IMG_STEM]
# start to train the models
print(f'Starting the DIP process...')
for i, model_name in enumerate(model_names, start=1):
print('{:03}/{:03}: {}'.format(i, len(model_names), model_name))
# we will save the results here
modeldir = datadir[DATA][model_name]
# check whether the necessary files allready exists
if CHECK and \
modeldir['htr.pkl'].exists() and \
modeldir['grid.png'].exists() and \
modeldir['img_noisy.npy'].exists() and \
modeldir['input_noise.npy'].exists() and \
modeldir['psnr_noisy.pkl'].exists():
print('Necessary files already exists - skipped.\n')
continue
# create the model
model = mu.create_model(
model_name, in_channels=in_channels, out_channels=out_channels
).type(DTYPE)
print('Model is created.')
print('Starting optimization with ADAM.')
optimizer = optim.Adam(model.parameters(), lr=LR)
# denoising
htr = du.denoising(
model=model,
optimizer=optimizer,
img_true_np=img_true_np,
img_noisy_torch=img_noisy_torch,
input_noise=input_noise,
num_iter=NUM_ITER,
atleast=ATLEAST,
exp_weight=EXP_WEIGHT,
reg_noise_std=REG_NOISE_STD,
show_every=SHOW_EVERY
)
grid = get_image_grid(
[
input_noise_np,
img_true_np,
img_noisy_np,
htr['best_out'],
htr['best_out_sm'],
input_noise_np_psd_db_norm,
img_true_np_psd_db_norm,
img_noisy_np_psd_db_norm,
fn.psd_db_norm(htr['best_out']),
fn.psd_db_norm(htr['best_out_sm'])
],
nrow=5
)
# save the results
modeldir['htr.pkl'].save(htr)
modeldir['img_noisy.npy'].save(img_noisy_np)
modeldir['input_noise.npy'].save(input_noise_np)
modeldir['psnr_noisy.pkl'].save(psnr_noisy)
modeldir['grid.png'].save(grid)
print('Results are saved.\n')
if __name__ == '__main__':
main()