Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
TomTomTommi authored Jan 7, 2022
1 parent ebb3dab commit cec74b5
Show file tree
Hide file tree
Showing 2 changed files with 701 additions and 0 deletions.
257 changes: 257 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import warnings
import sys
import math
import os
import torch
import torch.nn
import torch.optim
import torchvision
import numpy as np
import tqdm
# import cv2
from model import *
from imp_subnet import *
import config as c
from os.path import join
import datasets
import modules.module_util as mutil
import modules.Unet_common as common

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def load(name, net, optim):
state_dicts = torch.load(name)
network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
net.load_state_dict(network_state_dict)
try:
optim.load_state_dict(state_dicts['opt'])
except:
print('Cannot load optimizer for some reason or other')


def gauss_noise(shape):
noise = torch.zeros(shape).cuda()
for i in range(noise.shape[0]):
noise[i] = torch.randn(noise[i].shape).cuda()

return noise


def computePSNR(origin, pred):
origin = np.array(origin)
origin = origin.astype(np.float32)
pred = np.array(pred)
pred = pred.astype(np.float32)
mse = np.mean((origin / 1.0 - pred / 1.0) ** 2)
if mse < 1.0e-10:
return 100
return 10 * math.log10(255.0 ** 2 / mse)


net1 = Model_1()
net2 = Model_2()
net3 = ImpMapBlock()
net1.cuda()
net2.cuda()
net3.cuda()
init_model(net1)
init_model(net2)
net1 = torch.nn.DataParallel(net1, device_ids=c.device_ids)
net2 = torch.nn.DataParallel(net2, device_ids=c.device_ids)
net3 = torch.nn.DataParallel(net3, device_ids=c.device_ids)
params_trainable1 = (list(filter(lambda p: p.requires_grad, net1.parameters())))
params_trainable2 = (list(filter(lambda p: p.requires_grad, net2.parameters())))
params_trainable3 = (list(filter(lambda p: p.requires_grad, net3.parameters())))
optim1 = torch.optim.Adam(params_trainable1, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
optim2 = torch.optim.Adam(params_trainable2, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
optim3 = torch.optim.Adam(params_trainable3, lr=c.lr3, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
weight_scheduler1 = torch.optim.lr_scheduler.StepLR(optim1, c.weight_step, gamma=c.gamma)
weight_scheduler2 = torch.optim.lr_scheduler.StepLR(optim2, c.weight_step, gamma=c.gamma)
weight_scheduler3 = torch.optim.lr_scheduler.StepLR(optim3, c.weight_step, gamma=c.gamma)
dwt = common.DWT()
iwt = common.IWT()

# if c.tain_next:
# load(c.MODEL_PATH + c.suffix_load + '_1.pt', net1, optim1)
# load(c.MODEL_PATH + c.suffix_load + '_2.pt', net2, optim2)
# load(c.MODEL_PATH + c.suffix_load + '_3.pt', net3, optim3)

if c.pretrain:
load(c.PRETRAIN_PATH + c.suffix_pretrain + '_1.pt', net1, optim1)
load(c.PRETRAIN_PATH + c.suffix_pretrain + '_2.pt', net2, optim2)
if c.PRETRAIN_PATH_3 is not None:
load(c.PRETRAIN_PATH_3 + c.suffix_pretrain_3 + '_3.pt', net3, optim3)


with torch.no_grad():
net1.eval()
net2.eval()
net3.eval()
for i, x in enumerate(datasets.testloader):
# for x in datasets.testloader:
x = x.to(device)
cover = x[:x.shape[0] // 3] # channels = 3
secret_1 = x[x.shape[0] // 3: 2 * (x.shape[0] // 3)]
secret_2 = x[2 * (x.shape[0] // 3): 3 * (x.shape[0] // 3)]

cover_dwt = dwt(cover) # channels = 12
secret_dwt_1 = dwt(secret_1)
secret_dwt_2 = dwt(secret_2)

input_dwt_1 = torch.cat((cover_dwt, secret_dwt_1), 1) # channels = 24

#################
# forward1: #
#################
output_dwt_1 = net1(input_dwt_1) # channels = 24
output_steg_dwt_1 = output_dwt_1.narrow(1, 0, 4 * c.channels_in) # channels = 12
output_z_dwt_1 = output_dwt_1.narrow(1, 4 * c.channels_in,
output_dwt_1.shape[1] - 4 * c.channels_in) # channels = 12

# get steg1
output_steg_1 = iwt(output_steg_dwt_1) # channels = 3

#################
# forward2: #
#################
if c.use_imp_map:
imp_map = net3(cover, secret_1, output_steg_1) # channels = 3
# for name, parameters in net3.named_parameters(): # 打印出每一层的参数的大小
# print(name, ':', parameters)
# print(imp_map)
# print(imp_map.min())
# print(imp_map.max())
else:
imp_map = torch.zeros(cover.shape).cuda()

imp_map_dwt = dwt(imp_map) # channels = 12
input_dwt_2 = torch.cat((output_steg_dwt_1, imp_map_dwt), 1) # 24, without secret2
input_dwt_2 = torch.cat((input_dwt_2, secret_dwt_2), 1) # 36

output_dwt_2 = net2(input_dwt_2) # channels = 36
output_steg_dwt_2 = output_dwt_2.narrow(1, 0, 4 * c.channels_in) # channels = 12
output_z_dwt_2 = output_dwt_2.narrow(1, 4 * c.channels_in,
output_dwt_2.shape[1] - 4 * c.channels_in) # channels = 24

# get steg2
output_steg_2 = iwt(output_steg_dwt_2) # channels = 3

#################
# backward2: #
#################

output_z_guass_1 = gauss_noise(output_z_dwt_1.shape) # channels = 12
output_z_guass_2 = gauss_noise(output_z_dwt_2.shape) # channels = 24

output_rev_dwt_2 = torch.cat((output_steg_dwt_2, output_z_guass_2), 1) # channels = 36

rev_dwt_2 = net2(output_rev_dwt_2, rev=True) # channels = 36

rev_steg_dwt_1 = rev_dwt_2.narrow(1, 0, 4 * c.channels_in) # channels = 12
rev_secret_dwt_2 = rev_dwt_2.narrow(1, rev_dwt_2.shape[1] - 4 * c.channels_in, 4 * c.channels_in) # channels = 12

rev_steg_1 = iwt(rev_steg_dwt_1) # channels = 3
rev_secret_2 = iwt(rev_secret_dwt_2) # channels = 3

#################
# backward1: #
#################
output_rev_dwt_1 = torch.cat((rev_steg_dwt_1, output_z_guass_1), 1) # channels = 24

rev_dwt_1 = net1(output_rev_dwt_1, rev=True) # channels = 36

rev_secret_dwt = rev_dwt_1.narrow(1, 4 * c.channels_in, 4 * c.channels_in) # channels = 12
rev_secret_1 = iwt(rev_secret_dwt)

# imp_map = imp_map.cpu().numpy().squeeze() * 255
# cover = cover.cpu().numpy().squeeze() * 255
# output_steg_1 = output_steg_1.cpu().numpy().squeeze() * 255
# print(imp_map)
# imp_map = imp_map * 1
# resi_cover_1 = cover - output_steg_1
# resi_cover_2 = cover - output_steg_2
# resi_secret_1 = secret_1 - rev_secret_1
# resi_secret_2 = secret_2 - rev_secret_2
# resi_cover_1 = resi_cover_1 * 7
# resi_cover_2 = resi_cover_2 * 7
# resi_secret_1 = resi_secret_1 * 7
# resi_secret_2 = resi_secret_2 * 7

# torchvision.utils.save_image(imp_map, c.TEST_PATH_imp_map + '%.5d.png' % i)
# torchvision.utils.save_image(resi_cover_1, c.TEST_PATH_resi_cover_1 + '%.5d.png' % i)
# torchvision.utils.save_image(resi_cover_2, c.TEST_PATH_resi_cover_2 + '%.5d.png' % i)
# torchvision.utils.save_image(resi_secret_1, c.TEST_PATH_resi_secret_1 + '%.5d.png' % i)
# torchvision.utils.save_image(resi_secret_2, c.TEST_PATH_resi_secret_2 + '%.5d.png' % i)

# torchvision.utils.save_image(cover, c.TEST_PATH_cover + '%.5d.png' % i)
# torchvision.utils.save_image(secret_1, c.TEST_PATH_secret_1 + '%.5d.png' % i)
# torchvision.utils.save_image(secret_2, c.TEST_PATH_secret_2 + '%.5d.png' % i)

# torchvision.utils.save_image(output_steg_1, c.TEST_PATH_steg_1 + '%.5d.png' % i)
# torchvision.utils.save_image(rev_secret_1, c.TEST_PATH_secret_rev_1 + '%.5d.png' % i)

# torchvision.utils.save_image(output_steg_2, c.TEST_PATH_steg_2 + '%.5d.png' % i)
# torchvision.utils.save_image(rev_secret_2, c.TEST_PATH_secret_rev_2 + '%.5d.png' % i)

torchvision.utils.save_image(rev_secret_dwt.narrow(1, 0, c.channels_in), '/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret-rev_1/' + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_dwt.narrow(1, c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret-rev_1/' + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_dwt.narrow(1, 2 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret-rev_1/' + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_dwt.narrow(1, 3 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret-rev_1/' + '%.5d.png' % i)

torchvision.utils.save_image(output_steg_dwt_2.narrow(1, 0, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/steg_2/' + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_dwt_2.narrow(1, c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/steg_2/' + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_dwt_2.narrow(1, 2 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/steg_2/' + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_dwt_2.narrow(1, 3 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/steg_2/' + '%.5d.png' % i)

torchvision.utils.save_image(output_steg_dwt_1.narrow(1, 0, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/steg_1/' + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_dwt_1.narrow(1, c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/steg_1/' + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_dwt_1.narrow(1, 2 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/steg_1/' + '%.5d.png' % i)
torchvision.utils.save_image(output_steg_dwt_1.narrow(1, 3 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/steg_1/' + '%.5d.png' % i)

torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, 0, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret-rev_2/' + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret-rev_2/' + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, 2 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret-rev_2/' + '%.5d.png' % i)
torchvision.utils.save_image(rev_secret_dwt_2.narrow(1, 3 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret-rev_2/' + '%.5d.png' % i)

torchvision.utils.save_image(cover_dwt.narrow(1, 0, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/cover/' + '%.5d.png' % i)
torchvision.utils.save_image(cover_dwt.narrow(1, c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/cover/' + '%.5d.png' % i)
torchvision.utils.save_image(cover_dwt.narrow(1, 2 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/cover/' + '%.5d.png' % i)
torchvision.utils.save_image(cover_dwt.narrow(1, 3 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/cover/' + '%.5d.png' % i)

torchvision.utils.save_image(secret_dwt_1.narrow(1, 0, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret_1/' + '%.5d.png' % i)
torchvision.utils.save_image(secret_dwt_1.narrow(1, c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret_1/' + '%.5d.png' % i)
torchvision.utils.save_image(secret_dwt_1.narrow(1, 2 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret_1/' + '%.5d.png' % i)
torchvision.utils.save_image(secret_dwt_1.narrow(1, 3 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret_1/' + '%.5d.png' % i)

torchvision.utils.save_image(secret_dwt_2.narrow(1, 0, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LL/secret_2/' + '%.5d.png' % i)
torchvision.utils.save_image(secret_dwt_2.narrow(1, c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HL/secret_2/' + '%.5d.png' % i)
torchvision.utils.save_image(secret_dwt_2.narrow(1, 2 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/LH/secret_2/' + '%.5d.png' % i)
torchvision.utils.save_image(secret_dwt_2.narrow(1, 3 * c.channels_in, c.channels_in),
'/home/jjp/cascaded_Hinet/test-image-attention-div2k-DWT/HH/secret_2/' + '%.5d.png' % i)
Loading

0 comments on commit cec74b5

Please sign in to comment.