Skip to content

Commit

Permalink
Make sure examples work.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Apr 30, 2024
1 parent 0ccec8e commit 73a4b6b
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 49 deletions.
7 changes: 4 additions & 3 deletions configs/train_digicam_celeba.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# python scripts/recon/train_learning_based.py -cn train_digicam_singlemask
# python scripts/recon/train_learning_based.py -cn train_digicam_celeba
defaults:
- train_unrolledADMM
- _self_
Expand All @@ -13,6 +13,7 @@ files:
huggingface_psf: "psf_simulated.png"
huggingface_dataset: True
split_seed: 0
test_size: 0.15
downsample: 2
rotate: True # if measurement is upside-down
save_psf: False
Expand All @@ -34,14 +35,14 @@ alignment:
random_vflip: False
random_hflip: False
quantize: False
# shifting when there is no files.downsample
# shifting when there is no files to downsample
vertical_shift: -117
horizontal_shift: -25

training:
batch_size: 4
epoch: 25
eval_batch_size: 4
eval_batch_size: 16
crop_preloss: True

reconstruction:
Expand Down
41 changes: 40 additions & 1 deletion configs/train_digicam_multimask.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
# python scripts/recon/train_learning_based.py -cn train_digicam_multimask
defaults:
- train_digicam_singlemask
- train_unrolledADMM
- _self_

torch_device: 'cuda:0'
device_ids: [0, 1, 2, 3]
eval_disp_idx: [1, 2, 4, 5, 9]


# Dataset
files:
dataset: bezzam/DigiCam-Mirflickr-MultiMask-25K
huggingface_dataset: True
huggingface_psf: null
downsample: 1
# TODO: these parameters should be in the dataset?
image_res: [900, 1200] # used during measurement
rotate: True # if measurement is upside-down
save_psf: False

extra_eval:
singlemask:
Expand All @@ -19,3 +27,34 @@ files:
alignment:
topright: [80, 100] # height, width
height: 200

# TODO: these parameters should be in the dataset?
alignment:
# when there is no downsampling
topright: [80, 100] # height, width
height: 200

training:
batch_size: 4
epoch: 25
eval_batch_size: 4

reconstruction:
method: unrolled_admm
unrolled_admm:
# Number of iterations
n_iter: 10
# Hyperparameters
mu1: 1e-4
mu2: 1e-4
mu3: 1e-4
tau: 2e-4
pre_process:
network : UnetRes # UnetRes or DruNet or null
depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [32,64,116,128]
post_process:
network : UnetRes # UnetRes or DruNet or null
depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [32,64,116,128]

1 change: 1 addition & 0 deletions configs/train_digicam_singlemask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ eval_disp_idx: [1, 2, 4, 5, 9]
files:
dataset: bezzam/DigiCam-Mirflickr-SingleMask-25K
huggingface_dataset: True
huggingface_psf: null
downsample: 1
# TODO: these parameters should be in the dataset?
image_res: [900, 1200] # used during measurement
Expand Down
2 changes: 2 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ start_delay: null

# Dataset
files:
# -- using local dataset
# dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
# celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
# psf: data/psf/diffusercam_psf.tiff
# diffusercam_psf: True

# -- using huggingface dataset
dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM
huggingface_dataset: True
huggingface_psf: psf.tiff
Expand Down
36 changes: 20 additions & 16 deletions docs/source/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@ or measured).
:special-members: __init__, __len__


Measured dataset objects
------------------------

.. autoclass:: lensless.utils.dataset.HFDataset
:members:
:special-members: __init__

.. autoclass:: lensless.utils.dataset.MeasuredDataset
:members:
:special-members: __init__

.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal
:members:
:special-members: __init__

.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset
:members:
:special-members: __init__


Simulated dataset objects
-------------------------

Expand All @@ -43,19 +63,3 @@ mask / PSF.
.. autoclass:: lensless.utils.dataset.SimulatedDatasetTrainableMask
:members:
:special-members: __init__


Measured dataset objects
------------------------

.. autoclass:: lensless.utils.dataset.MeasuredDataset
:members:
:special-members: __init__

.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal
:members:
:special-members: __init__

.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset
:members:
:special-members: __init__
54 changes: 44 additions & 10 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,25 +1016,59 @@ def _get_images_pair(self, idx):
return lensless, lensed


class DigiCam(DualDataset):
class HFDataset(DualDataset):
def __init__(
self,
huggingface_repo,
split,
n_files=None,
psf=None,
display_res=None,
sensor="rpi_hq",
slm="adafruit",
rotate=False, # just the lensless image
downsample=1,
downsample_lensed=1,
display_res=None,
sensor="rpi_hq",
slm="adafruit",
alignment=None,
save_psf=False,
simulation_config=None,
return_mask_label=False,
n_files=None,
save_psf=False,
**kwargs,
):
"""
Wrapper for lensless datasets on Hugging Face.
Parameters
----------
huggingface_repo : str
Hugging Face repository ID.
split : str or :py:class:`torch.utils.data.Dataset`
Split of the dataset to use: 'train', 'test', or 'all'. If a Dataset object is given, it is used directly.
n_files : int, optional
Number of files to load from the dataset, by default None, namely all.
psf : str, optional
File name of the PSF at the repository. If None, it is assumed that there is a mask pattern from which the PSF is simulated, by default None.
rotate : bool, optional
If True, lensless images and PSF are rotated 180 degrees. Lensed/original image is not rotated! By default False.
downsample : float, optional
Downsample factor of the lensless images, by default 1.
downsample_lensed : float, optional
Downsample factor of the lensed images, by default 1.
display_res : tuple, optional
Resolution of images when displayed on screen during measurement.
sensor : str, optional
If `psf` not provided, the sensor to use for the PSF simulation, by default "rpi_hq".
slm : str, optional
If `psf` not provided, the SLM to use for the PSF simulation, by default "adafruit".
alignment : dict, optional
Alignment parameters between lensless and lensed data.
If "topright", "height", and "width" are provided, the region-of-interest from the reconstruction of ``lensless`` is extracted and ``lensed`` is reshaped to match.
If "crop" is provided, the region-of-interest is extracted from the simulated lensed image, namely a ``simulation`` configuration should be provided within ``alignment``.
return_mask_label : bool, optional
If multimask dataset, return the mask label (True) or the corresponding PSF (False).
save_psf : bool, optional
If multimask dataset, save the simulated PSFs.
"""

