Skip to content

Commit

Permalink
Add preproc support for multi-wiener, add single psf option for bench…
Browse files Browse the repository at this point in the history
…marking.
  • Loading branch information
ebezzam committed Jun 18, 2024
1 parent b1ad3a7 commit 26e860d
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 7 deletions.
3 changes: 3 additions & 0 deletions configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ huggingface:
downsample: 1
downsample_lensed: 1
split_seed: null
single_channel_psf: False

device: "cuda"
# numbers of iterations to benchmark
Expand All @@ -40,6 +41,8 @@ algorithms: ["ADMM", "ADMM_Monakhova2019", "FISTA"] #["ADMM", "ADMM_Monakhova201
baseline: "MONAKHOVA 100iter"

save_idx: [0, 1, 2, 3, 4] # provide index of files to save e.g. [1, 5, 10]
gamma_psf: 1.5 # gamma factor for PSF


# Hyperparameters
nesterov:
Expand Down
9 changes: 7 additions & 2 deletions lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@
"Unet4M+U10+Unet4M": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M",
# simulated PSF (with waveprop, with deadspace)
"U10_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-wave",
"U10+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-unet8M-wave",
"Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unet8M-wave",
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-wave",
"TrainInv+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-trainable-inv-unet8M-wave",
"U10+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm10-unet8M-wave",
"U5+Unet8M_wave": "bezzam/digicam-mirflickr-single-25k-unrolled-admm5-unet8M-wave",
"MWDN8M_wave": "bezzam/digicam-mirflickr-single-25k-mwdn-8M",
"MMCN4M+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-mmcn-unet4M",
"Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-trainable-inv-unet4M-wave",
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-wave",
# measured PSF
"Unet4M+U10+Unet4M_measured": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-measured",
# simulated PSF (with waveprop, no deadspace)
Expand All @@ -110,6 +113,7 @@
"Unet4M+TrainInv+Unet4M": "bezzam/tapecam-mirflickr-unet4M-trainable-inv-unet4M",
"Unet4M+U5+Unet4M": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M",
"Unet2M+MMCN+Unet2M": "bezzam/tapecam-mirflickr-unet2M-mmcn-unet2M",
"Unet2M+MWDN6M": "bezzam/tapecam-mirflickr-unet2M-mwdn-6M",
},
},
}
Expand Down Expand Up @@ -272,6 +276,7 @@ def load_model(
psf=psf,
psf_channels=3,
nc=config["reconstruction"]["multi_wiener"]["nc"],
pre_process=pre_process,
)
recon.to(device)

Expand Down
57 changes: 56 additions & 1 deletion lensless/recon/multi_wiener.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,16 @@ def WieNer(blur, psf, delta):


class MultiWiener(nn.Module):
def __init__(self, in_channels, out_channels, psf, psf_channels=1, nc=None):
def __init__(
self,
in_channels,
out_channels,
psf,
psf_channels=1,
nc=None,
pre_process=None,
skip_pre=False,
):
"""
Parameters
----------
Expand Down Expand Up @@ -165,6 +174,46 @@ def __init__(self, in_channels, out_channels, psf, psf_channels=1, nc=None):
self._n_iter = 1
self._convolver = RealFFTConvolve2D(psf, pad=True)

self.set_pre_process(pre_process)
self.skip_pre = skip_pre

def _prepare_process_block(self, process):
"""
Method for preparing the pre or post process block.
Parameters
----------
process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional
Pre or post process block to prepare.
"""
if isinstance(process, torch.nn.Module):
# If the post_process is a torch module, we assume it is a DruNet like network.
from lensless.recon.utils import get_drunet_function_v2

process_model = process
process_function = get_drunet_function_v2(process_model, mode="train")
elif process is not None:
# Otherwise, we assume it is a function.
assert callable(process), "pre_process must be a callable function"
process_function = process
process_model = None
else:
process_function = None
process_model = None

if process_function is not None:
process_param = torch.nn.Parameter(torch.tensor([1.0], device=self._psf.device))
else:
process_param = None

return process_function, process_model, process_param

def set_pre_process(self, pre_process):
(
self.pre_process,
self.pre_process_model,
self.pre_process_param,
) = self._prepare_process_block(pre_process)

def forward(self, batch, psfs=None):

if psfs is None:
Expand All @@ -178,6 +227,12 @@ def forward(self, batch, psfs=None):
if n_depth > 1:
raise NotImplementedError("3D not implemented yet.")

# pre process data
if self.pre_process is not None and not self.skip_pre:
device_before = batch.device
batch = self.pre_process(batch, self.pre_process_param)
batch = batch.to(device_before)

# pad to multiple of 8
batch = convert_to_NCHW(batch)
batch = torch.nn.functional.pad(
Expand Down
2 changes: 2 additions & 0 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ def train_epoch(self, data_loader):
mean_loss = 0.0
i = 1.0
pbar = tqdm(data_loader)
self.recon.train()
for batch in pbar:

# get batch
Expand Down Expand Up @@ -957,6 +958,7 @@ def evaluate(self, mean_loss, epoch, disp=None):
output_dir = os.path.join(output_dir, str(epoch))

# benchmarking
self.recon.eval()
current_metrics = benchmark(
self.recon,
self.test_dataset,
Expand Down
3 changes: 3 additions & 0 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,9 @@ def __init__(
single_psf=single_channel_psf,
)
self.psf = torch.from_numpy(psf)
if single_channel_psf:
# replicate across three channels
self.psf = self.psf.repeat(1, 1, 1, 3)

elif "mask_label" in data_0:
self.multimask = True
Expand Down
26 changes: 22 additions & 4 deletions scripts/eval/benchmark_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from lensless import ADMM, FISTA, GradientDescent, NesterovGradientDescent
from lensless.utils.dataset import DiffuserCamTestDataset, DigiCamCelebA, HFDataset
from lensless.utils.io import save_image
from lensless.utils.image import gamma_correction
from lensless.recon.model_dict import download_model, load_model

import torch
Expand Down Expand Up @@ -129,6 +130,7 @@ def benchmark_recon(config):
downsample_lensed=config.huggingface.downsample_lensed,
alignment=config.huggingface.alignment,
simulation_config=config.simulation,
single_channel_psf=config.huggingface.single_channel_psf,
)
if benchmark_dataset.multimask:
# get first PSF for initialization
Expand Down Expand Up @@ -199,6 +201,9 @@ def benchmark_recon(config):
)
if "hf" in algo:
param = algo.split(":")
assert (
len(param) == 4
), "hf model requires following format: hf:camera:dataset:model_name"
camera = param[1]
dataset = param[2]
model_name = param[3]
Expand All @@ -211,23 +216,32 @@ def benchmark_recon(config):
skip_post = False

model_path = download_model(camera=camera, dataset=dataset, model=model_name)
model_list.append(
(algo, load_model(model_path, psf, device, skip_pre=skip_pre, skip_post=skip_post))
)
model = load_model(model_path, psf, device, skip_pre=skip_pre, skip_post=skip_post)
model.eval()
model_list.append((algo, model))

results = {}
output_dir = None

# save PSF
psf_np = psf.cpu().numpy()[0]
psf_np = psf_np / np.max(psf_np)
psf_np = gamma_correction(psf_np, gamma=config.gamma_psf)
save_image(psf_np, fp="psf.png")

# save ground truth and lensless images
if config.save_idx is not None:

assert np.max(config.save_idx) < len(
benchmark_dataset
), "save_idx values must be smaller than dataset size"

os.mkdir("GROUND_TRUTH")
os.mkdir("LENSLESS")
for idx in config.save_idx:
ground_truth = benchmark_dataset[idx][1]
lensless, ground_truth = benchmark_dataset[idx]
ground_truth_np = ground_truth.cpu().numpy()[0]
lensless_np = lensless.cpu().numpy()[0]

if crop is not None:
ground_truth_np = ground_truth_np[
Expand All @@ -239,6 +253,10 @@ def benchmark_recon(config):
ground_truth_np,
fp=os.path.join("GROUND_TRUTH", f"{idx}.png"),
)
save_image(
lensless_np,
fp=os.path.join("LENSLESS", f"{idx}.png"),
)
# benchmark each model for different number of iteration and append result to results
# -- batchsize has to equal 1 as baseline models don't support batch processing
start_time = time.time()
Expand Down
1 change: 1 addition & 0 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def train_learned(config):
psf=psf,
psf_channels=3,
nc=config.reconstruction.multi_wiener.nc,
pre_process=pre_process if pre_proc_delay is None else None,
)

else:
Expand Down

0 comments on commit 26e860d

Please sign in to comment.