Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved dataset #68

Merged
merged 30 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f765a23
New simulated dataset (moved old dataset)
YohannPerron Jul 20, 2023
1b91fab
Move dataset to utils
YohannPerron Jul 20, 2023
8894801
Added parent class DualDataset
YohannPerron Jul 20, 2023
1834c66
Use new dataset structure for training
YohannPerron Jul 20, 2023
ea14909
Fix doc and bugs
YohannPerron Jul 20, 2023
037a192
New dataset for lensless only
YohannPerron Jul 20, 2023
c81e95a
Fixes for downscaling
YohannPerron Jul 20, 2023
6bb042d
Update change
YohannPerron Jul 21, 2023
c383386
Disclaimer for LenslessDataset
YohannPerron Jul 21, 2023
4b8750a
Added header
YohannPerron Jul 21, 2023
1a8283f
Updated documentation
YohannPerron Jul 21, 2023
7d25fe7
Merge.
ebezzam Jul 31, 2023
a637dc4
Fix typos and wording.
ebezzam Jul 31, 2023
9b02ff1
Move dataset docs to data section.
ebezzam Jul 31, 2023
9b87f75
Fixed docstring
YohannPerron Aug 14, 2023
d58d5ef
Merge branch 'main' into Improved-Dataset
YohannPerron Aug 14, 2023
f668813
Fix for flip in simulated dataset
YohannPerron Aug 16, 2023
b9b3d78
Add wrapper arounf FarFieldSimulator
YohannPerron Aug 22, 2023
82e8270
Fix import error
YohannPerron Aug 22, 2023
e763f7d
Fix docstrings
YohannPerron Aug 22, 2023
1c08b6d
FIx typos.
ebezzam Aug 28, 2023
32e5f9b
Fix doc rendering of FarFieldSimulator.
ebezzam Aug 28, 2023
dc7afe6
Refactor.
ebezzam Aug 28, 2023
5e9737b
Refactor.
ebezzam Aug 29, 2023
585ed44
Fix import.
ebezzam Aug 29, 2023
fe0467d
Refactor and rephrase for clearer dataset diff
ebezzam Aug 29, 2023
b37cdcf
Fixed no attribute psf
YohannPerron Aug 29, 2023
d80f340
add new simulation to training script
YohannPerron Aug 29, 2023
c6ec313
Remove print.
ebezzam Aug 29, 2023
a821c82
Update changelog.
ebezzam Aug 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Added
- Script for measuring arbitrary dataset (from Raspberry Pi).
- Support for preprocessing and postprocessing, such as denoising, in ``TrainableReconstructionAlgorithm``. Both trainable and fix postprocessing can be used.
- Utilities to load a trained DruNet model for use as postprocessing in ``TrainableReconstructionAlgorithm``.
- Unified interface for dataset. See ``utils.dataset.DualDataset``.
- New simulated dataset compatible with new data format ([(batch_size), depth, width, height, color]). See ``utils.dataset.SimulatedDataset``.
- New dataset for pair of original image and thair measurement from a screen. See ``utils.dataset.LenslessDataset``.
- Support for unrolled loading and inference in the script ``admm.py``.
- Tikhonov reconstruction for coded aperture measurements (MLS / MURA).

Expand Down
3 changes: 2 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ docutils==0.16 # >0.17 doesn't render bullets
numpy>=1.22 # so that default dtype are correctly rendered
torch>=1.10
torchvision>=0.15.2
torchmetrics>=0.11.4
torchmetrics>=0.11.4
waveprop>=0.0.5
8 changes: 4 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
"pycsou.util",
"pycsou.util.ptype",
"PIL",
"PIL.Image",
"tqdm",
"paramiko",
"paramiko.ssh_exception",
"perlin_numpy",
"waveprop",
"waveprop.fresnel",
"waveprop.rs",
"waveprop.noise",
"scipy.special",
"matplotlib.cm",
"pyffs",
]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
Expand Down
23 changes: 23 additions & 0 deletions docs/source/dataset.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Dataset objects (for training and testing)
==========================================

The software below provides functionality (with PyTorch) to load
datasets for training and testing.

.. automodule:: lensless.utils.dataset

.. autoclass:: lensless.utils.dataset.DualDataset
:members: _get_images_pair
:special-members: __init__, __len__

