From 538a721c8ef0b37b8d2f6b0b5b679fa886dbe990 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Mon, 29 Apr 2024 15:23:03 +0200 Subject: [PATCH 01/10] Add wandb support. --- configs/train_unrolledADMM.yaml | 1 + lensless/eval/benchmark.py | 32 +++++++++++++++++++------------- lensless/recon/utils.py | 17 ++++++++++++++++- lensless/utils/dataset.py | 4 ++++ recon_requirements.txt | 3 ++- scripts/recon/train_unrolled.py | 14 ++++++++++++++ 6 files changed, 56 insertions(+), 15 deletions(-) diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index d4998e11..0392d9d4 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -4,6 +4,7 @@ hydra: chdir: True # change to output folder +wandb_project: lensless seed: 0 start_delay: null diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 3bb7e25b..d521f610 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -116,6 +116,9 @@ def benchmark( if dataset.multimask: lensless, lensed, psfs = batch psfs = psfs.to(device) + else: + lensless, lensed = batch + psfs = None else: lensless, lensed = batch psfs = None @@ -198,24 +201,27 @@ def benchmark( .item() ) else: - if "LPIPS" in metric: - if prediction.shape[1] == 1: - # LPIPS needs 3 channels - metrics_values[metric].append( - metrics[metric]( - prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + try: + if "LPIPS" in metric: + if prediction.shape[1] == 1: + # LPIPS needs 3 channels + metrics_values[metric].append( + metrics[metric]( + prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + ) + .cpu() + .item() + ) + else: + metrics_values[metric].append( + metrics[metric](prediction, lensed).cpu().item() ) - .cpu() - .item() - ) else: metrics_values[metric].append( metrics[metric](prediction, lensed).cpu().item() ) - else: - metrics_values[metric].append( - metrics[metric](prediction, lensed).cpu().item() - ) + except Exception as e: + print(f"Error in metric {metric}: {e}") # compute metrics for unrolled output if unrolled_output_factor: diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 7c6a60fb..a7fa1ed2 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -6,7 +6,7 @@ # Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# - +import wandb import json import math import numpy as np @@ -302,6 +302,7 @@ def __init__( clip_grad=1.0, unrolled_output_factor=False, extra_eval_sets=None, + use_wandb=False, # for adding components during training pre_process=None, pre_process_delay=None, @@ -490,6 +491,8 @@ def __init__( self.optimizer_config = optimizer self.set_optimizer() + # metrics + self.use_wandb = use_wandb self.metrics = { "LOSS": [], # train loss "LOSS_TEST": [], # test loss @@ -802,6 +805,8 @@ def evaluate(self, mean_loss, epoch, disp=None): # update metrics with current metrics self.metrics["LOSS"].append(mean_loss) + if self.use_wandb: + wandb.log({"LOSS": mean_loss}, step=epoch) for key in current_metrics: self.metrics[key].append(current_metrics[key]) @@ -824,8 +829,11 @@ def evaluate(self, mean_loss, epoch, disp=None): eval_loss = current_metrics[self.metrics["metric_for_best_model"]] self.metrics["LOSS_TEST"].append(eval_loss) + if self.use_wandb: + wandb.log({"LOSS_TEST": eval_loss}, step=epoch) # add extra evaluation sets + extra_metrics_epoch = {} if self.extra_eval_sets is not None: for eval_set in self.extra_eval_sets: @@ -860,12 +868,19 @@ def evaluate(self, mean_loss, epoch, disp=None): self.metrics[eval_set][key] = [extra_metrics[key]] else: self.metrics[eval_set][key].append(extra_metrics[key]) + extra_metrics_epoch[f"{eval_set}_{key}"] = extra_metrics[key] # set back PSF to original in case changed # TODO: cleaner way? if not self.train_multimask: self.recon._set_psf(self.train_dataset.psf.to(self.device)) + # log metrics to wandb + if self.use_wandb: + wandb.log(current_metrics, step=epoch) + if self.extra_eval_sets is not None: + wandb.log(extra_metrics_epoch, step=epoch) + return eval_loss def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None): diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 803d4166..9308151c 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1031,15 +1031,19 @@ def __init__( save_psf=False, simulation_config=None, return_mask_label=False, + n_files=None, **kwargs, ): if isinstance(split, str): + if n_files is not None: + split = f"{split}[0:{n_files}]" self.dataset = load_dataset(huggingface_repo, split=split) elif isinstance(split, Dataset): self.dataset = split else: raise ValueError("split should be a string or a Dataset object") + self.rotate = rotate self.display_res = display_res self.return_mask_label = return_mask_label diff --git a/recon_requirements.txt b/recon_requirements.txt index 0a2ff942..309d71fe 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -9,4 +9,5 @@ waveprop>=0.0.10 # for simulation torch >= 2.0.0 torchvision torchmetrics -lpips \ No newline at end of file +lpips +wandb \ No newline at end of file diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 95057aca..73ee1760 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -32,6 +32,7 @@ """ +import wandb import logging import hydra from hydra.utils import get_original_cwd @@ -62,6 +63,15 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config): + if config.wandb_project is not None: + # start a new wandb run to track this script + wandb.init( + # set the wandb project where this run will be logged + project=config.wandb_project, + # track hyperparameters and run metadata + config=dict(config), + ) + # set seed seed = config.seed torch.manual_seed(seed) @@ -220,6 +230,7 @@ def train_unrolled(config): downsample=config.files.downsample, alignment=config.alignment, save_psf=config.files.save_psf, + n_files=config.files.n_files, ) test_set = DigiCam( huggingface_repo=config.files.dataset, @@ -230,6 +241,7 @@ def train_unrolled(config): downsample=config.files.downsample, alignment=config.alignment, save_psf=config.files.save_psf, + n_files=config.files.n_files, ) if train_set.multimask: # get first PSF for initialization @@ -278,6 +290,7 @@ def train_unrolled(config): extra_eval_sets[eval_set] = DigiCam( split="test", downsample=config.files.downsample, # needs to be same size + n_files=config.files.n_files, **config.files.extra_eval[eval_set], ) @@ -492,6 +505,7 @@ def train_unrolled(config): clip_grad=config.training.clip_grad, unrolled_output_factor=config.unrolled_output_factor, extra_eval_sets=extra_eval_sets if config.files.extra_eval is not None else None, + use_wandb=True if config.wandb_project is not None else False, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx) From 8ee5ce19bc4b41feced115e1952404b0e3933597 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 29 Apr 2024 15:47:06 +0000 Subject: [PATCH 02/10] Update setup. --- recon_requirements.txt | 3 ++- setup.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/recon_requirements.txt b/recon_requirements.txt index 309d71fe..78dd418d 100644 --- a/recon_requirements.txt +++ b/recon_requirements.txt @@ -10,4 +10,5 @@ torch >= 2.0.0 torchvision torchmetrics lpips -wandb \ No newline at end of file +wandb +datasets \ No newline at end of file diff --git a/setup.py b/setup.py index d1ab6d68..c3d0fc2c 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "rawpy>=0.16.0", # less than python 3.12 "paramiko>=3.2.0", "hydra-core", + "slm_controller @ git+https://github.com/ebezzam/slm-controller.git" ], extra_requires={"dev": ["pudb", "black"]}, ) From 0bb90d59cf795630cdbdbb44a191ebbe23db39ee Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Mon, 29 Apr 2024 15:49:45 +0000 Subject: [PATCH 03/10] Clean up configs. --- configs/train_digicam_multimask.yaml | 39 +-------------------------- configs/train_digicam_singlemask.yaml | 3 ++- configs/train_unrolledADMM.yaml | 1 + 3 files changed, 4 insertions(+), 39 deletions(-) diff --git a/configs/train_digicam_multimask.yaml b/configs/train_digicam_multimask.yaml index 6011f5f0..ffe3f3ce 100644 --- a/configs/train_digicam_multimask.yaml +++ b/configs/train_digicam_multimask.yaml @@ -1,9 +1,8 @@ # python scripts/recon/train_unrolled.py -cn train_digicam_multimask defaults: - - train_unrolledADMM + - train_digicam_singlemask - _self_ - torch_device: 'cuda:0' device_ids: [0, 1, 2, 3] eval_disp_idx: [1, 2, 4, 5, 9] @@ -11,12 +10,6 @@ eval_disp_idx: [1, 2, 4, 5, 9] # Dataset files: dataset: bezzam/DigiCam-Mirflickr-MultiMask-25K - huggingface_dataset: True - downsample: 1 - # TODO: these parameters should be in the dataset? - image_res: [900, 1200] # used during measurement - rotate: True # if measurement is upside-down - save_psf: False extra_eval: singlemask: @@ -26,33 +19,3 @@ files: alignment: topright: [80, 100] # height, width height: 200 - -# TODO: these parameters should be in the dataset? -alignment: - # when there is no downsampling - topright: [80, 100] # height, width - height: 200 - -training: - batch_size: 4 - epoch: 25 - eval_batch_size: 4 - -reconstruction: - method: unrolled_admm - unrolled_admm: - # Number of iterations - n_iter: 10 - # Hyperparameters - mu1: 1e-4 - mu2: 1e-4 - mu3: 1e-4 - tau: 2e-4 - pre_process: - network : UnetRes # UnetRes or DruNet or null - depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet - nc: [32,64,116,128] - post_process: - network : UnetRes # UnetRes or DruNet or null - depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet - nc: [32,64,116,128] \ No newline at end of file diff --git a/configs/train_digicam_singlemask.yaml b/configs/train_digicam_singlemask.yaml index 69b4e3a2..64ddcb66 100644 --- a/configs/train_digicam_singlemask.yaml +++ b/configs/train_digicam_singlemask.yaml @@ -12,11 +12,11 @@ files: dataset: bezzam/DigiCam-Mirflickr-SingleMask-25K huggingface_dataset: True downsample: 1 + # TODO: these parameters should be in the dataset? image_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down save_psf: False - # extra_eval: null extra_eval: multimask: huggingface_repo: bezzam/DigiCam-Mirflickr-MultiMask-25K @@ -26,6 +26,7 @@ files: topright: [80, 100] # height, width height: 200 +# TODO: these parameters should be in the dataset? alignment: # when there is no downsampling topright: [80, 100] # height, width diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 0392d9d4..70e58b40 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -35,6 +35,7 @@ files: torch: True torch_device: 'cuda' +device_ids: null # for multi-gpu set list, e.g. [0, 1, 2, 3] measure: null # if measuring data on-the-fly # test set example to visualize at the end of every epoch From 3baab47d0093f37c8321802d24c9cbc07488673b Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Mon, 29 Apr 2024 15:50:34 +0000 Subject: [PATCH 04/10] Log reconstruction images. --- lensless/eval/benchmark.py | 12 +++++++++++- lensless/recon/utils.py | 7 ++++++- scripts/recon/train_unrolled.py | 4 +++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index d521f610..f03ba57f 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -13,6 +13,7 @@ from tqdm import tqdm import os import numpy as np +import wandb try: import torch @@ -37,6 +38,9 @@ def benchmark( unrolled_output_factor=False, return_average=True, snr=None, + use_wandb=False, + label=None, + epoch=None, **kwargs, ): """ @@ -179,7 +183,13 @@ def benchmark( prediction_np = prediction.cpu().numpy()[i] # switch to [H, W, C] for saving prediction_np = np.moveaxis(prediction_np, 0, -1) - save_image(prediction_np, fp=os.path.join(output_dir, f"{_batch_idx}.png")) + fp = os.path.join(output_dir, f"{_batch_idx}.png") + save_image(prediction_np, fp=fp) + + if use_wandb: + assert epoch is not None, "epoch must be provided for wandb logging" + log_key = f"{_batch_idx}_{label}" if label is not None else f"{_batch_idx}" + wandb.log({log_key: wandb.Image(fp)}, step=epoch) # normalization prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index a7fa1ed2..282380c2 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -801,6 +801,8 @@ def evaluate(self, mean_loss, epoch, disp=None): output_dir=output_dir, crop=self.crop, unrolled_output_factor=self.unrolled_output_factor, + use_wandb=self.use_wandb, + epoch=epoch, ) # update metrics with current metrics @@ -860,6 +862,9 @@ def evaluate(self, mean_loss, epoch, disp=None): output_dir=output_dir, crop=self.crop, unrolled_output_factor=self.unrolled_output_factor, + use_wandb=self.use_wandb, + label=eval_set, + epoch=epoch, ) # add metrics to dictionary @@ -944,7 +949,7 @@ def train(self, n_epoch=1, save_pt=None, disp=None): start_time = time.time() - self.evaluate(-1, epoch=0, disp=disp) + self.evaluate(mean_loss=1, epoch=0, disp=disp) for epoch in range(n_epoch): # add extra components (if specified) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 73ee1760..d2111340 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -93,13 +93,15 @@ def train_unrolled(config): use_cuda = False if "cuda" in config.torch_device and torch.cuda.is_available(): # if config.torch_device == "cuda" and torch.cuda.is_available(): - log.info("Using GPU for training.") + log.info(f"Using GPU for training. Main device : {config.torch_device}") device = config.torch_device use_cuda = True else: log.info("Using CPU for training.") device = "cpu" device_ids = config.device_ids + if device_ids is not None: + log.info(f"Using multiple GPUs : {device_ids}") # load dataset and create dataloader train_set = None From 0fc2aebd4ec6d3ec20b0f55e916966da191d740a Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Mon, 29 Apr 2024 16:00:31 +0000 Subject: [PATCH 05/10] Rename training script. --- CHANGELOG.rst | 2 +- README.rst | 2 +- configs/fine-tune_PSF.yaml | 2 +- configs/train_celeba_digicam_hitl.yaml | 2 +- configs/train_celeba_digicam_mask.yaml | 2 +- configs/train_coded_aperture.yaml | 2 +- configs/train_digicam_celeba.yaml | 2 +- configs/train_digicam_multimask.yaml | 2 +- configs/train_digicam_singlemask.yaml | 2 +- configs/train_pre-post-processing.yaml | 2 +- configs/train_psf_from_scratch.yaml | 2 +- configs/train_unrolledADMM.yaml | 2 +- lensless/utils/dataset.py | 2 +- .../{train_unrolled.py => train_learning_based.py} | 12 ++++++------ 14 files changed, 19 insertions(+), 19 deletions(-) rename scripts/recon/{train_unrolled.py => train_learning_based.py} (98%) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 029c6d22..cd1c9754 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -151,7 +151,7 @@ Added - Option to warm-start reconstruction algorithm with ``initial_est``. - TrainableReconstructionAlgorithm class inherited from ReconstructionAlgorithm and torch.module for use with pytorch autograd and optimizers. - Unrolled version of FISTA and ADMM as TrainableReconstructionAlgorithm with learnable parameters. -- ``train_unrolled.py`` script for training unrolled algorithms. +- ``train_learning_based.py`` script for training unrolled algorithms. - ``benchmark_recon.py`` script for benchmarking and comparing reconstruction algorithms. - Added ``reconstruction_error`` to ``ReconstructionAlgorithm`` . - Added support for npy/npz image in load_image. diff --git a/README.rst b/README.rst index 419295ad..532e0940 100644 --- a/README.rst +++ b/README.rst @@ -45,7 +45,7 @@ The toolkit includes: * Measurement scripts (`link `__). * Dataset preparation and loading tools, with `Hugging Face `__ integration (`slides `__ on uploading a dataset to Hugging Face with `this script `__). * `Reconstruction algorithms `__ (e.g. FISTA, ADMM, unrolled algorithms, trainable inversion, pre- and post-processors). -* `Training script `__ for learning-based reconstruction. +* `Training script `__ for learning-based reconstruction. * `Pre-trained models `__ that can be loaded from `Hugging Face `__, for example in `this script `__. * Mask `design `__ and `fabrication `__ tools. * `Simulation tools `__. diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml index c7ff09c9..3c9f50e3 100644 --- a/configs/fine-tune_PSF.yaml +++ b/configs/fine-tune_PSF.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn fine-tune_PSF +# python scripts/recon/train_learning_based.py -cn fine-tune_PSF defaults: - train_unrolledADMM - _self_ diff --git a/configs/train_celeba_digicam_hitl.yaml b/configs/train_celeba_digicam_hitl.yaml index 8129973b..046c5962 100644 --- a/configs/train_celeba_digicam_hitl.yaml +++ b/configs/train_celeba_digicam_hitl.yaml @@ -1,7 +1,7 @@ # Learn mask with HITL training by setting measure configuration (set to null for learning in simulation) # # EXAMPLE COMMAND: -# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_hitl measure.rpi_username=USERNAME measure.rpi_hostname=HOSTNAME files.vertical_shift=SHIFT +# python scripts/recon/train_learning_based.py -cn train_celeba_digicam_hitl measure.rpi_username=USERNAME measure.rpi_hostname=HOSTNAME files.vertical_shift=SHIFT defaults: - train_celeba_digicam diff --git a/configs/train_celeba_digicam_mask.yaml b/configs/train_celeba_digicam_mask.yaml index 8dfd7f73..ba34ed46 100644 --- a/configs/train_celeba_digicam_mask.yaml +++ b/configs/train_celeba_digicam_mask.yaml @@ -1,5 +1,5 @@ # fine-tune mask for PSF, but don't re-simulate -# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_mask +# python scripts/recon/train_learning_based.py -cn train_celeba_digicam_mask defaults: - train_celeba_digicam - _self_ diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml index ea39b6ab..a0889435 100644 --- a/configs/train_coded_aperture.yaml +++ b/configs/train_coded_aperture.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_coded_aperture +# python scripts/recon/train_learning_based.py -cn train_coded_aperture defaults: - train_unrolledADMM - _self_ diff --git a/configs/train_digicam_celeba.yaml b/configs/train_digicam_celeba.yaml index 4a7d5028..973d13f5 100644 --- a/configs/train_digicam_celeba.yaml +++ b/configs/train_digicam_celeba.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_digicam_singlemask +# python scripts/recon/train_learning_based.py -cn train_digicam_singlemask defaults: - train_unrolledADMM - _self_ diff --git a/configs/train_digicam_multimask.yaml b/configs/train_digicam_multimask.yaml index ffe3f3ce..4ce73215 100644 --- a/configs/train_digicam_multimask.yaml +++ b/configs/train_digicam_multimask.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_digicam_multimask +# python scripts/recon/train_learning_based.py -cn train_digicam_multimask defaults: - train_digicam_singlemask - _self_ diff --git a/configs/train_digicam_singlemask.yaml b/configs/train_digicam_singlemask.yaml index 64ddcb66..f284385d 100644 --- a/configs/train_digicam_singlemask.yaml +++ b/configs/train_digicam_singlemask.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_digicam_singlemask +# python scripts/recon/train_learning_based.py -cn train_digicam_singlemask defaults: - train_unrolledADMM - _self_ diff --git a/configs/train_pre-post-processing.yaml b/configs/train_pre-post-processing.yaml index 86b95e86..9c0da345 100644 --- a/configs/train_pre-post-processing.yaml +++ b/configs/train_pre-post-processing.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_pre-post-processing +# python scripts/recon/train_learning_based.py -cn train_pre-post-processing defaults: - train_unrolledADMM - _self_ diff --git a/configs/train_psf_from_scratch.yaml b/configs/train_psf_from_scratch.yaml index 82586751..2def76c4 100644 --- a/configs/train_psf_from_scratch.yaml +++ b/configs/train_psf_from_scratch.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +# python scripts/recon/train_learning_based.py -cn train_psf_from_scratch defaults: - train_unrolledADMM - _self_ diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 70e58b40..ff0e7a3f 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_unrolled.py +# python scripts/recon/train_learning_based.py hydra: job: chdir: True # change to output folder diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 9308151c..10414f2a 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1269,7 +1269,7 @@ def simulate_dataset(config, generator=None): Parameters ---------- config : omegaconf.DictConfig - Configuration, e.g. from Hydra. See ``scripts/recon/train_unrolled.py`` for an example that uses this function. + Configuration, e.g. from Hydra. See ``scripts/recon/train_learning_based.py`` for an example that uses this function. generator : torch.Generator, optional Random number generator, by default ``None``. """ diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_learning_based.py similarity index 98% rename from scripts/recon/train_unrolled.py rename to scripts/recon/train_learning_based.py index d2111340..d4db0837 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_learning_based.py @@ -1,6 +1,6 @@ # ############################################################################# -# train_unrolled.py -# ================= +# train_learning_based.py +# ======================= # Authors : # Yohann PERRON [yohann.perron@gmail.com] # Eric BEZZAM [ebezzam@gmail.com] @@ -10,24 +10,24 @@ Train unrolled version of reconstruction algorithm. ``` -python scripts/recon/train_unrolled.py +python scripts/recon/train_learning_based.py ``` By default it uses the configuration from the file `configs/train_unrolledADMM.yaml`. To train pre- and post-processing networks, use the following command: ``` -python scripts/recon/train_unrolled.py -cn train_pre-post-processing +python scripts/recon/train_learning_based.py -cn train_pre-post-processing ``` To fine-tune the DiffuserCam PSF, use the following command: ``` -python scripts/recon/train_unrolled.py -cn fine-tune_PSF +python scripts/recon/train_learning_based.py -cn fine-tune_PSF ``` To train a PSF from scratch with a simulated dataset, use the following command: ``` -python scripts/recon/train_unrolled.py -cn train_psf_from_scratch +python scripts/recon/train_learning_based.py -cn train_psf_from_scratch ``` """ From 0ccec8ecb67b863ac2e2c41658273bce2ce1ebeb Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 30 Apr 2024 13:39:35 +0000 Subject: [PATCH 06/10] Clean up configs and log PSF to wandb. --- configs/fine-tune_PSF.yaml | 20 +----------- configs/train_pre-post-processing.yaml | 24 -------------- configs/train_psf_from_scratch.yaml | 8 +++-- configs/train_unrolledADMM.yaml | 32 ++++++++++++------- configs/train_unrolled_pre_post.yaml | 14 ++++++++ lensless/recon/utils.py | 28 ++++++++++++---- lensless/utils/dataset.py | 22 ++++++++++--- lensless/utils/io.py | 2 +- scripts/recon/train_learning_based.py | 44 ++++++-------------------- 9 files changed, 90 insertions(+), 104 deletions(-) delete mode 100644 configs/train_pre-post-processing.yaml create mode 100644 configs/train_unrolled_pre_post.yaml diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml index 3c9f50e3..d0835cba 100644 --- a/configs/fine-tune_PSF.yaml +++ b/configs/fine-tune_PSF.yaml @@ -12,25 +12,7 @@ trainable_mask: #Training training: - save_every: 10 - epoch: 50 - crop_preloss: False + save_every: 1 # to see how PSF evolves display: gamma: 2.2 - -reconstruction: - method: unrolled_admm - - pre_process: - network: UnetRes - depth: 2 - post_process: - network: DruNet - depth: 4 - -optimizer: - slow_start: 0.01 - -loss: l2 -lpips: 1.0 diff --git a/configs/train_pre-post-processing.yaml b/configs/train_pre-post-processing.yaml deleted file mode 100644 index 9c0da345..00000000 --- a/configs/train_pre-post-processing.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# python scripts/recon/train_learning_based.py -cn train_pre-post-processing -defaults: - - train_unrolledADMM - - _self_ - -reconstruction: - method: unrolled_admm - - pre_process: - network: UnetRes - depth: 2 - post_process: - network: DruNet - depth: 4 - -training: - epoch: 50 - crop_preloss: False - -optimizer: - slow_start: 0.01 - -loss: l2 -lpips: 1.0 diff --git a/configs/train_psf_from_scratch.yaml b/configs/train_psf_from_scratch.yaml index 2def76c4..8e1b0543 100644 --- a/configs/train_psf_from_scratch.yaml +++ b/configs/train_psf_from_scratch.yaml @@ -6,6 +6,10 @@ defaults: # Train Dataset files: dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + huggingface_dataset: False + n_files: 1000 + test_size: 0.15 + celeba_root: /scratch/bezzam downsample: 8 @@ -24,8 +28,6 @@ simulation: object_height: 0.30 training: - crop_preloss: False # crop region for computing loss - batch_size: 8 - epoch: 25 + batch_size: 2 eval_batch_size: 16 save_every: 5 diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index ff0e7a3f..e3e70cd7 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -10,29 +10,38 @@ start_delay: null # Dataset files: - dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" - celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html - psf: data/psf/diffusercam_psf.tiff - diffusercam_psf: True + # dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" + # celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + # psf: data/psf/diffusercam_psf.tiff + # diffusercam_psf: True - huggingface_dataset: null - huggingface_psf: null + dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM + huggingface_dataset: True + huggingface_psf: psf.tiff + + # -- train/test split split_seed: null # if null use train/test split from dataset - n_files: null # null to use all for both train/test - downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution - test_size: 0.15 + test_size: null + # -- processing parameters + downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution + downsample_lensed: 2 input_snr: null # adding shot noise at input (for measured dataset) at this SNR in dB vertical_shift: null horizontal_shift: null + rotate: False + save_psf: False crop: null # vertical: null # horizontal: null image_res: null # for measured data, what resolution used at screen - extra_eval: null # dict of extra datasets to evaluate on +alignment: null +# topright: null # height, width + # height: null + torch: True torch_device: 'cuda' device_ids: null # for multi-gpu set list, e.g. [0, 1, 2, 3] @@ -132,14 +141,13 @@ simulation: training: batch_size: 8 - epoch: 50 + epoch: 25 eval_batch_size: 10 metric_for_best_model: null # e.g. LPIPS_Vgg, null does test loss save_every: null #In case of instable training skip_NAN: True clip_grad: 1.0 - crop_preloss: False # crop region for computing loss, files.crop should be set optimizer: diff --git a/configs/train_unrolled_pre_post.yaml b/configs/train_unrolled_pre_post.yaml new file mode 100644 index 00000000..82a8b794 --- /dev/null +++ b/configs/train_unrolled_pre_post.yaml @@ -0,0 +1,14 @@ +# python scripts/recon/train_learning_based.py -cn train_unrolled_pre_post +defaults: + - train_unrolledADMM + - _self_ + +reconstruction: + method: unrolled_admm + + pre_process: + network: UnetRes + depth: 2 + post_process: + network: UnetRes + depth: 2 diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 282380c2..1bc44369 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -383,6 +383,8 @@ def __init__( """ global print + self.use_wandb = use_wandb + self.device = recon._psf.device self.logger = logger if self.logger is not None: @@ -441,6 +443,7 @@ def __init__( self.simulated_dataset_trainable_mask = True self.mask = mask + self.gamma = gamma if mask is not None: assert isinstance(mask, TrainableMask) self.use_mask = True @@ -450,11 +453,18 @@ def __init__( # save original PSF psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] psf_np = psf_np.squeeze() # remove (potential) singleton color channel - np.save(os.path.join("psf_original.npy"), psf_np) - save_image(psf_np, os.path.join("psf_original.png")) + np.save("psf_original.npy", psf_np) + fp = "psf_original.png" + save_image(psf_np, fp) + plot_image(psf_np, gamma=self.gamma) + fp_plot = "psf_original_plot.png" + plt.savefig(fp_plot) + + if self.use_wandb: + wandb.log({"psf": wandb.Image(fp)}, step=0) + wandb.log({"psf_plot": wandb.Image(fp_plot)}, step=0) self.l1_mask = l1_mask - self.gamma = gamma # loss if loss == "l2": @@ -492,7 +502,6 @@ def __init__( self.set_optimizer() # metrics - self.use_wandb = use_wandb self.metrics = { "LOSS": [], # train loss "LOSS_TEST": [], # test loss @@ -1017,9 +1026,16 @@ def save(self, epoch, path="recon", include_optimizer=False): psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...] psf_np = psf_np.squeeze() # remove (potential) singleton color channel np.save(os.path.join(path, f"psf_epoch{epoch}.npy"), psf_np) - save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png")) + fp = os.path.join(path, f"psf_epoch{epoch}.png") + save_image(psf_np, fp) plot_image(psf_np, gamma=self.gamma) - plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png")) + fp_plot = os.path.join(path, f"psf_epoch{epoch}_plot.png") + plt.savefig(fp_plot) + + if self.use_wandb and epoch!="BEST": + wandb.log({"psf": wandb.Image(fp)}, step=epoch) + wandb.log({"psf_plot": wandb.Image(fp_plot)}, step=epoch) + if epoch == "BEST": # save difference with original PSF psf_original = np.load("psf_original.npy") diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 10414f2a..5729484c 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1025,8 +1025,9 @@ def __init__( display_res=None, sensor="rpi_hq", slm="adafruit", - rotate=False, + rotate=False, # just the lensless image downsample=1, + downsample_lensed=1, alignment=None, save_psf=False, simulation_config=None, @@ -1048,14 +1049,18 @@ def __init__( self.display_res = display_res self.return_mask_label = return_mask_label - # deduce downsampling factor from measurement + # deduce downsampling factor from the first image data_0 = self.dataset[0] self.downsample_lensless = downsample + self.downsample_lensed = downsample_lensed lensless = np.array(data_0["lensless"]) if self.downsample_lensless != 1.0: lensless = resize(lensless, factor=1 / self.downsample_lensless) - sensor_res = sensor_dict[sensor][SensorParam.RESOLUTION] - downsample_fact = min(sensor_res / lensless.shape[:2]) + if psf is None: + sensor_res = sensor_dict[sensor][SensorParam.RESOLUTION] + downsample_fact = min(sensor_res / lensless.shape[:2]) + else: + downsample_fact = 1 # deduce recon shape from original image self.alignment = None @@ -1089,7 +1094,7 @@ def __init__( psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset") psf, _ = load_psf( psf_fp, - downsample=downsample_fact, + shape=lensless.shape, return_float=True, return_bg=True, flip=rotate, @@ -1191,6 +1196,7 @@ def _get_images_pair(self, idx): lensless_np, factor=1 / self.downsample_lensless, interpolation=cv2.INTER_NEAREST ) + lensless = lensless_np lensed = lensed_np if self.simulator is not None: @@ -1217,6 +1223,12 @@ def _get_images_pair(self, idx): lensed = resize( lensed_np, shape=(*self.display_res, 3), interpolation=cv2.INTER_NEAREST ) + elif self.downsample_lensed != 1.0: + lensed = resize( + lensed_np, + factor=self.downsample_lensed, + interpolation=cv2.INTER_NEAREST, + ) return lensless, lensed diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 373fbc01..a51feff5 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -345,7 +345,7 @@ def load_psf( bg = np.array(bg) # resize - if downsample != 1: + if downsample != 1 or shape is not None: psf = resize(psf, shape=shape, factor=1 / downsample) if single_psf: diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index d4db0837..3f99049c 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -17,7 +17,7 @@ To train pre- and post-processing networks, use the following command: ``` -python scripts/recon/train_learning_based.py -cn train_pre-post-processing +python scripts/recon/train_learning_based.py -cn train_unrolled_pre_post ``` To fine-tune the DiffuserCam PSF, use the following command: @@ -25,10 +25,6 @@ python scripts/recon/train_learning_based.py -cn fine-tune_PSF ``` -To train a PSF from scratch with a simulated dataset, use the following command: -``` -python scripts/recon/train_learning_based.py -cn train_psf_from_scratch -``` """ @@ -110,7 +106,7 @@ def train_unrolled(config): crop = None alignment = None # very similar to crop, TODO: should switch to this approach mask = None - if "DiffuserCam" in config.files.dataset: + if "DiffuserCam" in config.files.dataset and config.files.huggingface_dataset is False: original_path = os.path.join(get_original_cwd(), config.files.dataset) psf_path = os.path.join(get_original_cwd(), config.files.psf) @@ -134,15 +130,6 @@ def train_unrolled(config): # -- if learning mask mask = prep_trainable_mask(config, dataset.psf) - if mask is not None: - # plot initial PSF - psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] - if config.trainable_mask.grayscale: - psf_np = psf_np[:, :, -1] - - save_image(psf_np, os.path.join(save, "psf_initial.png")) - plot_image(psf_np, gamma=config.display.gamma) - plt.savefig(os.path.join(save, "psf_initial_plot.png")) psf = dataset.psf @@ -179,17 +166,7 @@ def train_unrolled(config): mask = prep_trainable_mask(config, dataset.psf, downsample=downsample) if mask is not None: - # plot initial PSF - with torch.no_grad(): - psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] - if config.trainable_mask.grayscale: - psf_np = psf_np[:, :, -1] - - save_image(psf_np, os.path.join(save, "psf_initial.png")) - plot_image(psf_np, gamma=config.display.gamma) - plt.savefig(os.path.join(save, "psf_initial_plot.png")) - - # save original PSF as well + # save original PSF psf_meas = dataset.psf.detach().cpu().numpy()[0, ...] plot_image(psf_meas, gamma=config.display.gamma) plt.savefig(os.path.join(save, "psf_meas_plot.png")) @@ -230,6 +207,7 @@ def train_unrolled(config): display_res=config.files.image_res, rotate=config.files.rotate, downsample=config.files.downsample, + downsample_lensed=config.files.downsample_lensed, alignment=config.alignment, save_psf=config.files.save_psf, n_files=config.files.n_files, @@ -241,6 +219,7 @@ def train_unrolled(config): display_res=config.files.image_res, rotate=config.files.rotate, downsample=config.files.downsample, + downsample_lensed=config.files.downsample_lensed, alignment=config.alignment, save_psf=config.files.save_psf, n_files=config.files.n_files, @@ -258,14 +237,6 @@ def train_unrolled(config): mask = prep_trainable_mask(config, psf) if mask is not None: assert not train_set.multimask - # plot initial PSF - psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] - if config.trainable_mask.grayscale: - psf_np = psf_np[:, :, -1] - - save_image(psf_np, os.path.join(save, "psf_initial.png")) - plot_image(psf_np, gamma=config.display.gamma) - plt.savefig(os.path.join(save, "psf_initial_plot.png")) else: @@ -273,6 +244,11 @@ def train_unrolled(config): psf = train_set.psf crop = train_set.crop + if not hasattr(train_set, "multimask"): + train_set.multimask = False + if not hasattr(test_set, "multimask"): + test_set.multimask = False + assert train_set is not None # if not hasattr(test_set, "psfs"): # assert psf is not None From 73a4b6b72c606474c368dfb37860fcca99dd1cd9 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 30 Apr 2024 23:09:14 +0000 Subject: [PATCH 07/10] Make sure examples work. --- configs/train_digicam_celeba.yaml | 7 ++-- configs/train_digicam_multimask.yaml | 41 +++++++++++++++++++- configs/train_digicam_singlemask.yaml | 1 + configs/train_unrolledADMM.yaml | 2 + docs/source/dataset.rst | 36 ++++++++++-------- lensless/utils/dataset.py | 54 ++++++++++++++++++++++----- scripts/data/authenticate.py | 8 ++-- scripts/eval/benchmark_recon.py | 4 +- scripts/recon/dataset.py | 4 +- scripts/recon/digicam_mirflickr.py | 4 +- scripts/recon/train_learning_based.py | 26 ++++++++----- 11 files changed, 138 insertions(+), 49 deletions(-) diff --git a/configs/train_digicam_celeba.yaml b/configs/train_digicam_celeba.yaml index 973d13f5..b2724dc9 100644 --- a/configs/train_digicam_celeba.yaml +++ b/configs/train_digicam_celeba.yaml @@ -1,4 +1,4 @@ -# python scripts/recon/train_learning_based.py -cn train_digicam_singlemask +# python scripts/recon/train_learning_based.py -cn train_digicam_celeba defaults: - train_unrolledADMM - _self_ @@ -13,6 +13,7 @@ files: huggingface_psf: "psf_simulated.png" huggingface_dataset: True split_seed: 0 + test_size: 0.15 downsample: 2 rotate: True # if measurement is upside-down save_psf: False @@ -34,14 +35,14 @@ alignment: random_vflip: False random_hflip: False quantize: False - # shifting when there is no files.downsample + # shifting when there is no files to downsample vertical_shift: -117 horizontal_shift: -25 training: batch_size: 4 epoch: 25 - eval_batch_size: 4 + eval_batch_size: 16 crop_preloss: True reconstruction: diff --git a/configs/train_digicam_multimask.yaml b/configs/train_digicam_multimask.yaml index 4ce73215..e05dda06 100644 --- a/configs/train_digicam_multimask.yaml +++ b/configs/train_digicam_multimask.yaml @@ -1,15 +1,23 @@ # python scripts/recon/train_learning_based.py -cn train_digicam_multimask defaults: - - train_digicam_singlemask + - train_unrolledADMM - _self_ torch_device: 'cuda:0' device_ids: [0, 1, 2, 3] eval_disp_idx: [1, 2, 4, 5, 9] + # Dataset files: dataset: bezzam/DigiCam-Mirflickr-MultiMask-25K + huggingface_dataset: True + huggingface_psf: null + downsample: 1 + # TODO: these parameters should be in the dataset? + image_res: [900, 1200] # used during measurement + rotate: True # if measurement is upside-down + save_psf: False extra_eval: singlemask: @@ -19,3 +27,34 @@ files: alignment: topright: [80, 100] # height, width height: 200 + +# TODO: these parameters should be in the dataset? +alignment: + # when there is no downsampling + topright: [80, 100] # height, width + height: 200 + +training: + batch_size: 4 + epoch: 25 + eval_batch_size: 4 + +reconstruction: + method: unrolled_admm + unrolled_admm: + # Number of iterations + n_iter: 10 + # Hyperparameters + mu1: 1e-4 + mu2: 1e-4 + mu3: 1e-4 + tau: 2e-4 + pre_process: + network : UnetRes # UnetRes or DruNet or null + depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: [32,64,116,128] + post_process: + network : UnetRes # UnetRes or DruNet or null + depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: [32,64,116,128] + diff --git a/configs/train_digicam_singlemask.yaml b/configs/train_digicam_singlemask.yaml index f284385d..932d68a8 100644 --- a/configs/train_digicam_singlemask.yaml +++ b/configs/train_digicam_singlemask.yaml @@ -11,6 +11,7 @@ eval_disp_idx: [1, 2, 4, 5, 9] files: dataset: bezzam/DigiCam-Mirflickr-SingleMask-25K huggingface_dataset: True + huggingface_psf: null downsample: 1 # TODO: these parameters should be in the dataset? image_res: [900, 1200] # used during measurement diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index e3e70cd7..47fba326 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -10,11 +10,13 @@ start_delay: null # Dataset files: + # -- using local dataset # dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" # celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html # psf: data/psf/diffusercam_psf.tiff # diffusercam_psf: True + # -- using huggingface dataset dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM huggingface_dataset: True huggingface_psf: psf.tiff diff --git a/docs/source/dataset.rst b/docs/source/dataset.rst index ad21defb..0a8c503b 100644 --- a/docs/source/dataset.rst +++ b/docs/source/dataset.rst @@ -19,6 +19,26 @@ or measured). :special-members: __init__, __len__ +Measured dataset objects +------------------------ + +.. autoclass:: lensless.utils.dataset.HFDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDataset + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal + :members: + :special-members: __init__ + +.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset + :members: + :special-members: __init__ + + Simulated dataset objects ------------------------- @@ -43,19 +63,3 @@ mask / PSF. .. autoclass:: lensless.utils.dataset.SimulatedDatasetTrainableMask :members: :special-members: __init__ - - -Measured dataset objects ------------------------- - -.. autoclass:: lensless.utils.dataset.MeasuredDataset - :members: - :special-members: __init__ - -.. autoclass:: lensless.utils.dataset.MeasuredDatasetSimulatedOriginal - :members: - :special-members: __init__ - -.. autoclass:: lensless.utils.dataset.DiffuserCamTestDataset - :members: - :special-members: __init__ diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 5729484c..b758337a 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1016,25 +1016,59 @@ def _get_images_pair(self, idx): return lensless, lensed -class DigiCam(DualDataset): +class HFDataset(DualDataset): def __init__( self, huggingface_repo, split, + n_files=None, psf=None, - display_res=None, - sensor="rpi_hq", - slm="adafruit", rotate=False, # just the lensless image downsample=1, downsample_lensed=1, + display_res=None, + sensor="rpi_hq", + slm="adafruit", alignment=None, - save_psf=False, - simulation_config=None, return_mask_label=False, - n_files=None, + save_psf=False, **kwargs, ): + """ + Wrapper for lensless datasets on Hugging Face. + + Parameters + ---------- + huggingface_repo : str + Hugging Face repository ID. + split : str or :py:class:`torch.utils.data.Dataset` + Split of the dataset to use: 'train', 'test', or 'all'. If a Dataset object is given, it is used directly. + n_files : int, optional + Number of files to load from the dataset, by default None, namely all. + psf : str, optional + File name of the PSF at the repository. If None, it is assumed that there is a mask pattern from which the PSF is simulated, by default None. + rotate : bool, optional + If True, lensless images and PSF are rotated 180 degrees. Lensed/original image is not rotated! By default False. + downsample : float, optional + Downsample factor of the lensless images, by default 1. + downsample_lensed : float, optional + Downsample factor of the lensed images, by default 1. + display_res : tuple, optional + Resolution of images when displayed on screen during measurement. + sensor : str, optional + If `psf` not provided, the sensor to use for the PSF simulation, by default "rpi_hq". + slm : str, optional + If `psf` not provided, the SLM to use for the PSF simulation, by default "adafruit". + alignment : dict, optional + Alignment parameters between lensless and lensed data. + If "topright", "height", and "width" are provided, the region-of-interest from the reconstruction of ``lensless`` is extracted and ``lensed`` is reshaped to match. + If "crop" is provided, the region-of-interest is extracted from the simulated lensed image, namely a ``simulation`` configuration should be provided within ``alignment``. + return_mask_label : bool, optional + If multimask dataset, return the mask label (True) or the corresponding PSF (False). + save_psf : bool, optional + If multimask dataset, save the simulated PSFs. + + """ if isinstance(split, str): if n_files is not None: @@ -1080,6 +1114,7 @@ def __init__( # preparing ground-truth as simulated measurement of original elif "crop" in alignment: + assert "simulation" in alignment, "Simulation config should be provided" self.crop = dict(alignment["crop"].copy()) self.crop["vertical"][0] = int(self.crop["vertical"][0] / downsample) self.crop["vertical"][1] = int(self.crop["vertical"][1] / downsample) @@ -1170,7 +1205,7 @@ def __init__( if "horizontal_shift" in simulation_config: self.horizontal_shift = int(simulation_config["horizontal_shift"] / downsample) - super(DigiCam, self).__init__(**kwargs) + super(HFDataset, self).__init__(**kwargs) def __len__(self): return len(self.dataset) @@ -1196,7 +1231,6 @@ def _get_images_pair(self, idx): lensless_np, factor=1 / self.downsample_lensless, interpolation=cv2.INTER_NEAREST ) - lensless = lensless_np lensed = lensed_np if self.simulator is not None: @@ -1226,7 +1260,7 @@ def _get_images_pair(self, idx): elif self.downsample_lensed != 1.0: lensed = resize( lensed_np, - factor=self.downsample_lensed, + factor=1 / self.downsample_lensed, interpolation=cv2.INTER_NEAREST, ) diff --git a/scripts/data/authenticate.py b/scripts/data/authenticate.py index 14f1d97b..9f71819c 100644 --- a/scripts/data/authenticate.py +++ b/scripts/data/authenticate.py @@ -29,7 +29,7 @@ """ -from lensless.utils.dataset import DigiCam +from lensless.utils.dataset import HFDataset import torch from lensless import ADMM from lensless.utils.image import rgb2gray @@ -67,14 +67,14 @@ def authen(config): # load multimask dataset if split == "all": - train_set = DigiCam( + train_set = HFDataset( huggingface_repo=huggingface_repo, split="train", rotate=rotate, downsample=downsample, return_mask_label=True, ) - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=huggingface_repo, split="test", rotate=rotate, @@ -114,7 +114,7 @@ def authen(config): file_idx += list(np.arange(n_train_psf) + i * n_train_psf + test_files_offet) else: - all_set = DigiCam( + all_set = HFDataset( huggingface_repo=huggingface_repo, split=split, rotate=rotate, diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 1e45971d..ece0bcfa 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -26,7 +26,7 @@ from lensless.eval.benchmark import benchmark import matplotlib.pyplot as plt from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent -from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, DigiCam +from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset from lensless.utils.io import save_image import torch @@ -85,7 +85,7 @@ def benchmark_recon(config): dataset, [train_size, test_size], generator=generator ) elif dataset == "DigiCamHF": - benchmark_dataset = DigiCam( + benchmark_dataset = HFDataset( huggingface_repo=config.huggingface.repo, split="test", display_res=config.huggingface.image_res, diff --git a/scripts/recon/dataset.py b/scripts/recon/dataset.py index e14f4ecd..906508db 100644 --- a/scripts/recon/dataset.py +++ b/scripts/recon/dataset.py @@ -35,7 +35,7 @@ from tqdm import tqdm from joblib import Parallel, delayed import numpy as np -from lensless.utils.dataset import DiffuserCamMirflickrHF, DigiCam +from lensless.utils.dataset import DiffuserCamMirflickrHF, HFDataset from lensless.eval.metric import psnr, lpips from lensless.utils.image import resize @@ -47,7 +47,7 @@ def recon_dataset(config): if config.dataset == "diffusercam": dataset = DiffuserCamMirflickrHF(split=config.split, downsample=config.downsample) else: - dataset = DigiCam( + dataset = HFDataset( huggingface_repo=config.dataset, split=config.split, downsample=config.downsample, diff --git a/scripts/recon/digicam_mirflickr.py b/scripts/recon/digicam_mirflickr.py index 88a6a036..60411fd0 100644 --- a/scripts/recon/digicam_mirflickr.py +++ b/scripts/recon/digicam_mirflickr.py @@ -3,7 +3,7 @@ import torch from lensless import ADMM from lensless.utils.plot import plot_image -from lensless.utils.dataset import DigiCam +from lensless.utils.dataset import HFDataset import os from lensless.utils.io import save_image import time @@ -35,7 +35,7 @@ def apply_pretrained(config): model_config = yaml.safe_load(stream) # load dataset - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=model_config["files"]["dataset"], psf=model_config["files"]["huggingface_psf"] if "huggingface_psf" in model_config["files"] diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 3f99049c..9ad7a016 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -40,7 +40,7 @@ from lensless.utils.dataset import ( DiffuserCamMirflickr, DigiCamCelebA, - DigiCam, + HFDataset, MyDataParallel, simulate_dataset, ) @@ -57,7 +57,7 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") -def train_unrolled(config): +def train_learned(config): if config.wandb_project is not None: # start a new wandb run to track this script @@ -189,8 +189,13 @@ def train_unrolled(config): generator = torch.Generator().manual_seed(seed) # - combine train and test into single dataset - train_dataset = load_dataset(config.files.dataset, split="train") - test_dataset = load_dataset(config.files.dataset, split="test") + train_split = "train" + test_split = "test" + if config.files.n_files is not None: + train_split = f"train[:{config.files.n_files}]" + test_split = f"test[:{config.files.n_files}]" + train_dataset = load_dataset(config.files.dataset, split=train_split) + test_dataset = load_dataset(config.files.dataset, split=test_split) dataset = concatenate_datasets([test_dataset, train_dataset]) # - split into train and test @@ -200,7 +205,7 @@ def train_unrolled(config): dataset, [train_size, test_size], generator=generator ) - train_set = DigiCam( + train_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, split=split_train, @@ -212,7 +217,7 @@ def train_unrolled(config): save_psf=config.files.save_psf, n_files=config.files.n_files, ) - test_set = DigiCam( + test_set = HFDataset( huggingface_repo=config.files.dataset, psf=config.files.huggingface_psf, split=split_test, @@ -226,7 +231,10 @@ def train_unrolled(config): ) if train_set.multimask: # get first PSF for initialization - first_psf_key = list(train_set.psf.keys())[device_ids[0]] + if device_ids is not None: + first_psf_key = list(train_set.psf.keys())[device_ids[0]] + else: + first_psf_key = list(train_set.psf.keys())[0] psf = train_set.psf[first_psf_key].to(device) else: psf = train_set.psf.to(device) @@ -265,7 +273,7 @@ def train_unrolled(config): extra_eval_sets = dict() for eval_set in config.files.extra_eval: - extra_eval_sets[eval_set] = DigiCam( + extra_eval_sets[eval_set] = HFDataset( split="test", downsample=config.files.downsample, # needs to be same size n_files=config.files.n_files, @@ -492,4 +500,4 @@ def train_unrolled(config): if __name__ == "__main__": - train_unrolled() + train_learned() From a1a94c12051c32cd9697577e4eeeec2db4c8ad3a Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 30 Apr 2024 23:21:50 +0000 Subject: [PATCH 08/10] Fix formatting. --- lensless/recon/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 1bc44369..c7096b83 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -1032,7 +1032,7 @@ def save(self, epoch, path="recon", include_optimizer=False): fp_plot = os.path.join(path, f"psf_epoch{epoch}_plot.png") plt.savefig(fp_plot) - if self.use_wandb and epoch!="BEST": + if self.use_wandb and epoch != "BEST": wandb.log({"psf": wandb.Image(fp)}, step=epoch) wandb.log({"psf_plot": wandb.Image(fp_plot)}, step=epoch) From 8cdb1ebc4791e1ca7b83077abeda47b7767cc685 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 30 Apr 2024 23:28:37 +0000 Subject: [PATCH 09/10] Formatting. --- lensless/eval/benchmark.py | 4 +++- lensless/recon/utils.py | 2 +- lensless/utils/dataset.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index f03ba57f..8df388c1 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -188,7 +188,9 @@ def benchmark( if use_wandb: assert epoch is not None, "epoch must be provided for wandb logging" - log_key = f"{_batch_idx}_{label}" if label is not None else f"{_batch_idx}" + log_key = ( + f"{_batch_idx}_{label}" if label is not None else f"{_batch_idx}" + ) wandb.log({log_key: wandb.Image(fp)}, step=epoch) # normalization diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index c7096b83..80349001 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -459,7 +459,7 @@ def __init__( plot_image(psf_np, gamma=self.gamma) fp_plot = "psf_original_plot.png" plt.savefig(fp_plot) - + if self.use_wandb: wandb.log({"psf": wandb.Image(fp)}, step=0) wandb.log({"psf_plot": wandb.Image(fp_plot)}, step=0) diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index b758337a..5a01e770 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1023,7 +1023,7 @@ def __init__( split, n_files=None, psf=None, - rotate=False, # just the lensless image + rotate=False, # just the lensless image downsample=1, downsample_lensed=1, display_res=None, From 54050db268b92165700c7205e4d28d360e9d2182 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 1 May 2024 09:23:41 +0000 Subject: [PATCH 10/10] Update CHANGELOG. --- CHANGELOG.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cd1c9754..7dd4d036 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,12 +21,13 @@ Added - ``lensless.utils.dataset.simulate_dataset`` for simulating a dataset given a mask/PSF. - Support for training/testing with multiple mask patterns in the dataset. - Multi-GPU support for training. -- DigiCam dataset which interfaces with Hugging Face. +- Dataset which interfaces with Hugging Face (``lensless.utils.dataset.HFDataset``). - Scripts for authentication. - DigiCam support for Telegram demo. - DiffuserCamMirflickr Hugging Face API. - Fallback for normalization if data not in 8bit range (``lensless.utils.io.save_image``). - Add utilities for fabricating masks with 3D printing (``lensless.hardware.fabrication``). +- WandB support. Changed ~~~~~~~