-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
3D support #41
3D support #41
Conversation
Update utilities for loading and visualizing.
Add dataset evaluation
Add reconstruction template.
Update readme and setup
Update metric library.
Update capture script and add display script.
Add MIR Flicker scripts and update README.
Add support for original DiffuserCam dataset.
Update readme.
lensless/io.py
Outdated
|
||
if data.shape[3] == 1 and psf.shape[3] > 1: | ||
print("Warning : loaded a RGB PSF with grayscale data. Repeating data across channels.") | ||
print("This may be an error as the PSF and the data are likely from different datasets.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
combine both message and use warnings.warn
lensless/recon.py
Outdated
3D example | ||
---------- | ||
|
||
It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
render algorithms with hyperlinks as
:py:class:`~lensless.GradientDescent`
:py:class:`~lensless.ADMM`
:py:class:`~lensless.APGD`
as in here
lensless/recon.py
Outdated
---------- | ||
|
||
It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet. | ||
This requires to use a 3D PSF as an input in the form of a .npy file, which actually is a set of 2D PSFs corresponding to the same diffuser sampeled with light sources from different depths. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"of an .npy
file..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"sampled"
lensless/recon.py
Outdated
|
||
It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet. | ||
This requires to use a 3D PSF as an input in the form of a .npy file, which actually is a set of 2D PSFs corresponding to the same diffuser sampeled with light sources from different depths. | ||
The input data for 3D reconstructions is still a 2D image, as collected by the camera. The reconstruction will be able to separate which part of the lensless data corresponds to which 2D PSF, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove extra space "The input data for..."
lensless/recon.py
Outdated
It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet. | ||
This requires to use a 3D PSF as an input in the form of a .npy file, which actually is a set of 2D PSFs corresponding to the same diffuser sampeled with light sources from different depths. | ||
The input data for 3D reconstructions is still a 2D image, as collected by the camera. The reconstruction will be able to separate which part of the lensless data corresponds to which 2D PSF, | ||
and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of a .npy file as well as a 2D projection on the depth axis to be displayed to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"an .npy
file. A 2D projection on the depth axis is also displayed to the user."
lensless/recon.py
Outdated
Outside of the input data and PSF, no special argument has to be given to the script for it to operate a 3D reconstruction, as actually, the 2D reconstuction is internally | ||
viewed as a 3D reconstruction which has only one depth level. It is also the case for ADMM although for now, the reconstructions are wrong when more than one depth level is used. | ||
|
||
3D data is not directly provided in the LenslessPiCam, but some can be :doc:`imported <data>` from the Waller Lab dataset. For this data, it is best to set the downsample to 1 : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"provided in LenslessPiCam"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
", but example data can be obtained from Waller Lab."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"... best to set downsample=1
:
lensless/recon.py
Outdated
and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of a .npy file as well as a 2D projection on the depth axis to be displayed to the | ||
user as an image. | ||
|
||
As for the 2D ADMM reconstuction, scripts for 3D reconstruction can be found in ``scripts/recon/gradient_descent.py`` and ``scripts/recon/apgd_pycsou.py``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"As for the 2D ADMM....one depth level is used"
All this can be replaced with:
The same scripts for 2D reconstruction can be used for 3D reconstruction, namely scripts/recon/gradient_descent.py
and scripts/recon/apgd_pycsou.py
.
lensless/recon.py
Outdated
@@ -193,16 +213,11 @@ def __init__(self, psf, dtype=None, pad=True, n_iter=100, **kwargs): | |||
if torch_available: | |||
self.is_torch = isinstance(psf, torch.Tensor) | |||
|
|||
assert len(psf.shape) == 4 # depth, width, height, channel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can add an error message: "PSF must be 4D: [depth, width, height, channel]."
lensless/recon.py
Outdated
data = data[:, :, None] | ||
assert len(self._psf_shape) == len(data.shape) | ||
assert len(data.shape) == 4 | ||
assert len(self._psf_shape) == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you already have check for PSF shape in constructor
lensless/recon.py
Outdated
assert len(data.shape) == 2 | ||
data = data[:, :, None] | ||
assert len(self._psf_shape) == len(data.shape) | ||
assert len(data.shape) == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can add similar error message as for PSF
lensless/rfft_convolve.py
Outdated
self._n_channels = self._psf.shape[2] | ||
self._psf_shape = np.array(self._psf.shape) | ||
|
||
assert len(psf.shape) == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can give error message with expected shape
lensless/rfft_convolve.py
Outdated
self._start_idx = (self._padded_shape[:2] - self._psf_shape[:2]) // 2 | ||
self._end_idx = self._start_idx + self._psf_shape[:2] | ||
self._padded_shape = list(np.r_[self._psf_shape[0], self._padded_shape, self._psf_shape[3]]) | ||
self._start_idx = (self._padded_shape[1:3] - self._psf_shape[1:3]) // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you replace with [-3:-1]?
lensless/rfft_convolve.py
Outdated
self._padded_shape = list(np.r_[self._padded_shape, [self._n_channels]]) | ||
self._start_idx = (self._padded_shape[:2] - self._psf_shape[:2]) // 2 | ||
self._end_idx = self._start_idx + self._psf_shape[:2] | ||
self._padded_shape = list(np.r_[self._psf_shape[0], self._padded_shape, self._psf_shape[3]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you replace PSF shape indices with negative values? e.g. here
self._padded_shape = list(np.r_[self._psf_shape[-4], self._padded_shape, self._psf_shape[-1]])
lensless/rfft_convolve.py
Outdated
self._end_idx = self._start_idx + self._psf_shape[:2] | ||
self._padded_shape = list(np.r_[self._psf_shape[0], self._padded_shape, self._psf_shape[3]]) | ||
self._start_idx = (self._padded_shape[1:3] - self._psf_shape[1:3]) // 2 | ||
self._end_idx = self._start_idx + self._psf_shape[1:3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[-3:-1]
lensless/rfft_convolve.py
Outdated
self.pad = pad # Whether necessary to pad provided data | ||
|
||
# precompute filter in frequency domain | ||
if self.is_torch: | ||
self._H = torch.fft.rfft2( | ||
self._pad(self._psf), norm=norm, dim=(0, 1), s=self._padded_shape[:2] | ||
self._pad(self._psf), norm=norm, dim=(1, 2), s=self._padded_shape[1:3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._pad(self._psf), norm=norm, dim=(-3, -2), s=self._padded_shape[-3:-1]
lensless/rfft_convolve.py
Outdated
) | ||
self._Hadj = torch.conj(self._H) | ||
self._padded_data = torch.zeros(size=self._padded_shape, dtype=dtype, device=psf.device) | ||
|
||
else: | ||
self._H = fft.rfft2(self._pad(self._psf), axes=(0, 1), norm=norm) | ||
self._H = fft.rfft2(self._pad(self._psf), axes=(1, 2), norm=norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
conv_output = torch.fft.ifftshift( | ||
torch.fft.irfft2( | ||
torch.fft.rfft2(self._padded_data, dim=(0, 1)) * self._H, dim=(0, 1) | ||
torch.fft.rfft2(self._padded_data, dim=(1, 2)) * self._H, dim=(1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
), | ||
dim=(0, 1), | ||
dim=(1, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
conv_output = fft.ifftshift( | ||
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._H, axes=(0, 1)), | ||
axes=(0, 1), | ||
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._H, axes=(1, 2)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._H, axes=(0, 1)), | ||
axes=(0, 1), | ||
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._H, axes=(1, 2)), | ||
axes=(1, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
deconv_output = torch.fft.ifftshift( | ||
torch.fft.irfft2( | ||
torch.fft.rfft2(self._padded_data, dim=(0, 1)) * self._Hadj, dim=(0, 1) | ||
torch.fft.rfft2(self._padded_data, dim=(1, 2)) * self._Hadj, dim=(1, 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
), | ||
dim=(0, 1), | ||
dim=(1, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
deconv_output = fft.ifftshift( | ||
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._Hadj, axes=(0, 1)), | ||
axes=(0, 1), | ||
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._Hadj, axes=(1, 2)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
lensless/rfft_convolve.py
Outdated
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._Hadj, axes=(0, 1)), | ||
axes=(0, 1), | ||
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._Hadj, axes=(1, 2)), | ||
axes=(1, 2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
3D support for Gradient Descent (and APGD soon), also with pytorch
Benchmarks done with a AMD Ryzen 7 5800H CPU and a NVIDIA Geforce RTX 3050 Laptop GPU
Data used : see
data.rst
Gradient Descent, 2D
Default : ~140 s
python scripts/recon/gradient_descent.py
Torch, CPU : ~70 s
python scripts/recon/gradient_descent.py torch=True
Torch, cuda:0 : ~22 s
python scripts/recon/gradient_descent.py -cn pytorch
Gradient Descent, 3D
Default : ~135 s
python scripts/recon/gradient_descent.py input.psf="data/psf/diffuser_cam.npy" input.data="data/raw_data/diffuser_cam.tiff" preprocess.downsample=1
Torch, CPU : ~105 s
python scripts/recon/gradient_descent.py torch=True input.psf="data/psf/diffuser_cam.npy" input.data="data/raw_data/diffuser_cam.tiff" preprocess.downsample=1
Torch, cuda:0 : ~27 s
python scripts/recon/gradient_descent.py -cn pytorch input.psf="data/psf/diffuser_cam.npy" input.data="data/raw_data/diffuser_cam.tiff" preprocess.downsample=1