From 9afc86408ef33f6f2227936b65681ccad3e1a0aa Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 19 Jun 2024 09:58:11 +0000 Subject: [PATCH] Add support for single channel and flipping. --- configs/train_mirflickr_diffuser.yaml | 37 +++++++++++++++++++++++++++ configs/train_unrolledADMM.yaml | 2 ++ lensless/recon/multi_wiener.py | 10 ++++---- lensless/recon/rfft_convolve.py | 9 ++++--- scripts/recon/train_learning_based.py | 14 +++++++++- 5 files changed, 63 insertions(+), 9 deletions(-) create mode 100644 configs/train_mirflickr_diffuser.yaml diff --git a/configs/train_mirflickr_diffuser.yaml b/configs/train_mirflickr_diffuser.yaml new file mode 100644 index 00000000..06bfdd9b --- /dev/null +++ b/configs/train_mirflickr_diffuser.yaml @@ -0,0 +1,37 @@ +# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape +defaults: + - train_unrolledADMM + - _self_ + +torch_device: 'cuda:0' +device_ids: [0, 1, 2, 3] +eval_disp_idx: [0, 1, 3, 4, 8] + +# Dataset +files: + dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM + huggingface_dataset: True + huggingface_psf: psf.tiff + single_channel_psf: True + downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution + downsample_lensed: 2 # only used if lensed if measured + flipud: True + flip_lensed: True + +training: + batch_size: 4 + epoch: 25 + eval_batch_size: 4 + +reconstruction: + method: unrolled_admm + unrolled_admm: + n_iter: 5 + pre_process: + network : UnetRes # UnetRes or DruNet or null + depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: [32,64,116,128] + post_process: + network : UnetRes # UnetRes or DruNet or null + depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet + nc: [32,64,116,128] diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 2a55a834..5e9fa940 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -34,6 +34,8 @@ files: vertical_shift: null horizontal_shift: null rotate: False + flipud: False + flip_lensed: False save_psf: False crop: null # vertical: null diff --git a/lensless/recon/multi_wiener.py b/lensless/recon/multi_wiener.py index 5824c938..2ada3203 100644 --- a/lensless/recon/multi_wiener.py +++ b/lensless/recon/multi_wiener.py @@ -41,14 +41,14 @@ def forward(self, x): class Down(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.avgpool_conv = nn.Sequential( + self.pool_conv = nn.Sequential( # nn.AvgPool2d(2), - nn.MaxPool2d(2), + nn.MaxPool2d(2), # original paper says max-pooling DoubleConv(in_channels, out_channels), ) def forward(self, x): - return self.avgpool_conv(x) + return self.pool_conv(x) class Up(nn.Module): @@ -120,9 +120,9 @@ def __init__( """ assert in_channels == 1 or in_channels == 3, "in_channels must be 1 or 3" assert out_channels == 1 or out_channels == 3, "out_channels must be 1 or 3" + assert in_channels >= out_channels if nc is None: nc = [64, 128, 256, 512, 512] - # assert nc[-1] == nc[-2], "Last two channels must be the same" super(MultiWiener, self).__init__() self.in_channels = in_channels @@ -172,7 +172,7 @@ def __init__( self._psf, (self.left, self.right, self.top, self.bottom), mode="constant", value=0 ) self._n_iter = 1 - self._convolver = RealFFTConvolve2D(psf, pad=True) + self._convolver = RealFFTConvolve2D(psf, pad=True, rgb=True if out_channels == 3 else False) self.set_pre_process(pre_process) self.skip_pre = skip_pre diff --git a/lensless/recon/rfft_convolve.py b/lensless/recon/rfft_convolve.py index 5518a651..c0a58236 100644 --- a/lensless/recon/rfft_convolve.py +++ b/lensless/recon/rfft_convolve.py @@ -24,7 +24,7 @@ class RealFFTConvolve2D: - def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs): + def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs): """ Linear operator that performs convolution in Fourier domain, and assumes real-valued signals. @@ -56,7 +56,10 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs): len(psf.shape) >= 4 ), "Expected 4D PSF of shape ([batch], depth, width, height, channels)" self._use_3d = psf.shape[-4] != 1 - self._is_rgb = psf.shape[-1] == 3 + if rgb is None: + self._is_rgb = psf.shape[-1] == 3 + else: + self._is_rgb = rgb assert self._is_rgb or psf.shape[-1] == 1 # save normalization @@ -80,7 +83,7 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs): self._padded_shape = 2 * self._psf_shape[-3:-1] - 1 self._padded_shape = np.array([next_fast_len(i) for i in self._padded_shape]) self._padded_shape = list( - np.r_[self._psf_shape[-4], self._padded_shape, self._psf_shape[-1]] + np.r_[self._psf_shape[-4], self._padded_shape, 3 if self._is_rgb else 1] ) self._start_idx = (self._padded_shape[-3:-1] - self._psf_shape[-3:-1]) // 2 self._end_idx = self._start_idx + self._psf_shape[-3:-1] diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 129358dc..f0e394c0 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -213,6 +213,8 @@ def train_learned(config): split=split_train, display_res=config.files.image_res, rotate=config.files.rotate, + flipud=config.files.flipud, + flip_lensed=config.files.flip_lensed, downsample=config.files.downsample, downsample_lensed=config.files.downsample_lensed, alignment=config.alignment, @@ -229,6 +231,8 @@ def train_learned(config): split=split_test, display_res=config.files.image_res, rotate=config.files.rotate, + flipud=config.files.flipud, + flip_lensed=config.files.flip_lensed, downsample=config.files.downsample, downsample_lensed=config.files.downsample_lensed, alignment=config.alignment, @@ -478,11 +482,19 @@ def train_learned(config): else False, ) elif config.reconstruction.method == "multi_wiener": + + if config.files.single_channel_psf: + psf = psf[..., 0].unsqueeze(-1) + psf_channels = 1 + else: + psf_channels = 3 + print(psf.shape) + recon = MultiWiener( in_channels=3, out_channels=3, psf=psf, - psf_channels=3, + psf_channels=psf_channels, nc=config.reconstruction.multi_wiener.nc, pre_process=pre_process if pre_proc_delay is None else None, )