Skip to content

Commit

Permalink
added negative indexes to rfft_convolve
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien-Sahli committed May 11, 2023
1 parent 9462e5f commit 5f53e8d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
11 changes: 4 additions & 7 deletions lensless/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction
The input data for 3D reconstructions is still a 2D image, as collected by the camera. The reconstruction will be able to separate which part of the lensless data corresponds to which 2D PSF,
and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of an .npy file. A 2D projection on the depth axis is also displayed to the user.
As for the 2D ADMM reconstuction, scripts for 3D reconstruction can be found in ``scripts/recon/gradient_descent.py`` and ``scripts/recon/apgd_pycsou.py``.
Outside of the input data and PSF, no special argument has to be given to the script for it to operate a 3D reconstruction, as actually, the 2D reconstuction is internally
viewed as a 3D reconstruction which has only one depth level. It is also the case for ADMM although for now, the reconstructions are wrong when more than one depth level is used.
The same scripts for 2D reconstruction can be used for 3D reconstruction, namely ``scripts/recon/gradient_descent.py`` and ``scripts/recon/apgd_pycsou.py``.
3D data is provided in LenslessPiCam, but it is simulated. Real example data can be obtained from `Waller Lab <https://github.com/Waller-Lab/DiffuserCam/tree/master/example_data>`_.
For both the simulated data and the data from Waller Lab, it is best to set ``downsample=1`` :
Expand Down Expand Up @@ -216,8 +214,8 @@ def __init__(self, psf, dtype=None, pad=True, n_iter=100, **kwargs):
if torch_available:
self.is_torch = isinstance(psf, torch.Tensor)

assert len(psf.shape) == 4 # depth, width, height, channel
assert psf.shape[3] == 3 or psf.shape[3] == 1 # either rgb or grayscale
assert len(psf.shape) == 4, "PSF must be 4D: [depth, width, height, channel]."
assert psf.shape[3] == 3 or psf.shape[3] == 1, "PSF must either be rgb (3) or grayscale (1)"
self._psf = psf
self._n_iter = n_iter

Expand Down Expand Up @@ -307,8 +305,7 @@ def set_data(self, data):
else:
assert isinstance(data, np.ndarray)

assert len(data.shape) == 4
assert len(self._psf_shape) == 4
assert len(data.shape) == 4, "Data must be 4D: [depth, width, height, channel]."

# assert same shapes
assert np.all(
Expand Down
32 changes: 17 additions & 15 deletions lensless/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs):

# prepare shapes for reconstruction

assert len(psf.shape) == 4
assert len(psf.shape) == 4, "Expected 4D PSF of shape (depth, width, height, channels)"
self._use_3d = psf.shape[0] != 1
self._is_rgb = psf.shape[3] == 3
assert self._is_rgb or psf.shape[3] == 1
Expand All @@ -53,23 +53,25 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs):
self._psf_shape = np.array(self._psf.shape)

# cropping / padding indexes
self._padded_shape = 2 * self._psf_shape[1:3] - 1
self._padded_shape = 2 * self._psf_shape[-3:-1] - 1
self._padded_shape = np.array([next_fast_len(i) for i in self._padded_shape])
self._padded_shape = list(np.r_[self._psf_shape[0], self._padded_shape, self._psf_shape[3]])
self._start_idx = (self._padded_shape[1:3] - self._psf_shape[1:3]) // 2
self._end_idx = self._start_idx + self._psf_shape[1:3]
self._padded_shape = list(
np.r_[self._psf_shape[-4], self._padded_shape, self._psf_shape[-1]]
)
self._start_idx = (self._padded_shape[-3:-1] - self._psf_shape[-3:-1]) // 2
self._end_idx = self._start_idx + self._psf_shape[-3:-1]
self.pad = pad # Whether necessary to pad provided data

# precompute filter in frequency domain
if self.is_torch:
self._H = torch.fft.rfft2(
self._pad(self._psf), norm=norm, dim=(1, 2), s=self._padded_shape[1:3]
self._pad(self._psf), norm=norm, dim=(-3, -2), s=self._padded_shape[-3:-1]
)
self._Hadj = torch.conj(self._H)
self._padded_data = torch.zeros(size=self._padded_shape, dtype=dtype, device=psf.device)

else:
self._H = fft.rfft2(self._pad(self._psf), axes=(1, 2), norm=norm)
self._H = fft.rfft2(self._pad(self._psf), axes=(-3, -2), norm=norm)
self._Hadj = np.conj(self._H)
self._padded_data = np.zeros(self._padded_shape).astype(dtype)

Expand Down Expand Up @@ -98,15 +100,15 @@ def convolve(self, x):
if self.is_torch:
conv_output = torch.fft.ifftshift(
torch.fft.irfft2(
torch.fft.rfft2(self._padded_data, dim=(1, 2)) * self._H, dim=(1, 2)
torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._H, dim=(-3, -2)
),
dim=(1, 2),
dim=(-3, -2),
)

else:
conv_output = fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._H, axes=(1, 2)),
axes=(1, 2),
fft.irfft2(fft.rfft2(self._padded_data, axes=(-3, -2)) * self._H, axes=(-3, -2)),
axes=(-3, -2),
)
if self.pad:
return self._crop(conv_output)
Expand All @@ -127,15 +129,15 @@ def deconvolve(self, y):
if self.is_torch:
deconv_output = torch.fft.ifftshift(
torch.fft.irfft2(
torch.fft.rfft2(self._padded_data, dim=(1, 2)) * self._Hadj, dim=(1, 2)
torch.fft.rfft2(self._padded_data, dim=(-3, -2)) * self._Hadj, dim=(-3, -2)
),
dim=(1, 2),
dim=(-3, -2),
)

else:
deconv_output = fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._Hadj, axes=(1, 2)),
axes=(1, 2),
fft.irfft2(fft.rfft2(self._padded_data, axes=(-3, -2)) * self._Hadj, axes=(-3, -2)),
axes=(-3, -2),
)

if self.pad:
Expand Down

0 comments on commit 5f53e8d

Please sign in to comment.