Skip to content

Commit

Permalink
Add one-shot method for ADMM and GD.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed May 13, 2024
1 parent 3e07250 commit 173329d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
26 changes: 25 additions & 1 deletion lensless/recon/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy as np
from lensless.recon.recon import ReconstructionAlgorithm
from scipy import fft
from lensless.utils.io import load_data
import time

try:
import torch
Expand Down Expand Up @@ -45,7 +47,7 @@ def __init__(
norm="backward",
# PnP
denoiser=None,
**kwargs
**kwargs,
):
"""
Expand Down Expand Up @@ -393,3 +395,25 @@ def finite_diff_gram(shape, dtype=None, is_torch=False):
return torch.fft.rfft2(gram, dim=(-3, -2))
else:
return fft.rfft2(gram, axes=(-3, -2))


def apply_admm(psf_fp, data_fp, n_iter, verbose=False, **kwargs):

# load data
psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs)

# create reconstruction object
recon = ADMM(psf, n_iter=n_iter)

# set data
recon.set_data(data)

# perform reconstruction
start_time = time.time()
res = recon.apply(plot=False)
proc_time = time.time() - start_time

if verbose:
print(f"Reconstruction time : {proc_time} s")
print(f"Reconstruction shape: {res.shape}")
return res
25 changes: 24 additions & 1 deletion lensless/recon/gd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import numpy as np
from lensless.recon.recon import ReconstructionAlgorithm
import inspect
from lensless.utils.io import load_data
import time

try:
import torch
Expand Down Expand Up @@ -229,9 +231,30 @@ def reset(self, tk=None):

def _update(self, iter):
self._image_est -= self._alpha * self._grad()
# xk = self._proj(self._image_est)
xk = self._form_image()
tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2
self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk)
self._tk = tk
self._xk = xk


def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs):

# load data
psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs)

# create reconstruction object
recon = GradientDescent(psf, n_iter=n_iter, proj=proj)

# set data
recon.set_data(data)

# perform reconstruction
start_time = time.time()
res = recon.apply(plot=False)
proc_time = time.time() - start_time

if verbose:
print(f"Reconstruction time : {proc_time} s")
print(f"Reconstruction shape: {res.shape}")
return res

0 comments on commit 173329d

Please sign in to comment.