-
Notifications
You must be signed in to change notification settings - Fork 14
/
algorithm_inv_prob.py
86 lines (53 loc) · 1.97 KB
/
algorithm_inv_prob.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
import torch
import time
import os
### Takes a tensor of size (n_ch, im_d1, im_d2)
### and returns a tensor of size (n_ch, im_d1, im_d2)
def univ_inv_sol(model, x_c ,task ,sig_0=1, sig_L=.01, h0=.01 , beta=.01 , freq=5):
'''
@x_c: M^T.x)
@task: the specific linear inverse problem
@sig_0: initial sigma (largest)
@sig_L: final sigma (smallest)
@h0: 1st step size
@beta:controls added noise in each iteration (0,1]. if 1, no noise is added. As it decreases more noise added.
'''
M_T = task.M_T #low rank measurement matrix - in function form
M = task.M #inverse of M_T
n_ch, im_d1,im_d2 = M(x_c).size()
N = n_ch* im_d1*im_d2
intermed_Ys=[]
# initialize y
e = torch.ones_like(M(x_c), requires_grad= False )
y = torch.normal((e - M(M_T(e)))*.5 + M(x_c), sig_0)
y = y.unsqueeze(0)
y.requires_grad = False
if freq > 0:
intermed_Ys.append(y.squeeze(0))
if torch.cuda.is_available():
y = y.cuda()
f_y = model(y)
sigma = torch.norm(f_y)/np.sqrt(N)
t=1
start_time_total = time.time()
while sigma > sig_L:
h = h0*t/(1+ (h0*(t-1)) )
with torch.no_grad():
f_y = model(y)
d = f_y - M(M_T(f_y[0])) + ( M(M_T(y[0])) - M(x_c) )
sigma = torch.norm(d)/np.sqrt(N)
gamma = sigma*np.sqrt(((1 - (beta*h))**2 - (1-h)**2 ))
noise = torch.randn(n_ch, im_d1,im_d2)
if torch.cuda.is_available():
noise = noise.cuda()
y = y - h*d + gamma*noise
if freq > 0 and t%freq== 0:
print('-----------------------------', t)
print('sigma ' , sigma.item() )
intermed_Ys.append(y.squeeze(0))
t +=1
print("-------- total number of iterations, " , t )
print("-------- average time per iteration (s), " , np.round((time.time() - start_time_total)/(t-1) ,4) )
denoised_y = y - model(y)
return denoised_y.squeeze(0), intermed_Ys