From 26e860d26ab77aef1459c8bf0f3c19aaab63c5ff Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 18 Jun 2024 15:17:56 +0000 Subject: [PATCH] Add preproc support for multi-wiener, add single psf option for benchmarking. --- configs/benchmark.yaml | 3 ++ lensless/recon/model_dict.py | 9 ++++- lensless/recon/multi_wiener.py | 57 ++++++++++++++++++++++++++- lensless/recon/utils.py | 2 + lensless/utils/dataset.py | 3 ++ scripts/eval/benchmark_recon.py | 26 ++++++++++-- scripts/recon/train_learning_based.py | 1 + 7 files changed, 94 insertions(+), 7 deletions(-) diff --git a/configs/benchmark.yaml b/configs/benchmark.yaml index ef689a83..11f3c9ca 100644 --- a/configs/benchmark.yaml +++ b/configs/benchmark.yaml @@ -25,6 +25,7 @@ huggingface: downsample: 1 downsample_lensed: 1 split_seed: null + single_channel_psf: False device: "cuda" # numbers of iterations to benchmark @@ -40,6 +41,8 @@ algorithms: ["ADMM", "ADMM_Monakhova2019", "FISTA"] #["ADMM", "ADMM_Monakhova201 baseline: "MONAKHOVA 100iter" save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10] +gamma_psf: 1.5 # gamma factor for PSF + # Hyperparameters nesterov: diff --git a/lensless/recon/model_dict.py b/lensless/recon/model_dict.py index b3d4536a..0491196e 100644 --- a/lensless/recon/model_dict.py +++ b/lensless/recon/model_dict.py @@ -81,11 +81,14 @@ "Unet4M+U10+Unet4M": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M", # simulated PSF (with waveprop, with deadspace) "U10_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-wave", + "U10+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-unet8M-wave", "Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unet8M-wave", + "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-wave", "TrainInv+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-trainable-inv-unet8M-wave", - "U10+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-unet8M-wave", + "U5+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm5-unet8M-wave", + "MWDN8M_wave": "bezzam/digicam-mirflickr-single-25k-mwdn-8M", + "MMCN4M+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-mmcn-unet4M", "Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-trainable-inv-unet4M-wave", - "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-wave", # measured PSF "Unet4M+U10+Unet4M_measured": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-measured", # simulated PSF (with waveprop, no deadspace) @@ -110,6 +113,7 @@ "Unet4M+TrainInv+Unet4M": "bezzam/tapecam-mirflickr-unet4M-trainable-inv-unet4M", "Unet4M+U5+Unet4M": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M", "Unet2M+MMCN+Unet2M": "bezzam/tapecam-mirflickr-unet2M-mmcn-unet2M", + "Unet2M+MWDN6M": "bezzam/tapecam-mirflickr-unet2M-mwdn-6M", }, }, } @@ -272,6 +276,7 @@ def load_model( psf=psf, psf_channels=3, nc=config["reconstruction"]["multi_wiener"]["nc"], + pre_process=pre_process, ) recon.to(device) diff --git a/lensless/recon/multi_wiener.py b/lensless/recon/multi_wiener.py index 29cd6c0e..5824c938 100644 --- a/lensless/recon/multi_wiener.py +++ b/lensless/recon/multi_wiener.py @@ -99,7 +99,16 @@ def WieNer(blur, psf, delta): class MultiWiener(nn.Module): - def __init__(self, in_channels, out_channels, psf, psf_channels=1, nc=None): + def __init__( + self, + in_channels, + out_channels, + psf, + psf_channels=1, + nc=None, + pre_process=None, + skip_pre=False, + ): """ Parameters ---------- @@ -165,6 +174,46 @@ def __init__(self, in_channels, out_channels, psf, psf_channels=1, nc=None): self._n_iter = 1 self._convolver = RealFFTConvolve2D(psf, pad=True) + self.set_pre_process(pre_process) + self.skip_pre = skip_pre + + def _prepare_process_block(self, process): + """ + Method for preparing the pre or post process block. + Parameters + ---------- + process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + Pre or post process block to prepare. + """ + if isinstance(process, torch.nn.Module): + # If the post_process is a torch module, we assume it is a DruNet like network. + from lensless.recon.utils import get_drunet_function_v2 + + process_model = process + process_function = get_drunet_function_v2(process_model, mode="train") + elif process is not None: + # Otherwise, we assume it is a function. + assert callable(process), "pre_process must be a callable function" + process_function = process + process_model = None + else: + process_function = None + process_model = None + + if process_function is not None: + process_param = torch.nn.Parameter(torch.tensor([1.0], device=self._psf.device)) + else: + process_param = None + + return process_function, process_model, process_param + + def set_pre_process(self, pre_process): + ( + self.pre_process, + self.pre_process_model, + self.pre_process_param, + ) = self._prepare_process_block(pre_process) + def forward(self, batch, psfs=None): if psfs is None: @@ -178,6 +227,12 @@ def forward(self, batch, psfs=None): if n_depth > 1: raise NotImplementedError("3D not implemented yet.") + # pre process data + if self.pre_process is not None and not self.skip_pre: + device_before = batch.device + batch = self.pre_process(batch, self.pre_process_param) + batch = batch.to(device_before) + # pad to multiple of 8 batch = convert_to_NCHW(batch) batch = torch.nn.functional.pad( diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 930c4bea..e9d79542 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -740,6 +740,7 @@ def train_epoch(self, data_loader): mean_loss = 0.0 i = 1.0 pbar = tqdm(data_loader) + self.recon.train() for batch in pbar: # get batch @@ -957,6 +958,7 @@ def evaluate(self, mean_loss, epoch, disp=None): output_dir = os.path.join(output_dir, str(epoch)) # benchmarking + self.recon.eval() current_metrics = benchmark( self.recon, self.test_dataset, diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 638216c7..762ab3cb 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1164,6 +1164,9 @@ def __init__( single_psf=single_channel_psf, ) self.psf = torch.from_numpy(psf) + if single_channel_psf: + # replicate across three channels + self.psf = self.psf.repeat(1, 1, 1, 3) elif "mask_label" in data_0: self.multimask = True diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index c04e6ebe..3cf163b6 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -28,6 +28,7 @@ from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset from lensless.utils.io import save_image +from lensless.utils.image import gamma_correction from lensless.recon.model_dict import download_model, load_model import torch @@ -129,6 +130,7 @@ def benchmark_recon(config): downsample_lensed=config.huggingface.downsample_lensed, alignment=config.huggingface.alignment, simulation_config=config.simulation, + single_channel_psf=config.huggingface.single_channel_psf, ) if benchmark_dataset.multimask: # get first PSF for initialization @@ -199,6 +201,9 @@ def benchmark_recon(config): ) if "hf" in algo: param = algo.split(":") + assert ( + len(param) == 4 + ), "hf model requires following format: hf:camera:dataset:model_name" camera = param[1] dataset = param[2] model_name = param[3] @@ -211,13 +216,20 @@ def benchmark_recon(config): skip_post = False model_path = download_model(camera=camera, dataset=dataset, model=model_name) - model_list.append( - (algo, load_model(model_path, psf, device, skip_pre=skip_pre, skip_post=skip_post)) - ) + model = load_model(model_path, psf, device, skip_pre=skip_pre, skip_post=skip_post) + model.eval() + model_list.append((algo, model)) results = {} output_dir = None + # save PSF + psf_np = psf.cpu().numpy()[0] + psf_np = psf_np / np.max(psf_np) + psf_np = gamma_correction(psf_np, gamma=config.gamma_psf) + save_image(psf_np, fp="psf.png") + + # save ground truth and lensless images if config.save_idx is not None: assert np.max(config.save_idx) < len( @@ -225,9 +237,11 @@ def benchmark_recon(config): ), "save_idx values must be smaller than dataset size" os.mkdir("GROUND_TRUTH") + os.mkdir("LENSLESS") for idx in config.save_idx: - ground_truth = benchmark_dataset[idx][1] + lensless, ground_truth = benchmark_dataset[idx] ground_truth_np = ground_truth.cpu().numpy()[0] + lensless_np = lensless.cpu().numpy()[0] if crop is not None: ground_truth_np = ground_truth_np[ @@ -239,6 +253,10 @@ def benchmark_recon(config): ground_truth_np, fp=os.path.join("GROUND_TRUTH", f"{idx}.png"), ) + save_image( + lensless_np, + fp=os.path.join("LENSLESS", f"{idx}.png"), + ) # benchmark each model for different number of iteration and append result to results # -- batchsize has to equal 1 as baseline models don't support batch processing start_time = time.time() diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index a3ce60f1..129358dc 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -484,6 +484,7 @@ def train_learned(config): psf=psf, psf_channels=3, nc=config.reconstruction.multi_wiener.nc, + pre_process=pre_process if pre_proc_delay is None else None, ) else: