-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathtest_denoising_SIDD.py
62 lines (51 loc) · 1.78 KB
/
test_denoising_SIDD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2020-07-10 14:38:39
'''
In this demo, we only test the model on one image of SIDD validation dataset.
The full validation dataset can be download from the following website:
https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php
'''
import argparse
import torch
from networks import UNetD
from scipy.io import loadmat
from skimage import img_as_float32, img_as_ubyte
from matplotlib import pyplot as plt
from utils import PadUNet
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='DANet+',
help="Model selection: DANet or DANet+, (default:DANet+)")
args = parser.parse_args()
# build the network
net = UNetD(3, wf=32, depth=5).cuda()
# load the pretrained model
if args.model.lower() == 'danet':
net.load_state_dict(torch.load('./model_states/DANet.pt', map_location='cpu')['D'])
else:
net.load_state_dict(torch.load('./model_states/DANetPlus.pt', map_location='cpu'))
# read the images
im_noisy = loadmat('./test_data/SIDD/noisy.mat')['im_noisy']
im_gt = loadmat('./test_data/SIDD/gt.mat')['im_gt']
# denoising
inputs = torch.from_numpy(img_as_float32(im_noisy).transpose([2,0,1])).unsqueeze(0).cuda()
with torch.autograd.no_grad():
padunet = PadUNet(inputs, dep_U=5)
inputs_pad = padunet.pad()
outputs_pad = inputs_pad - net(inputs_pad)
outputs = padunet.pad_inverse(outputs_pad)
outputs.clamp_(0.0, 1.0)
im_denoise = img_as_ubyte(outputs.cpu().numpy()[0,].transpose([1,2,0]))
plt.subplot(1,3,1)
plt.imshow(im_noisy)
plt.title('Noisy Image')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(im_gt)
plt.title('Gt Image')
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(im_denoise)
plt.title('Denoised Image')
plt.axis('off')
plt.show()