Skip to content

Commit

Permalink
3D plot bugfix (#54)
Browse files Browse the repository at this point in the history
* fixed normalization bug for 3D plotting and added 3 projections view

* Updated changelog

* Minor fixes.

* Correct changelog.

* Improve plotting for 3D.

* Improve docs on available data.

---------

Co-authored-by: Eric Bezzam <[email protected]>
  • Loading branch information
Julien-Sahli and ebezzam authored Jul 31, 2023
1 parent 09e96c1 commit c73141b
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 53 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ Changed
- Simpler remote capture and display scripts with Hydra.
- Group source code into four modules: ``hardware``, ``recon``, ``utils``, ``eval``.
- Split scripts into subfolders.
- Displaying 3D reconstructions now shows projections on all three axis.


Bugfix
~~~~~~

-
- Displaying 3D reconstructions by summing values along axis would produce un-normalized values.

1.0.4 - (2023-06-14)
--------------------
Expand Down Expand Up @@ -85,6 +86,7 @@ Bugfix

- Loading grayscale PSFs would cause an dimension error when removing the background pixels.


1.0.2 - (2022-05-31)
--------------------

Expand Down
50 changes: 36 additions & 14 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Already available data
======================
Measured data
=============

You can download example PSFs and raw data that we've measured
You can download PSFs and raw data that we've measured
`here <https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww>`__. We
recommend placing this content in the ``data`` folder.

Expand Down Expand Up @@ -37,15 +37,10 @@ use the correct PSF file for the data you're using!
python scripts/recon/gradient_descent.py -cn in_the_wild \
input.data=data/raw_data/thumbs_up_rgb.png \
input.psf=data/psf/tape_rgb.png
# 3D LCAV logo
python scripts/recon/gradient_descent.py \
input.data=data/raw_data/3d_sim.png \
input.psf=data/psf/3d_sim.npz
Dataset collected by other people
---------------------------------
DiffuserCam Lensless Mirflickr Dataset (DLMD)
---------------------------------------------

You can download a subset for the `DiffuserCam Lensless Mirflickr
Dataset <https://waller-lab.github.io/LenslessLearning/dataset.html>`__
Expand All @@ -61,14 +56,41 @@ dataset (200 files, 725 MB). It was prepared with the following script:
--data ~/Documents/DiffuserCam/DiffuserCam_Mirflickr_Dataset
3D data
-------

You can download example 3D PSF and raw data from Prof. Laura Waller's lab
`here <https://github.com/Waller-Lab/DiffuserCam/tree/master/example_data>`__.
The PSF has to be converted from ``.mat`` to ``.npy`` in order to be usable:
`here <https://github.com/Waller-Lab/DiffuserCam/tree/master/example_data>`__,
or by running the commands at the beginning of this page to download all
the example data.

Their PSF has to be converted from ``.mat`` to ``.npy`` in order to be usable:

.. code:: bash
python scripts/data/3d/mat_to_npy.py ~/path/to/example_psfs.mat
# replace path to .mat file if different
python scripts/data/3d/mat_to_npy.py data/psf/waller_3d_psfs.mat
The following command can be used to run a reconstruction on the 3D data:

.. code:: bash
python scripts/recon/gradient_descent.py \
input.data=data/raw_data/waller_3d_raw.png \
input.psf=psf.npy preprocess.downsample=1 \
-cn pytorch # if pytorch is available with GPU
You can also perform a 3D reconstruction on data we have simulated:

.. code:: bash
# 3D LCAV logo
python scripts/recon/gradient_descent.py \
input.data=data/raw_data/3d_sim.png \
input.psf=data/psf/3d_sim.npz \
-cn pytorch # if pytorch is available with GPU
Once you have run a reconstruction, you may want to convert the
resulting ``.npy`` files in separate ``.tiff`` images for each depth.
Expand Down
16 changes: 11 additions & 5 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ class :py:class:`~lensless.ReconstructionAlgorithm`. The three reconstruction
----------
It is also possible to reconstruct 3D scenes using :py:class:`~lensless.GradientDescent` or :py:class:`~lensless.APGD`. :py:class:`~lensless.ADMM` does not support 3D reconstruction yet.
This requires to use a 3D PSF as an input in the form of an .npy or .npz file, which actually is a set of 2D PSFs corresponding to the same diffuser sampled with light sources from different depths.
This requires to use a 3D PSF as an input in the form of an ``.npy`` or ``.npz`` file, which is a set of 2D PSFs corresponding to the same diffuser sampled with light sources at different depths.
The input data for 3D reconstructions is still a 2D image, as collected by the camera. The reconstruction will be able to separate which part of the lensless data corresponds to which 2D PSF,
and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of an .npy file. A 2D projection on the depth axis is also displayed to the user.
and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of an ``.npy`` file. A 2D projection on the depth axis is also displayed to the user.
The same scripts for 2D reconstruction can be used for 3D reconstruction, namely ``scripts/recon/gradient_descent.py`` and ``scripts/recon/apgd_pycsou.py``.
Expand Down Expand Up @@ -491,7 +491,7 @@ def apply(

if (plot or save) and disp_iter is not None:
if ax is None:
ax = plot_image(self._get_numpy_data(self._data[0]), gamma=gamma)
ax = plot_image(self._get_numpy_data(self._image_est[0]), gamma=gamma)
else:
ax = None
disp_iter = n_iter + 1
Expand All @@ -503,7 +503,10 @@ def apply(
self._progress()
img = self._form_image()
ax = plot_image(self._get_numpy_data(img[0]), ax=ax, gamma=gamma)
ax.set_title("Reconstruction after iteration {}".format(i + 1))
if hasattr(ax, "__len__"):
ax[0, 0].set_title("Reconstruction after iteration {}".format(i + 1))
else:
ax.set_title("Reconstruction after iteration {}".format(i + 1))
if save:
plt.savefig(plib.Path(save) / f"{i + 1}.png")
if plot:
Expand All @@ -513,7 +516,10 @@ def apply(
final_im = self._form_image()[0]
if plot:
ax = plot_image(self._get_numpy_data(final_im), ax=ax, gamma=gamma)
ax.set_title("Final reconstruction after {} iterations".format(n_iter))
if hasattr(ax, "__len__"):
ax[0, 0].set_title("Final reconstruction after {} iterations".format(n_iter))
else:
ax.set_title("Final reconstruction after {} iterations".format(n_iter))
if save:
plt.savefig(plib.Path(save) / f"{n_iter}.png")
return final_im, ax
Expand Down
97 changes: 64 additions & 33 deletions lensless/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from lensless.utils.image import FLOAT_DTYPES, get_max_val, gamma_correction, autocorr2d


def plot_image(img, ax=None, gamma=None, normalize=True, axis=0):
def plot_image(img, ax=None, gamma=None, normalize=True):
"""
Plot image data.
Expand All @@ -26,61 +26,92 @@ def plot_image(img, ax=None, gamma=None, normalize=True, axis=0):
`Axes` object to fill for plotting/saving, default is to create one.
gamma : float, optional
Gamma correction factor to apply for plots. Default is None.
normalize : bool
normalize : bool, optional
Whether to normalize data to maximum range. Default is True.
axis : int
For 3D data, the axis on which to project the data
Returns
-------
ax : :py:class:`~matplotlib.axes.Axes`
Axes on which image is plot.
"""

if ax is None:
_, ax = plt.subplots()
# if we have only 1 depth, remove the axis
if img.shape[0] == 1:
img = img[0]

max_val = img.max()
if not normalize:
if img.dtype not in FLOAT_DTYPES:
max_val = get_max_val(img)
else:
max_val = 1
# if we have only 1 color channel, remove the axis
if img.shape[-1] == 1:
img = img[..., 0]

# need float image for gamma correction and plotting
img_norm = img / max_val
if gamma and gamma > 1:
img_norm = gamma_correction(img_norm, gamma=gamma)
disp_img = None
cmap = None

# full data format : [depth, width, height, color]
# full 3D RGB format : [depth, width, height, color]
is_3d = False
if len(img.shape) == 4:
if img.shape[3] == 3: # 3d rgb
sum_img = np.sum(img_norm, axis=axis)
ax.imshow(sum_img)

else:
assert img.shape[3] == 1 # 3d grayscale with color channel extended
sum_img = np.sum(img_norm[:, :, :, 0], axis=axis)
ax.imshow(sum_img, cmap="gray")
disp_img = [np.sum(img, axis=axis) for axis in range(3)]
cmap = None
is_3d = True

# data of length 3 means we have to infer whichever depth or color is missing, based on shape.
elif len(img.shape) == 3:
if img.shape[2] == 3: # 2D rgb
ax.imshow(img_norm)

elif img.shape[2] == 1: # 2D grayscale with color channel extended
ax.imshow(img_norm[:, :, 0], cmap="gray")
disp_img = [img]
cmap = None

else: # 3D grayscale
sum_img = np.sum(img_norm, axis=axis)
ax.imshow(sum_img, cmap="gray")
disp_img = [np.sum(img, axis=axis) for axis in range(3)]
cmap = "gray"
is_3d = True

# data of length 2 means we have only width and height
elif len(img.shape) == 2: # 2D grayscale
ax.imshow(img_norm, cmap="gray")
disp_img = [img]
cmap = "gray"

else:
raise ValueError(f"Unexpected data shape : {img.shape}")

max_val = [d.max() for d in disp_img]

if not normalize:
for i in range(len(max_val)):
if disp_img[i].dtype not in FLOAT_DTYPES:
max_val[i] = get_max_val(disp_img[i])
else:
max_val[i] = 1

assert len(disp_img) == 1 or len(disp_img) == 3

# need float image for gamma correction and plotting
img_norm = disp_img.copy()
for i in range(len(img_norm)):
img_norm[i] = disp_img[i] / max_val[i]
if gamma and gamma > 1:
img_norm[i] = gamma_correction(img_norm[i], gamma=gamma)

if ax is None:
if not is_3d:
_, ax = plt.subplots()
else:
_, ax = plt.subplots(2, 2)
else:
raise ValueError(f"Unexpected data shape : {img_norm.shape}")
if is_3d:
assert len(ax) == 2
assert len(ax[0]) == 2

if len(img_norm) == 1:
ax.imshow(img_norm[0], cmap=cmap)

else:

# plot each projection separately
ax[0, 0].imshow(img_norm[0], cmap=cmap)
ax[0, 1].imshow(np.swapaxes(img_norm[2], 0, 1), cmap=cmap)
ax[1, 0].imshow(img_norm[1], cmap=cmap)
ax[1, 1].axis("off")
ax[0, 1].set_xlabel("Depth")
ax[1, 0].set_ylabel("Depth")

return ax

Expand Down

0 comments on commit c73141b

Please sign in to comment.