Skip to content

Commit

Permalink
Add residual block to compensation.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jun 27, 2024
1 parent 9afc864 commit e72aa11
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 13 deletions.
2 changes: 2 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ files:
# psf: data/psf/diffusercam_psf.tiff
# diffusercam_psf: True

cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`.

# -- using huggingface dataset
dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM
huggingface_dataset: True
Expand Down
1 change: 0 additions & 1 deletion lensless/recon/drunet/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def __init__(
)

def forward(self, x):
# res = self.res(x)
return x + self.res(x)


Expand Down
42 changes: 41 additions & 1 deletion lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@
# baseline benchmarks which don't have model file but use ADMM
"admm_fista": "bezzam/diffusercam-mirflickr-admm-fista",
"admm_pnp": "bezzam/diffusercam-mirflickr-admm-pnp",
# -- TCI submission
"TrainInv+Unet8M": "bezzam/diffusercam-mirflickr-trainable-inv-unet8M",
"Unet4M+U5+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm5-unet4M",
"MWDN8M": "bezzam/diffusercam-mirflickr-mwdn-8M",
"Unet2M+MWDN6M": "bezzam/diffusercam-mirflickr-unet2M-mwdn-6M",
"Unet4M+TrainInv+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-trainable-inv-unet4M",
"MMCN4M+Unet4M": "bezzam/diffusercam-mirflickr-mmcn-unet4M",
"U5+Unet8M": "bezzam/diffusercam-mirflickr-unrolled-admm5-unet8M",
"Unet2M+MMCN+Unet2M": "bezzam/diffusercam-mirflickr-unet2M-mmcn-unet2M",
"Unet4M+U20+Unet4M": "bezzam/diffusercam-mirflickr-unet4M-unrolled-admm20-unet4M",
},
},
"digicam": {
Expand All @@ -70,6 +80,12 @@
# baseline benchmarks which don't have model file but use ADMM
"admm_measured_psf": "bezzam/digicam-celeba-admm-measured-psf",
"admm_simulated_psf": "bezzam/digicam-celeba-admm-simulated-psf",
# TCI submission (using waveprop simulation)
"U5+Unet8M_wave": "bezzam/digicam-celeba-unrolled-admm5-unet8M",
"TrainInv+Unet8M_wave": "bezzam/digicam-celeba-trainable-inv-unet8M_wave",
"MWDN8M_wave": "bezzam/digicam-celeba-mwnn-8M",
"MMCN4M+Unet4M_wave": "bezzam/digicam-celeba-mmcn-unet4M",
"Unet2M+MWDN6M_wave": "bezzam/digicam-celeba-unet2M-mwdn-6M",
},
"mirflickr_single_25k": {
# simulated PSF (without waveprop, with deadspace)
Expand All @@ -86,9 +102,12 @@
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-wave",
"TrainInv+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-trainable-inv-unet8M-wave",
"U5+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm5-unet8M-wave",
"Unet4M+U5+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave",
"MWDN8M_wave": "bezzam/digicam-mirflickr-single-25k-mwdn-8M",
"MMCN4M+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-mmcn-unet4M",
"Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mmcn-unet2M-wave",
"Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-trainable-inv-unet4M-wave",
"Unet2M+MWDN6M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mwdn-6M",
# measured PSF
"Unet4M+U10+Unet4M_measured": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-measured",
# simulated PSF (with waveprop, no deadspace)
Expand Down Expand Up @@ -270,11 +289,25 @@ def load_model(
skip_post=skip_post,
)
elif config["reconstruction"]["method"] == "multi_wiener":

if config["files"].get("single_channel_psf", False):

if torch.sum(psf[..., 0] - psf[..., 1]) != 0:
# need to sum difference channels
raise ValueError("PSF channels are not the same")
# psf = np.sum(psf, axis=3)

else:
psf = psf[..., 0].unsqueeze(-1)
psf_channels = 1
else:
psf_channels = 3

recon = MultiWiener(
in_channels=3,
out_channels=3,
psf=psf,
psf_channels=3,
psf_channels=psf_channels,
nc=config["reconstruction"]["multi_wiener"]["nc"],
pre_process=pre_process,
)
Expand All @@ -287,6 +320,13 @@ def load_model(
if "device_ids" in config.keys() and config["device_ids"] is not None:
model_state_dict = remove_data_parallel(model_state_dict)

# hotfixes for loading models
if config["reconstruction"]["method"] == "multi_wiener":
# replace "avgpool_conv" with "pool_conv"
model_state_dict = {
k.replace("avgpool_conv", "pool_conv"): v for k, v in model_state_dict.items()
}

recon.load_state_dict(model_state_dict)

return recon
80 changes: 72 additions & 8 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from lensless.utils.dataset import SimulatedDatasetTrainableMask


def double_cnn_max_pool(c_in, c_out, cnn_kernel=3, max_pool=2):
def double_cnn_max_pool(c_in, c_out, cnn_kernel=3, max_pool=2, padding=1, stride=1):
return nn.Sequential(
nn.Conv2d(
in_channels=c_in,
out_channels=c_out,
kernel_size=cnn_kernel,
padding="same",
padding=padding,
bias=False,
),
nn.BatchNorm2d(c_out),
Expand All @@ -39,21 +39,52 @@ def double_cnn_max_pool(c_in, c_out, cnn_kernel=3, max_pool=2):
in_channels=c_out,
out_channels=c_out,
kernel_size=cnn_kernel,
padding="same",
padding=padding,
bias=False,
),
nn.BatchNorm2d(c_out),
nn.ReLU(),
nn.MaxPool2d(kernel_size=max_pool),
# don't pass stride=1, otherwise no pooling/downsampling..
nn.MaxPool2d(kernel_size=max_pool, padding=0) if max_pool else nn.Identity(),
)


class ResBlock(nn.Module):
def __init__(self, c_in, c_out, cnn_kernel=3, max_pool=2, padding=1, stride=1):
super(ResBlock, self).__init__()
# assert c_in == c_out, "Input and output channels must be the same for residual block."

# conv layers for residual need to be the same size
self.double_conv = double_cnn_max_pool(
c_in, c_in, cnn_kernel=cnn_kernel, max_pool=False, padding=padding, stride=stride
)

# pooling
self.pooling = nn.Sequential(
nn.Conv2d(
in_channels=c_in,
out_channels=c_out,
kernel_size=cnn_kernel,
padding=padding,
bias=False,
),
nn.BatchNorm2d(c_out),
nn.ReLU(),
nn.MaxPool2d(kernel_size=max_pool, padding=0),
)

def forward(self, x):
return self.pooling(x + self.double_conv(x))


class CompensationBranch(nn.Module):
"""
Compensation branch for unrolled algorithm, as in "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021).
"""

def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3):
def __init__(
self, nc, cnn_kernel=3, max_pool=2, in_channel=3, residual=True, padding=1, stride=1
):
"""
Parameters
Expand All @@ -66,14 +97,23 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3):
Kernel size for max pooling layers, by default 2.
in_channel : int, optional
Number of input channels, by default 3 for RGB.
residual : bool, optional
Whether to use residual block or simply double conv for intermediate layers, by default True.
"""
super(CompensationBranch, self).__init__()

self.n_iter = len(nc)

# layers along the compensation branch, f^C in paper
branch_layers = [
double_cnn_max_pool(in_channel, nc[0], cnn_kernel=cnn_kernel, max_pool=max_pool)
double_cnn_max_pool(
in_channel,
nc[0],
cnn_kernel=cnn_kernel,
max_pool=max_pool,
padding=padding,
stride=stride,
)
]
self.branch_layers = nn.ModuleList(
branch_layers
Expand All @@ -83,6 +123,8 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3):
nc[i + 1],
cnn_kernel=cnn_kernel,
max_pool=max_pool,
padding=padding,
stride=stride,
)
for i in range(self.n_iter - 1)
]
Expand All @@ -92,8 +134,29 @@ def __init__(self, nc, cnn_kernel=3, max_pool=2, in_channel=3):
# -- not mentinoed in paper, but added more max-pooling for later residual layers, otherwise dimensions don't match
self.residual_layers = nn.ModuleList(
[
double_cnn_max_pool(
in_channel, nc[i], cnn_kernel=cnn_kernel, max_pool=max_pool ** (i + 1)
# double_cnn_max_pool(
# in_channel, nc[i], cnn_kernel=cnn_kernel, max_pool=max_pool ** (i + 1)
# )
# B.sequential(
# B.ResBlock(in_channel, in_channel, bias=False, mode="CRC", padding=padding, stride=stride),
# B.downsample_maxpool(in_channel, nc[i], bias=False, mode=str(max_pool ** (i + 1)), padding=padding, stride=stride)
# ) if residual
ResBlock(
in_channel,
nc[i],
cnn_kernel=cnn_kernel,
max_pool=max_pool ** (i + 1),
padding=padding,
stride=stride,
)
if residual
else double_cnn_max_pool(
in_channel,
nc[i],
cnn_kernel=cnn_kernel,
max_pool=max_pool ** (i + 1),
padding=padding,
stride=stride,
)
for i in range(self.n_iter - 1)
]
Expand All @@ -110,6 +173,7 @@ def forward(self, x, return_NCHW=True):
h_apo_k = self.branch_layers[0](convert_to_NCHW(x[0])) # h^{'}_k
for k in range(self.n_iter - 1): # eq. 18-21
# \tilde{h}_k
# import pudb; pudb.set_trace()
h_k = torch.cat([h_apo_k, self.residual_layers[k](convert_to_NCHW(x[k + 1]))], axis=1)
h_apo_k = self.branch_layers[k + 1](h_k) # h^{'}_k

Expand Down
115 changes: 115 additions & 0 deletions notebooks/multi_lens_array_2024-06-05_09-01-04.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
{
"seed": 4,
"n_lens": 15,
"radius_range": [
100.0,
400.0
],
"min_separation": 5.0,
"focal_length": [
859.0,
854.0,
852.0,
848.0,
780.0,
726.0,
684.0,
672.0,
574.0,
501.0,
382.0,
358.0,
346.0,
324.0,
221.0
],
"radius": [
395.0,
393.0,
392.0,
390.0,
359.0,
334.0,
314.0,
309.0,
264.0,
230.0,
176.0,
165.0,
159.0,
149.0,
102.0
],
"min_height": 1000.0,
"loc": [
[
2014.0,
419.0
],
[
2989.0,
1577.0
],
[
2970.0,
2528.0
],
[
594.0,
2024.0
],
[
1495.0,
1828.0
],
[
640.0,
782.0
],
[
2145.0,
1467.0
],
[
2107.0,
2520.0
],
[
796.0,
1380.0
],
[
785.0,
3022.0
],
[
1535.0,
2790.0
],
[
1448.0,
1215.0
],
[
940.0,
2490.0
],
[
2607.0,
305.0
],
[
3076.0,
3331.0
]
],
"mask_size": [
3500.0,
3500.0
],
"frame_size": [
3500.0,
3500.0
],
"refractive_index": 1.46
}
Loading

0 comments on commit e72aa11

Please sign in to comment.