Skip to content

Commit

Permalink
Add compensation branch option.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jun 9, 2024
1 parent 2685ec3 commit 515f62f
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 39 deletions.
1 change: 1 addition & 0 deletions configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ batchsize: 1 # must be 1 for iterative approaches

huggingface:
repo: "bezzam/DigiCam-Mirflickr-MultiMask-25K"
cache_dir: null # where to read/write dataset. Defaults to `"~/.cache/huggingface/datasets"`.
psf: null # null for simulating PSF
image_res: [900, 1200] # used during measurement
rotate: True # if measurement is upside-down
Expand Down
3 changes: 3 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ reconstruction:
freeze: null
unfreeze: null
train_last_layer: False
# number of channels for each compensation layer, list should equal to the number of layers (n_iter)
# and the last element should be equal to last layer of post_process.nc
compensation: null

#Trainable Mask
trainable_mask:
Expand Down
2 changes: 0 additions & 2 deletions lensless/recon/drunet/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
# #############################################################################

from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F


"""
Expand Down
42 changes: 34 additions & 8 deletions lensless/recon/drunet/network_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# https://github.com/cszn/DPIR/blob/15bca3fcc1f3cc51a1f99ccf027691e278c19354/models/network_unet.py
# #############################################################################


import torch
import torch.nn as nn
import lensless.recon.drunet.basicblock as B
import numpy as np
from torchvision.transforms.functional import resize

"""
# ====================
Expand Down Expand Up @@ -109,11 +110,13 @@ def __init__(
act_mode="R",
downsample_mode="strideconv",
upsample_mode="convtranspose",
concatenate_compensation=False,
):
super(UNetRes, self).__init__()

assert len(nc) == 4, "nc's length should be 4."

self.nc = nc
self.m_head = B.conv(in_nc, nc[0], bias=False, mode="C")

# downsample
Expand All @@ -139,9 +142,21 @@ def __init__(
downsample_block(nc[2], nc[3], bias=False, mode="2")
)

self.m_body = B.sequential(
*[B.ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C") for _ in range(nb)]
)
self.concatenate_compensation = concatenate_compensation
if concatenate_compensation:
self.m_body = B.sequential(
*[
B.ResBlock(nc[3] * 2, nc[3] * 2, bias=False, mode="C" + act_mode + "C")
for _ in range(nb)
]
)
else:
self.m_body = B.sequential(
*[
B.ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C")
for _ in range(nb)
]
)

# upsample
if upsample_mode == "upconv":
Expand All @@ -154,7 +169,9 @@ def __init__(
raise NotImplementedError("upsample mode [{:s}] is not found".format(upsample_mode))

self.m_up3 = B.sequential(
upsample_block(nc[3], nc[2], bias=False, mode="2"),
upsample_block(
nc[3] * 2 if concatenate_compensation else nc[3], nc[2], bias=False, mode="2"
),
*[B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") for _ in range(nb)]
)
self.m_up2 = B.sequential(
Expand All @@ -168,13 +185,22 @@ def __init__(

self.m_tail = B.conv(nc[0], out_nc, bias=False, mode="C")

def forward(self, x0):
def forward(self, x0, compensation_output=None):
if self.concatenate_compensation:
assert compensation_output is not None, "compensation_output should not be None."
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)

if compensation_output is not None:
compensation_output_re = resize(compensation_output, tuple(x4.shape[-2:]))
latent = torch.cat([x4, compensation_output_re], dim=1)
else:
latent = x4

x = self.m_body(latent)
x = self.m_up3(x + latent)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
Expand Down
3 changes: 3 additions & 0 deletions lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@
"Unet4M+U10+Unet4M_nodead": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-nodead",
},
"mirflickr_multi_25k": {
# simulated PSFs (without waveprop, with deadspace)
"Unet8M": "bezzam/digicam-mirflickr-multi-25k-unet8M",
"Unet4M+U10+Unet4M": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M",
# simulated PSF (with waveprop, with deadspace)
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave",
},
},
}
Expand Down
66 changes: 46 additions & 20 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
skip_unrolled=False,
return_unrolled_output=False,
legacy_denoiser=False,
compensation=None,
**kwargs,
):
"""
Expand All @@ -63,25 +64,28 @@ def __init__(
Parameters
----------
psf : :py:class:`~torch.Tensor`
Point spread function (PSF) that models forward propagation.
Must be of shape (depth, height, width, channels) even if
depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf`
to load a PSF from a file such that it is in the correct format.
dtype : float32 or float64
Data type to use for optimization.
n_iter : int
Number of iterations for unrolled algorithm.
pre_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional
If :py:class:`function` : Function to apply to the image estimate before algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible.
If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate before algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False.
post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional
If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible.
If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False.
skip_unrolled : bool, optional
Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction).
return_unrolled_output : bool, optional
Whether to return the output of the unrolled algorithm if also using a post-processor block.
psf : :py:class:`~torch.Tensor`
Point spread function (PSF) that models forward propagation.
Must be of shape (depth, height, width, channels) even if
depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf`
to load a PSF from a file such that it is in the correct format.
dtype : float32 or float64
Data type to use for optimization.
n_iter : int
Number of iterations for unrolled algorithm.
pre_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional
If :py:class:`function` : Function to apply to the image estimate before algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible.
If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate before algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False.
post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional
If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible.
If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False.
skip_unrolled : bool, optional
Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction).
return_unrolled_output : bool, optional
Whether to return the output of the unrolled algorithm if also using a post-processor block.
compensation : list, optional
Number of channels for each intermediate output in compensation layer, as in "Robust Reconstruction With Deep Learning to Handle Model Mismatch in Lensless Imaging" (2021).
Post-processor must be defined if compensation provided.
"""
assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor"
super(TrainableReconstructionAlgorithm, self).__init__(
Expand All @@ -93,6 +97,18 @@ def __init__(
self.set_post_process(post_process)
self.skip_unrolled = skip_unrolled
self.return_unrolled_output = return_unrolled_output
self.compensation_branch = compensation
if compensation is not None:
from lensless.recon.utils import CompensationBranch

assert (
post_process is not None
), "If compensation_branch is True, post_process must be defined."
assert (
len(compensation) == n_iter
), "compensation_nc must have the same length as n_iter"
self.compensation_branch = CompensationBranch(compensation)

if self.return_unrolled_output:
assert (
post_process is not None
Expand Down Expand Up @@ -231,15 +247,25 @@ def forward(self, batch, psfs=None):

# unrolled algorithm
if not self.skip_unrolled:
if self.compensation_branch is not None:
compensation_branch_inputs = [self._data]

for i in range(self._n_iter):
self._update(i)
if self.compensation_branch is not None and i < self._n_iter - 1:
compensation_branch_inputs.append(self._form_image())

image_est = self._form_image()
else:
image_est = self._data

# post process data
if self.post_process is not None:
final_est = self.post_process(image_est, self.post_process_param)
compensation_output = None
if self.compensation_branch is not None:
compensation_output = self.compensation_branch(compensation_branch_inputs)

final_est = self.post_process(image_est, self.post_process_param, compensation_output)
else:
final_est = image_est

Expand Down
2 changes: 1 addition & 1 deletion lensless/recon/unrolled_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,5 @@ def _update(self, iter):

def _form_image(self):
image = self._convolver._crop(self._image_est)
image[image < 0] = 0
image = torch.clamp(image, min=0)
return image
Loading

0 comments on commit 515f62f

Please sign in to comment.