if isinstance(split, str):
if n_files is not None:
Expand Down Expand Up @@ -1080,6 +1114,7 @@ def __init__(

# preparing ground-truth as simulated measurement of original
elif "crop" in alignment:
assert "simulation" in alignment, "Simulation config should be provided"
self.crop = dict(alignment["crop"].copy())
self.crop["vertical"][0] = int(self.crop["vertical"][0] / downsample)
self.crop["vertical"][1] = int(self.crop["vertical"][1] / downsample)
Expand Down Expand Up @@ -1170,7 +1205,7 @@ def __init__(
if "horizontal_shift" in simulation_config:
self.horizontal_shift = int(simulation_config["horizontal_shift"] / downsample)

super(DigiCam, self).__init__(**kwargs)
super(HFDataset, self).__init__(**kwargs)

def __len__(self):
return len(self.dataset)
Expand All @@ -1196,7 +1231,6 @@ def _get_images_pair(self, idx):
lensless_np, factor=1 / self.downsample_lensless, interpolation=cv2.INTER_NEAREST
)


lensless = lensless_np
lensed = lensed_np
if self.simulator is not None:
Expand Down Expand Up @@ -1226,7 +1260,7 @@ def _get_images_pair(self, idx):
elif self.downsample_lensed != 1.0:
lensed = resize(
lensed_np,
factor=self.downsample_lensed,
factor=1 / self.downsample_lensed,
interpolation=cv2.INTER_NEAREST,
)

Expand Down
8 changes: 4 additions & 4 deletions scripts/data/authenticate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"""


from lensless.utils.dataset import DigiCam
from lensless.utils.dataset import HFDataset
import torch
from lensless import ADMM
from lensless.utils.image import rgb2gray
Expand Down Expand Up @@ -67,14 +67,14 @@ def authen(config):

# load multimask dataset
if split == "all":
train_set = DigiCam(
train_set = HFDataset(
huggingface_repo=huggingface_repo,
split="train",
rotate=rotate,
downsample=downsample,
return_mask_label=True,
)
test_set = DigiCam(
test_set = HFDataset(
huggingface_repo=huggingface_repo,
split="test",
rotate=rotate,
Expand Down Expand Up @@ -114,7 +114,7 @@ def authen(config):
file_idx += list(np.arange(n_train_psf) + i * n_train_psf + test_files_offet)

else:
all_set = DigiCam(
all_set = HFDataset(
huggingface_repo=huggingface_repo,
split=split,
rotate=rotate,
Expand Down
4 changes: 2 additions & 2 deletions scripts/eval/benchmark_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from lensless.eval.benchmark import benchmark
import matplotlib.pyplot as plt
from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent
from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, DigiCam
from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset
from lensless.utils.io import save_image

import torch
Expand Down Expand Up @@ -85,7 +85,7 @@ def benchmark_recon(config):
dataset, [train_size, test_size], generator=generator
)
elif dataset == "DigiCamHF":
benchmark_dataset = DigiCam(
benchmark_dataset = HFDataset(
huggingface_repo=config.huggingface.repo,
split="test",
display_res=config.huggingface.image_res,
Expand Down
4 changes: 2 additions & 2 deletions scripts/recon/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from tqdm import tqdm
from joblib import Parallel, delayed
import numpy as np
from lensless.utils.dataset import DiffuserCamMirflickrHF, DigiCam
from lensless.utils.dataset import DiffuserCamMirflickrHF, HFDataset
from lensless.eval.metric import psnr, lpips
from lensless.utils.image import resize

Expand All @@ -47,7 +47,7 @@ def recon_dataset(config):
if config.dataset == "diffusercam":
dataset = DiffuserCamMirflickrHF(split=config.split, downsample=config.downsample)
else:
dataset = DigiCam(
dataset = HFDataset(
huggingface_repo=config.dataset,
split=config.split,
downsample=config.downsample,
Expand Down
4 changes: 2 additions & 2 deletions scripts/recon/digicam_mirflickr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from lensless import ADMM
from lensless.utils.plot import plot_image
from lensless.utils.dataset import DigiCam
from lensless.utils.dataset import HFDataset
import os
from lensless.utils.io import save_image
import time
Expand Down Expand Up @@ -35,7 +35,7 @@ def apply_pretrained(config):
model_config = yaml.safe_load(stream)

# load dataset
test_set = DigiCam(
test_set = HFDataset(
huggingface_repo=model_config["files"]["dataset"],
psf=model_config["files"]["huggingface_psf"]
if "huggingface_psf" in model_config["files"]
Expand Down
Loading

0 comments on commit 73a4b6b

Please sign in to comment.