From c0124a411543d9452e39ab50bdf3b5fc5441ec9f Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 9 Jul 2024 08:15:21 +0000 Subject: [PATCH] Add models, optional background subtraction. --- .../benchmark_digicam_mirflickr_single.yaml | 55 +++++++++++++------ lensless/recon/model_dict.py | 4 ++ lensless/utils/io.py | 9 ++- scripts/data/rename_mirflickr25k.py | 2 +- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/configs/benchmark_digicam_mirflickr_single.yaml b/configs/benchmark_digicam_mirflickr_single.yaml index 94562045..e690e5c4 100644 --- a/configs/benchmark_digicam_mirflickr_single.yaml +++ b/configs/benchmark_digicam_mirflickr_single.yaml @@ -6,10 +6,11 @@ defaults: dataset: HFDataset batchsize: 4 -device: "cuda:0" +device: "cuda:3" huggingface: repo: "bezzam/DigiCam-Mirflickr-SingleMask-25K" + cache_dir: /dev/shm psf: null # null for simulating PSF image_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down @@ -20,26 +21,48 @@ huggingface: height: 200 downsample: 1 + algorithms: [ - ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=True) - "hf:digicam:mirflickr_single_25k:U10_wave", - "hf:digicam:mirflickr_single_25k:Unet8M_wave", + "ADMM", + "hf:digicam:mirflickr_single_25k:U5+Unet8M_wave", "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave", - "hf:digicam:mirflickr_single_25k:U10+Unet8M_wave", + "hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:MWDN8M_wave", "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave", - "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave" - - ## -- below models need to set correct PSF simulation - # ## - measured PSF (huggingface.psf=psf_measured.png) - # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_measured", - # ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=False) - # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave_nodead", - # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=True) - # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M", - # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=False) - # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_nodead" + "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", + "hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave", + "hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave", + "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + + # ## -- reconstructions trained on other datasets/systems + # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", + # "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", + # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", + # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", ] + + +# algorithms: [ +# # ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=True) +# # "hf:digicam:mirflickr_single_25k:U10_wave", +# # "hf:digicam:mirflickr_single_25k:Unet8M_wave", +# # "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave", +# # "hf:digicam:mirflickr_single_25k:U10+Unet8M_wave", +# # "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave", +# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave" + +# ## -- below models need to set correct PSF simulation +# # ## - measured PSF (huggingface.psf=psf_measured.png) +# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_measured", +# # ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=False) +# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave_nodead", +# # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=True) +# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M", +# # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=False) +# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_nodead" +# ] save_idx: [1, 2, 4, 5, 9] +n_iter_range: [100] # for ADMM # simulating PSF simulation: diff --git a/lensless/recon/model_dict.py b/lensless/recon/model_dict.py index 7056c897..8a0a21d9 100644 --- a/lensless/recon/model_dict.py +++ b/lensless/recon/model_dict.py @@ -90,6 +90,7 @@ "Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-celeba-unet4M-trainable-inv-unet4M_wave", "Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-celeba-unet2M-mmcn-unet2M", "Unet4M+U5+Unet4M_wave": "bezzam/digicam-celeba-unet4M-unrolled-admm5-unet4M", + "Unet4M+U10+Unet4M_wave": "bezzam/digicam-celeba-unet4M-unrolled-admm10-unet4M", }, "mirflickr_single_25k": { # simulated PSF (without waveprop, with deadspace) @@ -112,6 +113,8 @@ "Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mmcn-unet2M-wave", "Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-trainable-inv-unet4M-wave", "Unet2M+MWDN6M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mwdn-6M", + "Unet4M+U5+Unet4M_wave_aux1": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-aux1", + "Unet4M+U5+Unet4M_wave_flips": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-flips", # measured PSF "Unet4M+U10+Unet4M_measured": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-measured", # simulated PSF (with waveprop, no deadspace) @@ -125,6 +128,7 @@ "Unet4M+U10+Unet4M": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M", # simulated PSF (with waveprop, with deadspace) "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave", + "Unet4M+U5+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave", }, }, "tapecam": { diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 67586406..62fd7f2b 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -471,12 +471,13 @@ def load_data( use_3d = psf_fp.endswith(".npy") or psf_fp.endswith(".npz") # load and process PSF data - psf, bg = load_psf( + bg = None + res = load_psf( psf_fp, downsample=downsample, return_float=return_float, bg_pix=bg_pix, - return_bg=True, + return_bg=True if bg_pix is not None else False, flip=flip, flip_ud=flip_ud, flip_lr=flip_lr, @@ -489,6 +490,10 @@ def load_data( use_3d=use_3d, bgr_input=bgr_input, ) + if bg_pix is not None: + psf, bg = res + else: + psf = res # load and process raw measurement data = load_image( diff --git a/scripts/data/rename_mirflickr25k.py b/scripts/data/rename_mirflickr25k.py index 58b9b080..fb4e125e 100644 --- a/scripts/data/rename_mirflickr25k.py +++ b/scripts/data/rename_mirflickr25k.py @@ -9,7 +9,7 @@ from lensless.utils.dataset import natural_sort -dir_path = "/root/mirflickr/mirflickr25k" +dir_path = "/dev/shm/mirflickr" # get all jpg files files = natural_sort(glob.glob(os.path.join(dir_path, "*.jpg")))