-
Notifications
You must be signed in to change notification settings - Fork 3
/
FRIST_imagedenoising.m
138 lines (136 loc) · 5.86 KB
/
FRIST_imagedenoising.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
function [Xr, Dlist, outputParam]= FRIST_imagedenoising(data, param)
%Function for denoising the gray-scale image using FRIST-based denoising
%algorithm.
%
%Note that all input parameters need to be set prior to simulation. We
%provide some example settings using function FRIST_imagedenoisinge_param.
%However, the user is advised to carefully choose optimal values for the
%parameters depending on the specific data or task at hand.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Inputs -
% 1. data : Image data. The fields are as follows -
% - noisy: a*b size gray-scale matrix for denoising
% - oracle (optional): a*b size gray-scale matrix as
% ground-true, which is used to calculate PSNR
%
% 2. param: Structure that contains the parameters of the
% OCTOBOS_imagedenoising algorithm. The various fields are as follows
% -
% - isFlipping: whether including flipping
% - sig: Standard deviation of the additive Gaussian
% noise (Example: 20)
% - n: Patch size as (Example: 64)
% - stride: stride of overlapping patches
% - isUnitary: set to 1, if unitary TL is used
% - isMultipass: set to 1, if multipass is applied
%
% Outputs -
% 1. Xr - Image reconstructed with FRIST_imagedenoising algorithm.
% 2. Dlist - learned FRIST.
% 2. outputParam: Structure that contains the parameters of the
% algorithm output for analysis as follows
% -
% - psnrOut: PSNR of Xr, if the oracle is provided
% - time: run time of the denoising algorithm
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%% Initialization %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
param = FRIST_imagedenoise_param(param);
mergeK = param.mergeK;
% n = param.n; % patch size / dimensionality
isUnitary = param.isUnitary;
C1 = param.C1; % thresholding coefficient
sig2 = param.sig2; % multi-pass noise level estimates
la = param.la; % fidelity term coefficient
T0 = param.T0; % initial sparsity level
Xr = data.noisy; % noisy image
transform = param.transform; % initial transform
iter = param.iter; % number of iterations in first pass denoising
iterMultipass = param.iterMultipass; % number of iterations in multipass denoising
l0 = param.l0; % regularizer coefficient
roundLearning = param.roundLearning; % number of rounds for OCTOBOS learning
stride = param.stride; % stride of overlapping patches
isFlipping = param.isFlipping;
% clear param;
stp = 1; % sparsity increase stepsize
SP = 1:stp:round(9*T0); % maximum sparsity level allowed in algorithm is 9*T0 here
len = length(sig2);
[nK, n] = size(transform);
Dlist = zeros(nK, n, len);
permutation = generatePermutation(sqrt(n), isFlipping);
L = permutation.L;
K = permutation.K;
revInd = permutation.revInd;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% MAIN CODE %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
for pr = 1 : length(sig2)
if isFlipping
numCl = zeros(2 * permutation.K, iter);
selectIndex = 1 : (2 * K);
reduceK = 2 * K;
else
numCl = zeros(permutation.K, iter);
selectIndex = 1 : K;
reduceK = K;
end
if(pr > 1)
iter = iterMultipass;
end
sig = sig2(pr);
threshold = C1 * sig * (sqrt(n)); %threshold in variable sparsity update
[TE, idx] = my_im2col(Xr,[sqrt(n),sqrt(n)],stride);
NTE = size(TE,2);
mu = mean(TE);
TE = TE - ones(n,1)*mu; %mean subtraction
STY = ones(1, NTE).*T0; % Initial Sparsity Vector
D = transform;
if ~isUnitary
l2 = l0 * norm(TE, 'fro')^2;
end
for j =1 : iter
[YH, IDX] = rotateFRIST_merge(TE, D, permutation, STY, isFlipping, selectIndex);
if j < iter/3
currentRoundLearning = roundLearning*2; %roundLearning is maximum number of learning iterations
else
currentRoundLearning = roundLearning;
end
if isUnitary
% unitary transform learning
D = forwardORTHOTRANSb(D, YH,currentRoundLearning, STY);
else
% well-conditioned transform learning
[D, ~] = transformLearning(YH, D, l2, l2, STY, currentRoundLearning);
end
[STY, reconstruction] = sparsityUpdate(YH, D, la, threshold, SP); % Sparsity Update for YH
%%%%%%%%% select the largest K cluster %%%%%%%%%%%%%
while reduceK > mergeK
for k = 1 : K
numCl(k, j) = numel(find(IDX == k));
if isFlipping
numCl(k + K, j) = numel(find(IDX == k + K));
end
end
reduceK = max(round(reduceK / 2), mergeK);
[~, sortweight] = sort(numCl(:, j),'descend');
selectIndex = sortweight(1 : reduceK);
end
end
% Patches Recovery
for k = 1 : K
[~, rev] = sort(L(:, k));
reconstruction(:, IDX == k) = reconstruction(rev, IDX == k);
if isFlipping
reconstruction(:, IDX == k + K) = reconstruction(rev, IDX == k + K);
reconstruction(:, IDX == k + K) = reconstruction(revInd, IDX == k + K);
end
end
reconstruction = reconstruction + ones(n,1)*mu;
Xr = my_col2image(Xr, n, reconstruction, idx, 0);
Xr(Xr>255) = 255;
Xr(Xr<0) = 0;
Dlist(:, :, pr) = D;
% display(pr);
end
outputParam.time = toc;
outputParam.psnrOut = PSNR(Xr - data.oracle);
end