diff --git a/configs/benchmark.yaml b/configs/benchmark.yaml index 8a852498..ef689a83 100644 --- a/configs/benchmark.yaml +++ b/configs/benchmark.yaml @@ -13,6 +13,7 @@ batchsize: 1 # must be 1 for iterative approaches huggingface: repo: "bezzam/DigiCam-Mirflickr-MultiMask-25K" + cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`. psf: null # null for simulating PSF image_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 825597d8..06f555b7 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -102,6 +102,9 @@ reconstruction: freeze: null unfreeze: null train_last_layer: False + # number of channels for each compensation layer, list should equal to the number of layers (n_iter) + # and the last element should be equal to last layer of post_process.nc + compensation: null #Trainable Mask trainable_mask: diff --git a/lensless/recon/drunet/basicblock.py b/lensless/recon/drunet/basicblock.py index ed17a10b..ae1108c1 100644 --- a/lensless/recon/drunet/basicblock.py +++ b/lensless/recon/drunet/basicblock.py @@ -6,9 +6,7 @@ # ############################################################################# from collections import OrderedDict -import torch import torch.nn as nn -import torch.nn.functional as F """ diff --git a/lensless/recon/drunet/network_unet.py b/lensless/recon/drunet/network_unet.py index 8d51d00b..c6476368 100644 --- a/lensless/recon/drunet/network_unet.py +++ b/lensless/recon/drunet/network_unet.py @@ -5,10 +5,11 @@ # https://github.com/cszn/DPIR/blob/15bca3fcc1f3cc51a1f99ccf027691e278c19354/models/network_unet.py # ############################################################################# + import torch import torch.nn as nn import lensless.recon.drunet.basicblock as B -import numpy as np +from torchvision.transforms.functional import resize """ # ==================== @@ -109,11 +110,13 @@ def __init__( act_mode="R", downsample_mode="strideconv", upsample_mode="convtranspose", + concatenate_compensation=False, ): super(UNetRes, self).__init__() assert len(nc) == 4, "nc's length should be 4." + self.nc = nc self.m_head = B.conv(in_nc, nc[0], bias=False, mode="C") # downsample @@ -139,9 +142,21 @@ def __init__( downsample_block(nc[2], nc[3], bias=False, mode="2") ) - self.m_body = B.sequential( - *[B.ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] - ) + self.concatenate_compensation = concatenate_compensation + if concatenate_compensation: + self.m_body = B.sequential( + *[ + B.ResBlock(nc[3] * 2, nc[3] * 2, bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ] + ) + else: + self.m_body = B.sequential( + *[ + B.ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C") + for _ in range(nb) + ] + ) # upsample if upsample_mode == "upconv": @@ -154,7 +169,9 @@ def __init__( raise NotImplementedError("upsample mode [{:s}] is not found".format(upsample_mode)) self.m_up3 = B.sequential( - upsample_block(nc[3], nc[2], bias=False, mode="2"), + upsample_block( + nc[3] * 2 if concatenate_compensation else nc[3], nc[2], bias=False, mode="2" + ), *[B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] ) self.m_up2 = B.sequential( @@ -168,13 +185,22 @@ def __init__( self.m_tail = B.conv(nc[0], out_nc, bias=False, mode="C") - def forward(self, x0): + def forward(self, x0, compensation_output=None): + if self.concatenate_compensation: + assert compensation_output is not None, "compensation_output should not be None." x1 = self.m_head(x0) x2 = self.m_down1(x1) x3 = self.m_down2(x2) x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) + + if compensation_output is not None: + compensation_output_re = resize(compensation_output, tuple(x4.shape[-2:])) + latent = torch.cat([x4, compensation_output_re], dim=1) + else: + latent = x4 + + x = self.m_body(latent) + x = self.m_up3(x + latent) x = self.m_up2(x + x3) x = self.m_up1(x + x2) x = self.m_tail(x + x1) diff --git a/lensless/recon/model_dict.py b/lensless/recon/model_dict.py index 244282be..2903a0dd 100644 --- a/lensless/recon/model_dict.py +++ b/lensless/recon/model_dict.py @@ -93,8 +93,11 @@ "Unet4M+U10+Unet4M_nodead": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-nodead", }, "mirflickr_multi_25k": { + # simulated PSFs (without waveprop, with deadspace) "Unet8M": "bezzam/digicam-mirflickr-multi-25k-unet8M", "Unet4M+U10+Unet4M": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M", + # simulated PSF (with waveprop, with deadspace) + "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave", }, }, } diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 33a9b9ec..0a5efc19 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -54,6 +54,7 @@ def __init__( skip_unrolled=False, return_unrolled_output=False, legacy_denoiser=False, + compensation=None, **kwargs, ): """ @@ -63,25 +64,28 @@ def __init__( Parameters ---------- - psf : :py:class:`~torch.Tensor` - Point spread function (PSF) that models forward propagation. - Must be of shape (depth, height, width, channels) even if - depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` - to load a PSF from a file such that it is in the correct format. - dtype : float32 or float64 - Data type to use for optimization. - n_iter : int - Number of iterations for unrolled algorithm. - pre_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional - If :py:class:`function` : Function to apply to the image estimate before algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. - If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate before algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. - post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional - If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. - If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. - skip_unrolled : bool, optional - Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction). - return_unrolled_output : bool, optional - Whether to return the output of the unrolled algorithm if also using a post-processor block. + psf : :py:class:`~torch.Tensor` + Point spread function (PSF) that models forward propagation. + Must be of shape (depth, height, width, channels) even if + depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf` + to load a PSF from a file such that it is in the correct format. + dtype : float32 or float64 + Data type to use for optimization. + n_iter : int + Number of iterations for unrolled algorithm. + pre_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + If :py:class:`function` : Function to apply to the image estimate before algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. + If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate before algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. + post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. + If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. + skip_unrolled : bool, optional + Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction). + return_unrolled_output : bool, optional + Whether to return the output of the unrolled algorithm if also using a post-processor block. + compensation : list, optional + Number of channels for each intermediate output in compensation layer, as in "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021). + Post-processor must be defined if compensation provided. """ assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor" super(TrainableReconstructionAlgorithm, self).__init__( @@ -93,6 +97,18 @@ def __init__( self.set_post_process(post_process) self.skip_unrolled = skip_unrolled self.return_unrolled_output = return_unrolled_output + self.compensation_branch = compensation + if compensation is not None: + from lensless.recon.utils import CompensationBranch + + assert ( + post_process is not None + ), "If compensation_branch is True, post_process must be defined." + assert ( + len(compensation) == n_iter + ), "compensation_nc must have the same length as n_iter" + self.compensation_branch = CompensationBranch(compensation) + if self.return_unrolled_output: assert ( post_process is not None @@ -231,15 +247,25 @@ def forward(self, batch, psfs=None): # unrolled algorithm if not self.skip_unrolled: + if self.compensation_branch is not None: + compensation_branch_inputs = [self._data] + for i in range(self._n_iter): self._update(i) + if self.compensation_branch is not None and i < self._n_iter - 1: + compensation_branch_inputs.append(self._form_image()) + image_est = self._form_image() else: image_est = self._data # post process data if self.post_process is not None: - final_est = self.post_process(image_est, self.post_process_param) + compensation_output = None + if self.compensation_branch is not None: + compensation_output = self.compensation_branch(compensation_branch_inputs) + + final_est = self.post_process(image_est, self.post_process_param, compensation_output) else: final_est = image_est diff --git a/lensless/recon/unrolled_admm.py b/lensless/recon/unrolled_admm.py index 2447aa4b..d30eb6b9 100644 --- a/lensless/recon/unrolled_admm.py +++ b/lensless/recon/unrolled_admm.py @@ -235,5 +235,5 @@ def _update(self, iter): def _form_image(self): image = self._convolver._crop(self._image_est) - image[image < 0] = 0 + image = torch.clamp(image, min=0) return image diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 0d4b44f4..e78961c4 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -14,6 +14,7 @@ import time import os import torch +from torch import nn from lensless.eval.benchmark import benchmark from lensless.hardware.trainable_mask import TrainableMask from tqdm import tqdm @@ -23,6 +24,115 @@ from lensless.utils.dataset import SimulatedDatasetTrainableMask +def double_cnn_max_pool(c_in, c_out, cnn_kernel=3, max_pool=2): + return nn.Sequential( + nn.Conv2d( + in_channels=c_in, + out_channels=c_out, + kernel_size=cnn_kernel, + padding="same", + bias=False, + ), + nn.BatchNorm2d(c_out), + nn.ReLU(), + nn.Conv2d( + in_channels=c_out, + out_channels=c_out, + kernel_size=cnn_kernel, + padding="same", + bias=False, + ), + nn.BatchNorm2d(c_out), + nn.ReLU(), + nn.MaxPool2d(kernel_size=max_pool), + ) + + +class CompensationBranch(nn.Module): + """ + Compensation branch for unrolled algorithm, as in "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021). + """ + + def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3): + """ + + Parameters + ---------- + nc : list + Number of channels for each layer of the compensation branch. + cnn_kernel : int, optional + Kernel size for convolutional layers, by default 3. + max_pool : int, optional + Kernel size for max pooling layers, by default 2. + in_channel : int, optional + Number of input channels, by default 3 for RGB. + """ + super(CompensationBranch, self).__init__() + + self.n_iter = len(nc) + + # layers along the compensation branch, f^C in paper + branch_layers = [ + double_cnn_max_pool(in_channel, nc[0], cnn_kernel=cnn_kernel, max_pool=max_pool) + ] + self.branch_layers = nn.ModuleList( + branch_layers + + [ + double_cnn_max_pool( + nc[i] * 2, # due to concatenation with intermediate layer + nc[i + 1], + cnn_kernel=cnn_kernel, + max_pool=max_pool, + ) + for i in range(self.n_iter - 1) + ] + ) + + # residual layers for intermediate output, \tilde{f}^C in paper + # -- not mentinoed in paper, but added more max-pooling for later residual layers, otherwise dimensions don't match + self.residual_layers = nn.ModuleList( + [ + double_cnn_max_pool( + in_channel, nc[i], cnn_kernel=cnn_kernel, max_pool=max_pool ** (i + 1) + ) + for i in range(self.n_iter - 1) + ] + ) + + def forward(self, x, return_NCHW=True): + """ + Input must be original input and intermediate outputs: (b, s1, s2, ... , s^{K-1}), where K is the number of iterations. + + See p. 1085 of "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021) for more details. + """ + assert len(x) == self.n_iter, "Input must have the same length as the number of iterations." + n_depth = x[0].shape[-4] + h_apo_k = self.branch_layers[0](convert_to_NCHW(x[0])) # h^{'}_k + for k in range(self.n_iter - 1): # eq. 18-21 + # \tilde{h}_k + h_k = torch.cat([h_apo_k, self.residual_layers[k](convert_to_NCHW(x[k + 1]))], axis=1) + h_apo_k = self.branch_layers[k + 1](h_k) # h^{'}_k + + if return_NCHW: + return h_apo_k + else: + return convert_to_NDCHW(h_apo_k, n_depth) + + +# convert from NDHWC to NCHW +def convert_to_NCHW(image): + image = image.movedim(-1, -3) + image = image.reshape(-1, *image.shape[-3:]) + return image + + +# convert back to NDHWC +def convert_to_NDCHW(image, depth): + image = image.movedim(-3, -1) + image = image.reshape(-1, depth, *image.shape[-3:]) + return image + + def load_drunet(model_path=None, n_channels=3, requires_grad=False): """ Load a pre-trained Drunet model. @@ -79,7 +189,7 @@ def load_drunet(model_path=None, n_channels=3, requires_grad=False): return model -def apply_denoiser(model, image, noise_level=10, mode="inference"): +def apply_denoiser(model, image, noise_level=10, mode="inference", compensation_output=None): """ Apply a pre-trained denoising model with input in the format Channel, Height, Width. An additionnal channel is added for the noise level as done in Drunet. @@ -132,9 +242,9 @@ def apply_denoiser(model, image, noise_level=10, mode="inference"): # apply model if mode == "inference": with torch.no_grad(): - image = model(image) + image = model(image, compensation_output) elif mode == "train": - image = model(image) + image = model(image, compensation_output) else: raise ValueError("mode must be 'inference' or 'train'") @@ -187,13 +297,14 @@ def get_drunet_function_v2(model, mode="inference"): Mode to use for model. Can be "inference" or "train". """ - def process(image, noise_level): + def process(image, noise_level, compensation_output=None): x_max = torch.amax(image, dim=(-1, -2, -3, -4), keepdim=True) + 1e-6 image = apply_denoiser( model, image / x_max, noise_level=noise_level, mode=mode, + compensation_output=compensation_output, ) image = torch.clip(image, min=0.0) * x_max.to(image.device) return image @@ -223,7 +334,9 @@ def measure_gradient(model): return total_norm -def create_process_network(network, depth=4, device="cpu", nc=None, device_ids=None): +def create_process_network( + network, depth=4, device="cpu", nc=None, device_ids=None, concatenate_compensation=False +): """ Helper function to create a process network. @@ -248,6 +361,9 @@ def create_process_network(network, depth=4, device="cpu", nc=None, device_ids=N assert len(nc) == 4 if network == "DruNet": + assert ( + concatenate_compensation is False + ), "DruNet does not support concatenation of compensation branch." from lensless.recon.utils import load_drunet process = load_drunet(requires_grad=True) @@ -264,6 +380,7 @@ def create_process_network(network, depth=4, device="cpu", nc=None, device_ids=N act_mode="R", downsample_mode="strideconv", upsample_mode="convtranspose", + concatenate_compensation=concatenate_compensation, ) process_name = "UnetRes_d" + str(depth) else: @@ -629,6 +746,7 @@ def train_epoch(self, data_loader): self.recon._set_psf(self.mask.get_psf().to(self.device)) # forward pass + # torch.autograd.set_detect_anomaly(True) # for debugging y_pred = self.recon.forward(batch=X, psfs=psfs) if self.unrolled_output_factor: unrolled_out = y_pred[1] diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index daf8c194..e8311f26 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -1040,6 +1040,7 @@ def __init__( simulation_config=dict(), simulate_lensless=False, force_rgb=False, + cache_dir=None, **kwargs, ): """ @@ -1083,7 +1084,7 @@ def __init__( if isinstance(split, str): if n_files is not None: split = f"{split}[0:{n_files}]" - self.dataset = load_dataset(huggingface_repo, split=split) + self.dataset = load_dataset(huggingface_repo, split=split, cache_dir=cache_dir) elif isinstance(split, Dataset): self.dataset = split else: diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 500b6203..c51aa48a 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -100,8 +100,12 @@ def benchmark_recon(config): if config.n_files is not None: train_split = f"train[:{config.n_files}]" test_split = f"test[:{config.n_files}]" - train_dataset = load_dataset(config.huggingface.repo, split=train_split) - test_dataset = load_dataset(config.huggingface.repo, split=test_split) + train_dataset = load_dataset( + config.huggingface.repo, split=train_split, cache_dir=config.huggingface.cache_dir + ) + test_dataset = load_dataset( + config.huggingface.repo, split=test_split, cache_dir=config.huggingface.cache_dir + ) dataset = concatenate_datasets([test_dataset, train_dataset]) # - split into train and test @@ -113,6 +117,7 @@ def benchmark_recon(config): benchmark_dataset = HFDataset( huggingface_repo=config.huggingface.repo, + cache_dir=config.huggingface.cache_dir, psf=config.huggingface.psf, n_files=n_files, split=split_test, diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 60d89582..9cb436c1 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -373,8 +373,14 @@ def train_learned(config): nc=config.reconstruction.post_process.nc, device=device, device_ids=device_ids, + concatenate_compensation=True if config.reconstruction.compensation is not None else False, ) post_proc_delay = config.reconstruction.post_process.delay + if config.reconstruction.post_process.network is not None: + if config.reconstruction.compensation is not None: + assert ( + config.reconstruction.compensation[-1] == config.reconstruction.post_process.nc[-1] + ) if config.reconstruction.post_process.train_last_layer: for name, param in post_process.named_parameters(): @@ -424,6 +430,7 @@ def train_learned(config): post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, return_unrolled_output=True if config.unrolled_output_factor > 0 else False, + compensation=config.reconstruction.compensation, ) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( @@ -437,6 +444,7 @@ def train_learned(config): post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, return_unrolled_output=True if config.unrolled_output_factor > 0 else False, + compensation=config.reconstruction.compensation, ) elif config.reconstruction.method == "trainable_inv": recon = TrainableInversion(