.. autoclass:: lensless.utils.dataset.LenslessDataset
:members:
:special-members: __init__

.. autoclass:: lensless.utils.dataset.SimulatedFarFieldDataset
:members:
:special-members: __init__

.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset
:members:
:special-members: __init__
4 changes: 0 additions & 4 deletions docs/source/evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,4 @@

.. automodule:: lensless.eval.benchmark

.. autoclass:: lensless.eval.benchmark.ParallelDataset
:members:
:special-members: __init__

.. autofunction:: lensless.eval.benchmark.benchmark
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Contents

simulation
data
dataset

.. toctree::
:hidden:
Expand Down
12 changes: 12 additions & 0 deletions docs/source/simulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ library is used with the following simulation steps:

PyTorch support is available to speed up simulation on GPU, and to create Dataset and DataLoader objects for training and testing!

FarFieldSimulator
------------------

A wrapper around `waveprop.simulation.FarFieldSimulator <https://github.com/ebezzam/waveprop/blob/82dfb08b4db11c0c07ef00bdb59b5a769a49f0b3/waveprop/simulation.py#L11C11-L11C11>`__
is implemented as :py:class:`lensless.utils.simulation.FarFieldSimulator`.
It handles the conversion between the HWC and CHW dimension orderings so that the convention of LenslessPiCam can be maintained (namely HWC).

.. autoclass:: lensless.utils.simulation.FarFieldSimulator
:members:
:special-members: __init__


Simulating 3D data
------------------

Expand Down
209 changes: 2 additions & 207 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,14 @@
# #############################################################################


import glob
import os
from lensless.utils.io import load_psf
from lensless.utils.image import resize
import numpy as np
from lensless.utils.dataset import DiffuserCamTestDataset
from tqdm import tqdm

from lensless.utils.io import load_image

try:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
from torch.nn import MSELoss, L1Loss
from torchmetrics import StructuralSimilarityIndexMeasure
from torchmetrics.image import lpip, psnr
Expand All @@ -28,207 +24,6 @@
)


class ParallelDataset(Dataset):
"""
Dataset consisting of lensless and corresponding lensed image.

It can be used with a PyTorch DataLoader to load a batch of lensless and corresponding lensed images.

"""

def __init__(
self,
root_dir,
n_files=False,
background=None,
downsample=4,
flip=False,
transform_lensless=None,
transform_lensed=None,
lensless_fn="diffuser",
lensed_fn="lensed",
image_ext="npy",
**kwargs,
):
"""
Dataset consisting of lensless and corresponding lensed image. Default parameters are for the DiffuserCam
Lensless Mirflickr Dataset (DLMD).

Parameters
----------

root_dir : str
Path to the test dataset. It is expected to contain two folders: ones of lensless images and one of lensed images.
n_files : int or None, optional
Metrics will be computed only on the first ``n_files`` images. If None, all images are used, by default False
background : :py:class:`~torch.Tensor` or None, optional
If not ``None``, background is removed from lensless images, by default ``None``.
downsample : int, optional
Downsample factor of the lensless images, by default 4.
flip : bool, optional
If ``True``, lensless images are flipped, by default ``False``.
transform_lensless : PyTorch Transform or None, optional
Transform to apply to the lensless images, by default None
transform_lensed : PyTorch Transform or None, optional
Transform to apply to the lensed images, by default None
lensless_fn : str, optional
Name of the folder containing the lensless images, by default "diffuser".
lensed_fn : str, optional
Name of the folder containing the lensed images, by default "lensed".
image_ext : str, optional
Extension of the images, by default "npy".
"""

self.root_dir = root_dir
self.lensless_dir = os.path.join(root_dir, lensless_fn)
self.lensed_dir = os.path.join(root_dir, lensed_fn)
self.image_ext = image_ext.lower()

files = glob.glob(os.path.join(self.lensless_dir, "*." + image_ext))
if n_files:
files = files[:n_files]
self.files = [os.path.basename(fn) for fn in files]

if len(self.files) == 0:
raise FileNotFoundError(
f"No files found in {self.lensless_dir} with extension {image_ext}"
)

self.background = background
self.downsample = downsample / 4
self.flip = flip
self.transform_lensless = transform_lensless
self.transform_lensed = transform_lensed

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

