forked from TomTomTommi/HiNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
99 lines (75 loc) · 3.08 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import math
import torch
import torch.nn
import torch.optim
import torchvision
import numpy as np
from model import *
import config as c
import datasets
import modules.Unet_common as common
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load(name):
state_dicts = torch.load(name)
network_state_dict = {k:v for k,v in state_dicts['net'].items() if 'tmp_var' not in k}
net.load_state_dict(network_state_dict)
try:
optim.load_state_dict(state_dicts['opt'])
except:
print('Cannot load optimizer for some reason or other')
def gauss_noise(shape):
noise = torch.zeros(shape).cuda()
for i in range(noise.shape[0]):
noise[i] = torch.randn(noise[i].shape).cuda()
return noise
def computePSNR(origin,pred):
origin = np.array(origin)
origin = origin.astype(np.float32)
pred = np.array(pred)
pred = pred.astype(np.float32)
mse = np.mean((origin/1.0 - pred/1.0) ** 2 )
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0**2/mse)
net = Model()
net.cuda()
init_model(net)
net = torch.nn.DataParallel(net, device_ids=c.device_ids)
params_trainable = (list(filter(lambda p: p.requires_grad, net.parameters())))
optim = torch.optim.Adam(params_trainable, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
weight_scheduler = torch.optim.lr_scheduler.StepLR(optim, c.weight_step, gamma=c.gamma)
load(c.MODEL_PATH + c.suffix)
net.eval()
dwt = common.DWT()
iwt = common.IWT()
with torch.no_grad():
for i, data in enumerate(datasets.testloader):
data = data.to(device)
cover = data[data.shape[0] // 2:, :, :, :]
secret = data[:data.shape[0] // 2, :, :, :]
cover_input = dwt(cover)
secret_input = dwt(secret)
input_img = torch.cat((cover_input, secret_input), 1)
#################
# forward: #
#################
output = net(input_img)
output_steg = output.narrow(1, 0, 4 * c.channels_in)
output_z = output.narrow(1, 4 * c.channels_in, output.shape[1] - 4 * c.channels_in)
steg_img = iwt(output_steg)
backward_z = gauss_noise(output_z.shape)
#################
# backward: #
#################
output_rev = torch.cat((output_steg, backward_z), 1)
bacward_img = net(output_rev, rev=True)
secret_rev = bacward_img.narrow(1, 4 * c.channels_in, bacward_img.shape[1] - 4 * c.channels_in)
secret_rev = iwt(secret_rev)
cover_rev = bacward_img.narrow(1, 0, 4 * c.channels_in)
cover_rev = iwt(cover_rev)
resi_cover = (steg_img - cover) * 20
resi_secret = (secret_rev - secret) * 20
torchvision.utils.save_image(cover, c.IMAGE_PATH_cover + '%.5d.png' % i)
torchvision.utils.save_image(secret, c.IMAGE_PATH_secret + '%.5d.png' % i)
torchvision.utils.save_image(steg_img, c.IMAGE_PATH_steg + '%.5d.png' % i)
torchvision.utils.save_image(secret_rev, c.IMAGE_PATH_secret_rev + '%.5d.png' % i)