forked from TomTomTommi/DeepMIH
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ebb3dab
commit cec74b5
Showing
2 changed files
with
701 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
import warnings | ||
import sys | ||
import math | ||
import os | ||
import torch | ||
import torch.nn | ||
import torch.optim | ||
import torchvision | ||
import numpy as np | ||
import tqdm | ||
# import cv2 | ||
from model import * | ||
from imp_subnet import * | ||
import config as c | ||
from os.path import join | ||
import datasets | ||
import modules.module_util as mutil | ||
import modules.Unet_common as common | ||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def load(name, net, optim): | ||
state_dicts = torch.load(name) | ||
network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k} | ||
net.load_state_dict(network_state_dict) | ||
try: | ||
optim.load_state_dict(state_dicts['opt']) | ||
except: | ||
print('Cannot load optimizer for some reason or other') | ||
|
||
|
||
def gauss_noise(shape): | ||
noise = torch.zeros(shape).cuda() | ||
for i in range(noise.shape[0]): | ||
noise[i] = torch.randn(noise[i].shape).cuda() | ||
|
||
return noise | ||
|
||
|
||
def computePSNR(origin, pred): | ||
origin = np.array(origin) | ||
origin = origin.astype(np.float32) | ||
pred = np.array(pred) | ||
pred = pred.astype(np.float32) | ||
mse = np.mean((origin / 1.0 - pred / 1.0) ** 2) | ||
if mse < 1.0e-10: | ||
return 100 | ||
return 10 * math.log10(255.0 ** 2 / mse) | ||
|
||
|
||
net1 = Model_1() | ||
net2 = Model_2() | ||
net3 = ImpMapBlock() | ||
net1.cuda() | ||
net2.cuda() | ||
net3.cuda() | ||
init_model(net1) | ||
init_model(net2) | ||
net1 = torch.nn.DataParallel(net1, device_ids=c.device_ids) | ||
net2 = torch.nn.DataParallel(net2, device_ids=c.device_ids) | ||
net3 = torch.nn.DataParallel(net3, device_ids=c.device_ids) | ||
params_trainable1 = (list(filter(lambda p: p.requires_grad, net1.parameters()))) | ||
params_trainable2 = (list(filter(lambda p: p.requires_grad, net2.parameters()))) | ||
params_trainable3 = (list(filter(lambda p: p.requires_grad, net3.parameters()))) | ||
optim1 = torch.optim.Adam(params_trainable1, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay) | ||
optim2 = torch.optim.Adam(params_trainable2, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay) | ||
optim3 = torch.optim.Adam(params_trainable3, lr=c.lr3, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay) | ||
weight_scheduler1 = torch.optim.lr_scheduler.StepLR(optim1, c.weight_step, gamma=c.gamma) | ||
weight_scheduler2 = torch.optim.lr_scheduler.StepLR(optim2, c.weight_step, gamma=c.gamma) | ||
weight_scheduler3 = torch.optim.lr_scheduler.StepLR(optim3, c.weight_step, gamma=c.gamma) | ||
dwt = common.DWT() | ||
iwt = common.IWT() | ||
|
||
# if c.tain_next: | ||
# load(c.MODEL_PATH + c.suffix_load + '_1.pt', net1, optim1) | ||
# load(c.MODEL_PATH + c.suffix_load + '_2.pt', net2, optim2) | ||
# load(c.MODEL_PATH + c.suffix_load + '_3.pt', net3, optim3) | ||
|
||
if c.pretrain: | ||
load(c.PRETRAIN_PATH + c.suffix_pretrain + '_1.pt', net1, optim1) | ||
load(c.PRETRAIN_PATH + c.suffix_pretrain + '_2.pt', net2, optim2) | ||
if c.PRETRAIN_PATH_3 is not None: | ||
load(c.PRETRAIN_PATH_3 + c.suffix_pretrain_3 + '_3.pt', net3, optim3) | ||
|
||
|
||
with torch.no_grad(): | ||
net1.eval() | ||
net2.eval() | ||
net3.eval() | ||
for i, x in enumerate(datasets.testloader): | ||
# for x in datasets.testloader: | ||
x = x.to(device) | ||
cover = x[:x.shape[0] // 3] # channels = 3 | ||
secret_1 = x[x.shape[0] // 3: 2 * (x.shape[0] // 3)] | ||
secret_2 = x[2 * (x.shape[0] // 3): 3 * (x.shape[0] // 3)] | ||
|
||
cover_dwt = dwt(cover) # channels = 12 | ||
secret_dwt_1 = dwt(secret_1) | ||
secret_dwt_2 = dwt(secret_2) | ||
|
||
input_dwt_1 = torch.cat((cover_dwt, secret_dwt_1), 1) # channels = 24 | ||
|
||
################# | ||
# forward1: # | ||
################# | ||
output_dwt_1 = net1(input_dwt_1) # channels = 24 | ||
output_steg_dwt_1 = output_dwt_1.narrow(1, 0, 4 * c.channels_in) # channels = 12 | ||
output_z_dwt_1 = output_dwt_1.narrow(1, 4 * c.channels_in, | ||
output_dwt_1.shape[1] - 4 * c.channels_in) # channels = 12 | ||
|
||
# get steg1 | ||
output_steg_1 = iwt(output_steg_dwt_1) # channels = 3 | ||
|
||
################# | ||
# forward2: # | ||
################# | ||
if c.use_imp_map: | ||
imp_map = net3(cover, secret_1, output_steg_1) # channels = 3 | ||
# for name, parameters in net3.named_parameters(): # 打印出每一层的参数的大小 | ||
# print(name, ':', parameters) | ||
# print(imp_map) | ||
# print(imp_map.min()) | ||
# print(imp_map.max()) | ||
else: | ||
imp_map = torch.zeros(cover.shape).cuda() | ||
|
||
imp_map_dwt = dwt(imp_map) # channels = 12 | ||
input_dwt_2 = torch.cat((output_steg_dwt_1, imp_map_dwt), 1) # 24, without secret2 | ||
input_dwt_2 = torch.cat((input_dwt_2, secret_dwt_2), 1) # 36 | ||
|
||
output_dwt_2 = net2(input_dwt_2) # channels = 36 | ||
output_steg_dwt_2 = output_dwt_2.narrow(1, 0, 4 * c.channels_in) # channels = 12 | ||
output_z_dwt_2 = output_dwt_2.narrow(1, 4 * c.channels_in, | ||
output_dwt_2.shape[1] - 4 * c.channels_in) # channels = 24 | ||
|
||
# get steg2 | ||
output_steg_2 = iwt(output_steg_dwt_2) # channels = 3 | ||
|
||
################# | ||
# backward2: # | ||
################# | ||
|
||
output_z_guass_1 = gauss_noise(output_z_dwt_1.shape) # channels = 12 | ||
output_z_guass_2 = gauss_noise(output_z_dwt_2.shape) # channels = 24 | ||
|
||
output_rev_dwt_2 = torch.cat((output_steg_dwt_2, output_z_guass_2), 1) # channels = 36 | ||
|
||
rev_dwt_2 = net2(output_rev_dwt_2, rev=True) # channels = 36 | ||
|
||
rev_steg_dwt_1 = rev_dwt_2.narrow(1, 0, 4 * c.channels_in) # channels = 12 | ||
rev_secret_dwt_2 = rev_dwt_2.narrow(1, rev_dwt_2.shape[1] - 4 * c.channels_in, 4 * c.channels_in) # channels = 12 | ||
|
||
rev_steg_1 = iwt(rev_steg_dwt_1) # channels = 3 | ||
rev_secret_2 = iwt(rev_secret_dwt_2) # channels = 3 | ||
|
||
################# | ||
# backward1: # | ||
################# | ||
output_rev_dwt_1 = torch.cat((rev_steg_dwt_1, output_z_guass_1), 1) # channels = 24 | ||
|
||
rev_dwt_1 = net1(output_rev_dwt_1, rev=True) # channels = 36 | ||
|
||
rev_secret_dwt = rev_dwt_1.narrow(1, 4 * c.channels_in, 4 * c.channels_in) # channels = 12 | ||
rev_secret_1 = iwt(rev_secret_dwt) | ||
|
||
# imp_map = imp_map.cpu().numpy().squeeze() * 255 | ||
# cover = cover.cpu().numpy().squeeze() * 255 | ||
# output_steg_1 = output_steg_1.cpu().numpy().squeeze() * 255 | ||
# print(imp_map) | ||
# imp_map = imp_map * 1 | ||
# resi_cover_1 = cover - output_steg_1 | ||
# resi_cover_2 = cover - output_steg_2 | ||
# resi_secret_1 = secret_1 - rev_secret_1 | ||
# resi_secret_2 = secret_2 - rev_secret_2 | ||
# resi_cover_1 = resi_cover_1 * 7 | ||
# resi_cover_2 = resi_cover_2 * 7 | ||
# resi_secret_1 = resi_secret_1 * 7 | ||
# resi_secret_2 = resi_secret_2 * 7 | ||
|
||
# torchvision.utils.save_image(imp_map, c.TEST_PATH_imp_map + '%.5d.png' % i) | ||
# torchvision.utils.save_image(resi_cover_1, c.TEST_PATH_resi_cover_1 + '%.5d.png' % i) | ||
# torchvision.utils.save_image(resi_cover_2, c.TEST_PATH_resi_cover_2 + '%.5d.png' % i) | ||
# torchvision.utils.save_image(resi_secret_1, c.TEST_PATH_resi_secret_1 + '%.5d.png' % i) | ||
# torchvision.utils.save_image(resi_secret_2, c.TEST_PATH_resi_secret_2 + '%.5d.png' % i) | ||
|
||
# torchvision.utils.save_image(cover, c.TEST_PATH_cover + '%.5d.png' % i) | ||
# torchvision.utils.save_image(secret_1, c.TEST_PATH_secret_1 + '%.5d.png' % i) | ||
# torchvision.utils.save_image(secret_2, c.TEST_PATH_secret_2 + '%.5d.png' % i) | ||
|
||
# torchvision.utils.save_image(output_steg_1, c.TEST_PATH_steg_1 + '%.5d.png' % i) | ||
# torchvision.utils.save_image(rev_secret_1, c.TEST_PATH_secret_rev_1 + '%.5d.png' % i) | ||
|
||
# torchvision.utils.save_image(output_steg_2, c.TEST_PATH_steg_2 + '%.5d.png' % i) | ||
# torchvision.utils.save_image(rev_secret_2, c.TEST_PATH_secret_rev_2 + '%.5d.png' % i) | ||
|
||
torchvision.utils.save_image(rev_secret_dwt.narrow(1, 0, c.channels_in), '/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret-rev_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(rev_secret_dwt.narrow(1, c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret-rev_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(rev_secret_dwt.narrow(1, 2 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret-rev_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(rev_secret_dwt.narrow(1, 3 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret-rev_1/' + '%.5d.png' % i) | ||
|
||
torchvision.utils.save_image(output_steg_dwt_2.narrow(1, 0, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/steg_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(output_steg_dwt_2.narrow(1, c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/steg_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(output_steg_dwt_2.narrow(1, 2 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/steg_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(output_steg_dwt_2.narrow(1, 3 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/steg_2/' + '%.5d.png' % i) | ||
|
||
torchvision.utils.save_image(output_steg_dwt_1.narrow(1, 0, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/steg_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(output_steg_dwt_1.narrow(1, c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/steg_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(output_steg_dwt_1.narrow(1, 2 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/steg_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(output_steg_dwt_1.narrow(1, 3 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/steg_1/' + '%.5d.png' % i) | ||
|
||
torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, 0, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret-rev_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret-rev_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, 2 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret-rev_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, 3 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret-rev_2/' + '%.5d.png' % i) | ||
|
||
torchvision.utils.save_image(cover_dwt.narrow(1, 0, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/cover/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(cover_dwt.narrow(1, c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/cover/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(cover_dwt.narrow(1, 2 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/cover/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(cover_dwt.narrow(1, 3 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/cover/' + '%.5d.png' % i) | ||
|
||
torchvision.utils.save_image(secret_dwt_1.narrow(1, 0, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(secret_dwt_1.narrow(1, c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(secret_dwt_1.narrow(1, 2 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret_1/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(secret_dwt_1.narrow(1, 3 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret_1/' + '%.5d.png' % i) | ||
|
||
torchvision.utils.save_image(secret_dwt_2.narrow(1, 0, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(secret_dwt_2.narrow(1, c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(secret_dwt_2.narrow(1, 2 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret_2/' + '%.5d.png' % i) | ||
torchvision.utils.save_image(secret_dwt_2.narrow(1, 3 * c.channels_in, c.channels_in), | ||
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret_2/' + '%.5d.png' % i) |
Oops, something went wrong.