From e72aa110db6a054be50d6c76bcf7e673a9b735f1 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Thu, 27 Jun 2024 09:41:00 +0000 Subject: [PATCH] Add residual block to compensation. --- configs/train_unrolledADMM.yaml | 2 + lensless/recon/drunet/basicblock.py | 1 - lensless/recon/model_dict.py | 42 ++++++- lensless/recon/utils.py | 80 ++++++++++-- .../multi_lens_array_2024-06-05_09-01-04.json | 115 ++++++++++++++++++ scripts/recon/train_learning_based.py | 11 +- 6 files changed, 238 insertions(+), 13 deletions(-) create mode 100644 notebooks/multi_lens_array_2024-06-05_09-01-04.json diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 5e9fa940..c27825c1 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -16,6 +16,8 @@ files: # psf: data/psf/diffusercam_psf.tiff # diffusercam_psf: True + cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`. + # -- using huggingface dataset dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM huggingface_dataset: True diff --git a/lensless/recon/drunet/basicblock.py b/lensless/recon/drunet/basicblock.py index ae1108c1..865d43bb 100644 --- a/lensless/recon/drunet/basicblock.py +++ b/lensless/recon/drunet/basicblock.py @@ -156,7 +156,6 @@ def __init__( ) def forward(self, x): - # res = self.res(x) return x + self.res(x) diff --git a/lensless/recon/model_dict.py b/lensless/recon/model_dict.py index 0491196e..068f9765 100644 --- a/lensless/recon/model_dict.py +++ b/lensless/recon/model_dict.py @@ -53,6 +53,16 @@ # baseline benchmarks which don't have model file but use ADMM "admm_fista": "bezzam/diffusercam-mirflickr-admm-fista", "admm_pnp": "bezzam/diffusercam-mirflickr-admm-pnp", + # -- TCI submission + "TrainInv+Unet8M": "bezzam/diffusercam-mirflickr-trainable-inv-unet8M", + "Unet4M+U5+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M", + "MWDN8M": "bezzam/diffusercam-mirflickr-mwdn-8M", + "Unet2M+MWDN6M": "bezzam/diffusercam-mirflickr-unet2M-mwdn-6M", + "Unet4M+TrainInv+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-trainable-inv-unet4M", + "MMCN4M+Unet4M": "bezzam/diffusercam-mirflickr-mmcn-unet4M", + "U5+Unet8M": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M", + "Unet2M+MMCN+Unet2M": "bezzam/diffusercam-mirflickr-unet2M-mmcn-unet2M", + "Unet4M+U20+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm20-unet4M", }, }, "digicam": { @@ -70,6 +80,12 @@ # baseline benchmarks which don't have model file but use ADMM "admm_measured_psf": "bezzam/digicam-celeba-admm-measured-psf", "admm_simulated_psf": "bezzam/digicam-celeba-admm-simulated-psf", + # TCI submission (using waveprop simulation) + "U5+Unet8M_wave": "bezzam/digicam-celeba-unrolled-admm5-unet8M", + "TrainInv+Unet8M_wave": "bezzam/digicam-celeba-trainable-inv-unet8M_wave", + "MWDN8M_wave": "bezzam/digicam-celeba-mwnn-8M", + "MMCN4M+Unet4M_wave": "bezzam/digicam-celeba-mmcn-unet4M", + "Unet2M+MWDN6M_wave": "bezzam/digicam-celeba-unet2M-mwdn-6M", }, "mirflickr_single_25k": { # simulated PSF (without waveprop, with deadspace) @@ -86,9 +102,12 @@ "Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-wave", "TrainInv+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-trainable-inv-unet8M-wave", "U5+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm5-unet8M-wave", + "Unet4M+U5+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave", "MWDN8M_wave": "bezzam/digicam-mirflickr-single-25k-mwdn-8M", "MMCN4M+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-mmcn-unet4M", + "Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mmcn-unet2M-wave", "Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-trainable-inv-unet4M-wave", + "Unet2M+MWDN6M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mwdn-6M", # measured PSF "Unet4M+U10+Unet4M_measured": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-measured", # simulated PSF (with waveprop, no deadspace) @@ -270,11 +289,25 @@ def load_model( skip_post=skip_post, ) elif config["reconstruction"]["method"] == "multi_wiener": + + if config["files"].get("single_channel_psf", False): + + if torch.sum(psf[..., 0] - psf[..., 1]) != 0: + # need to sum difference channels + raise ValueError("PSF channels are not the same") + # psf = np.sum(psf, axis=3) + + else: + psf = psf[..., 0].unsqueeze(-1) + psf_channels = 1 + else: + psf_channels = 3 + recon = MultiWiener( in_channels=3, out_channels=3, psf=psf, - psf_channels=3, + psf_channels=psf_channels, nc=config["reconstruction"]["multi_wiener"]["nc"], pre_process=pre_process, ) @@ -287,6 +320,13 @@ def load_model( if "device_ids" in config.keys() and config["device_ids"] is not None: model_state_dict = remove_data_parallel(model_state_dict) + # hotfixes for loading models + if config["reconstruction"]["method"] == "multi_wiener": + # replace "avgpool_conv" with "pool_conv" + model_state_dict = { + k.replace("avgpool_conv", "pool_conv"): v for k, v in model_state_dict.items() + } + recon.load_state_dict(model_state_dict) return recon diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index e9d79542..b3fae9e6 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -24,13 +24,13 @@ from lensless.utils.dataset import SimulatedDatasetTrainableMask -def double_cnn_max_pool(c_in, c_out, cnn_kernel=3, max_pool=2): +def double_cnn_max_pool(c_in, c_out, cnn_kernel=3, max_pool=2, padding=1, stride=1): return nn.Sequential( nn.Conv2d( in_channels=c_in, out_channels=c_out, kernel_size=cnn_kernel, - padding="same", + padding=padding, bias=False, ), nn.BatchNorm2d(c_out), @@ -39,21 +39,52 @@ def double_cnn_max_pool(c_in, c_out, cnn_kernel=3, max_pool=2): in_channels=c_out, out_channels=c_out, kernel_size=cnn_kernel, - padding="same", + padding=padding, bias=False, ), nn.BatchNorm2d(c_out), nn.ReLU(), - nn.MaxPool2d(kernel_size=max_pool), + # don't pass stride=1, otherwise no pooling/downsampling.. + nn.MaxPool2d(kernel_size=max_pool, padding=0) if max_pool else nn.Identity(), ) +class ResBlock(nn.Module): + def __init__(self, c_in, c_out, cnn_kernel=3, max_pool=2, padding=1, stride=1): + super(ResBlock, self).__init__() + # assert c_in == c_out, "Input and output channels must be the same for residual block." + + # conv layers for residual need to be the same size + self.double_conv = double_cnn_max_pool( + c_in, c_in, cnn_kernel=cnn_kernel, max_pool=False, padding=padding, stride=stride + ) + + # pooling + self.pooling = nn.Sequential( + nn.Conv2d( + in_channels=c_in, + out_channels=c_out, + kernel_size=cnn_kernel, + padding=padding, + bias=False, + ), + nn.BatchNorm2d(c_out), + nn.ReLU(), + nn.MaxPool2d(kernel_size=max_pool, padding=0), + ) + + def forward(self, x): + return self.pooling(x + self.double_conv(x)) + + class CompensationBranch(nn.Module): """ Compensation branch for unrolled algorithm, as in "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021). """ - def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3): + def __init__( + self, nc, cnn_kernel=3, max_pool=2, in_channel=3, residual=True, padding=1, stride=1 + ): """ Parameters @@ -66,6 +97,8 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3): Kernel size for max pooling layers, by default 2. in_channel : int, optional Number of input channels, by default 3 for RGB. + residual : bool, optional + Whether to use residual block or simply double conv for intermediate layers, by default True. """ super(CompensationBranch, self).__init__() @@ -73,7 +106,14 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3): # layers along the compensation branch, f^C in paper branch_layers = [ - double_cnn_max_pool(in_channel, nc[0], cnn_kernel=cnn_kernel, max_pool=max_pool) + double_cnn_max_pool( + in_channel, + nc[0], + cnn_kernel=cnn_kernel, + max_pool=max_pool, + padding=padding, + stride=stride, + ) ] self.branch_layers = nn.ModuleList( branch_layers @@ -83,6 +123,8 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3): nc[i + 1], cnn_kernel=cnn_kernel, max_pool=max_pool, + padding=padding, + stride=stride, ) for i in range(self.n_iter - 1) ] @@ -92,8 +134,29 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3): # -- not mentinoed in paper, but added more max-pooling for later residual layers, otherwise dimensions don't match self.residual_layers = nn.ModuleList( [ - double_cnn_max_pool( - in_channel, nc[i], cnn_kernel=cnn_kernel, max_pool=max_pool ** (i + 1) + # double_cnn_max_pool( + # in_channel, nc[i], cnn_kernel=cnn_kernel, max_pool=max_pool ** (i + 1) + # ) + # B.sequential( + # B.ResBlock(in_channel, in_channel, bias=False, mode="CRC", padding=padding, stride=stride), + # B.downsample_maxpool(in_channel, nc[i], bias=False, mode=str(max_pool ** (i + 1)), padding=padding, stride=stride) + # ) if residual + ResBlock( + in_channel, + nc[i], + cnn_kernel=cnn_kernel, + max_pool=max_pool ** (i + 1), + padding=padding, + stride=stride, + ) + if residual + else double_cnn_max_pool( + in_channel, + nc[i], + cnn_kernel=cnn_kernel, + max_pool=max_pool ** (i + 1), + padding=padding, + stride=stride, ) for i in range(self.n_iter - 1) ] @@ -110,6 +173,7 @@ def forward(self, x, return_NCHW=True): h_apo_k = self.branch_layers[0](convert_to_NCHW(x[0])) # h^{'}_k for k in range(self.n_iter - 1): # eq. 18-21 # \tilde{h}_k + # import pudb; pudb.set_trace() h_k = torch.cat([h_apo_k, self.residual_layers[k](convert_to_NCHW(x[k + 1]))], axis=1) h_apo_k = self.branch_layers[k + 1](h_k) # h^{'}_k diff --git a/notebooks/multi_lens_array_2024-06-05_09-01-04.json b/notebooks/multi_lens_array_2024-06-05_09-01-04.json new file mode 100644 index 00000000..ae419220 --- /dev/null +++ b/notebooks/multi_lens_array_2024-06-05_09-01-04.json @@ -0,0 +1,115 @@ +{ + "seed": 4, + "n_lens": 15, + "radius_range": [ + 100.0, + 400.0 + ], + "min_separation": 5.0, + "focal_length": [ + 859.0, + 854.0, + 852.0, + 848.0, + 780.0, + 726.0, + 684.0, + 672.0, + 574.0, + 501.0, + 382.0, + 358.0, + 346.0, + 324.0, + 221.0 + ], + "radius": [ + 395.0, + 393.0, + 392.0, + 390.0, + 359.0, + 334.0, + 314.0, + 309.0, + 264.0, + 230.0, + 176.0, + 165.0, + 159.0, + 149.0, + 102.0 + ], + "min_height": 1000.0, + "loc": [ + [ + 2014.0, + 419.0 + ], + [ + 2989.0, + 1577.0 + ], + [ + 2970.0, + 2528.0 + ], + [ + 594.0, + 2024.0 + ], + [ + 1495.0, + 1828.0 + ], + [ + 640.0, + 782.0 + ], + [ + 2145.0, + 1467.0 + ], + [ + 2107.0, + 2520.0 + ], + [ + 796.0, + 1380.0 + ], + [ + 785.0, + 3022.0 + ], + [ + 1535.0, + 2790.0 + ], + [ + 1448.0, + 1215.0 + ], + [ + 940.0, + 2490.0 + ], + [ + 2607.0, + 305.0 + ], + [ + 3076.0, + 3331.0 + ] + ], + "mask_size": [ + 3500.0, + 3500.0 + ], + "frame_size": [ + 3500.0, + 3500.0 + ], + "refractive_index": 1.46 +} \ No newline at end of file diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index f0e394c0..c4ecb851 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -195,8 +195,12 @@ def train_learned(config): if config.files.n_files is not None: train_split = f"train[:{config.files.n_files}]" test_split = f"test[:{config.files.n_files}]" - train_dataset = load_dataset(config.files.dataset, split=train_split) - test_dataset = load_dataset(config.files.dataset, split=test_split) + train_dataset = load_dataset( + config.files.dataset, split=train_split, cache_dir=config.files.cache_dir + ) + test_dataset = load_dataset( + config.files.dataset, split=test_split, cache_dir=config.files.cache_dir + ) dataset = concatenate_datasets([test_dataset, train_dataset]) # - split into train and test @@ -208,6 +212,7 @@ def train_learned(config): train_set = HFDataset( huggingface_repo=config.files.dataset, + cache_dir=config.files.cache_dir, psf=config.files.huggingface_psf, single_channel_psf=config.files.single_channel_psf, split=split_train, @@ -226,6 +231,7 @@ def train_learned(config): ) test_set = HFDataset( huggingface_repo=config.files.dataset, + cache_dir=config.files.cache_dir, psf=config.files.huggingface_psf, single_channel_psf=config.files.single_channel_psf, split=split_test, @@ -488,7 +494,6 @@ def train_learned(config): psf_channels = 1 else: psf_channels = 3 - print(psf.shape) recon = MultiWiener( in_channels=3,