if self.image_ext == "npy":
lensless_fp = os.path.join(self.lensless_dir, self.files[idx])
lensed_fp = os.path.join(self.lensed_dir, self.files[idx])
lensless = np.load(lensless_fp)
lensed = np.load(lensed_fp)
else:
# more standard image formats: png, jpg, tiff, etc.
lensless_fp = os.path.join(self.lensless_dir, self.files[idx])
lensed_fp = os.path.join(self.lensed_dir, self.files[idx])
lensless = load_image(lensless_fp)
lensed = load_image(lensed_fp)

# convert to float
if lensless.dtype == np.uint8:
lensless = lensless.astype(np.float32) / 255
lensed = lensed.astype(np.float32) / 255
else:
# 16 bit
lensless = lensless.astype(np.float32) / 65535
lensed = lensed.astype(np.float32) / 65535

if self.downsample != 1.0:
lensless = resize(lensless, factor=1 / self.downsample)
lensed = resize(lensed, factor=1 / self.downsample)

lensless = torch.from_numpy(lensless)
lensed = torch.from_numpy(lensed)

# If [H, W, C] -> [D, H, W, C]
if len(lensless.shape) == 3:
lensless = lensless.unsqueeze(0)
if len(lensed.shape) == 3:
lensed = lensed.unsqueeze(0)

if self.background is not None:
lensless = lensless - self.background

# flip image x and y if needed
if self.flip:
lensless = torch.rot90(lensless, dims=(-3, -2))
lensed = torch.rot90(lensed, dims=(-3, -2))
if self.transform_lensless:
lensless = self.transform_lensless(lensless)

if self.transform_lensed:
lensed = self.transform_lensed(lensed)

return lensless, lensed


class DiffuserCamTestDataset(ParallelDataset):
"""
Dataset consisting of lensless and corresponding lensed image. This is the standard dataset used for benchmarking.
"""

def __init__(
self,
data_dir="data",
n_files=200,
downsample=8,
):
"""
Dataset consisting of lensless and corresponding lensed image. Default parameters are for the test set of DiffuserCam
Lensless Mirflickr Dataset (DLMD).

Parameters
----------
data_dir : str, optional
The path to the folder containing the DiffuserCam_Test dataset, by default "data"
n_files : int, optional
Number of image pair to load in the dataset , by default 200
downsample : int, optional
Downsample factor of the lensless images, by default 8
"""
# download dataset if necessary
main_dir = data_dir
data_dir = os.path.join(data_dir, "DiffuserCam_Test")
if not os.path.isdir(data_dir):
print("No dataset found for benchmarking.")
try:
from torchvision.datasets.utils import download_and_extract_archive
except ImportError:
exit()
msg = "Do you want to download the sample dataset (3.5GB)?"

# default to yes if no input is given
valid = input("%s (Y/n) " % msg).lower() != "n"
if valid:
url = "https://drive.switch.ch/index.php/s/D3eRJ6PRljfHoH8/download"
filename = "DiffuserCam_Test.zip"
download_and_extract_archive(url, main_dir, filename=filename, remove_finished=True)

psf_fp = os.path.join(data_dir, "psf.tiff")
psf, background = load_psf(
psf_fp,
downsample=downsample,
return_float=True,
return_bg=True,
bg_pix=(0, 15),
)

# transform from BGR to RGB
from torchvision import transforms

transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]])

self.psf = transform_BRG2RGB(torch.from_numpy(psf))

super().__init__(
data_dir,
n_files,
background,
downsample,
flip=False,
transform_lensless=transform_BRG2RGB,
transform_lensed=transform_BRG2RGB,
lensless_fn="diffuser",
lensed_fn="lensed",
image_ext="npy",
)


def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
"""
Compute multiple metrics for a reconstruction algorithm.
Expand Down
3 changes: 1 addition & 2 deletions lensless/hardware/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from perlin_numpy import generate_perlin_noise_2d
from sympy.ntheory import quadratic_residues
from scipy.signal import max_len_seq
from scipy.linalg import circulant
from numpy.linalg import multi_dot
from scipy.linalg import circulant, multi_dot
from waveprop.fresnel import fresnel_conv
from waveprop.rs import angular_spectrum
from waveprop.noise import add_shot_noise
Expand Down
Loading
Loading