Skip to content

Commit

Permalink
Add support for preproc aux and initializing from HF model.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jun 14, 2024
1 parent ef6e9c0 commit b1ad3a7
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 54 deletions.
12 changes: 9 additions & 3 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ files:
dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM
huggingface_dataset: True
huggingface_psf: psf.tiff
single_channel_psf: False # whether to sum all PSF channels into one

# -- train/test split
split_seed: null # if null use train/test split from dataset
Expand Down Expand Up @@ -67,7 +68,11 @@ reconstruction:
# Method: unrolled_admm, unrolled_fista, trainable_inv
method: unrolled_admm
skip_unrolled: False
init_processors: null # model name

# initialize with "init_processors"
# -- for HuggingFace model use "hf:camera:dataset:model_name"
# -- for local model use "local:model_path"
init_processors: null
init_pre: True # if `init_processors`, set pre-procesor is available
init_post: True # if `init_processors`, set post-procesor is available

Expand Down Expand Up @@ -149,7 +154,6 @@ simulation:
max_val: 255

#Training

training:
batch_size: 8
epoch: 25
Expand All @@ -174,4 +178,6 @@ loss: 'l2'
# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1)
lpips: 1.0
unrolled_output_factor: False # whether to account for unrolled output in loss (there must post-processor)
pre_proc_aux: False # factor for auxiliary pre-processor loss to promote measurement consistency -> ||pre_proc(y) - Ax||
# factor for auxiliary pre-processor loss to promote measurement consistency -> ||pre_proc(y) - A * camera_inversion(y)||
# -- use camera inversion output so that doesn't include enhancements / coloring by post-processor
pre_proc_aux: False
36 changes: 30 additions & 6 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def benchmark(
save_idx=None,
output_dir=None,
unrolled_output_factor=False,
pre_process_aux=False,
return_average=True,
snr=None,
use_wandb=False,
Expand Down Expand Up @@ -106,6 +107,8 @@ def benchmark(
for key in output_metrics:
if key != "ReconstructionError":
metrics_values[key + "_unrolled"] = []
if pre_process_aux:
metrics_values["ReconstructionError_PreProc"] = []

# loop over batches
dataloader = DataLoader(dataset, batch_size=batchsize, pin_memory=(device != "cpu"))
Expand Down Expand Up @@ -141,14 +144,18 @@ def benchmark(
model._set_psf(psfs[0])
model.set_data(lensless)
prediction = model.apply(
plot=False, save=False, output_intermediate=unrolled_output_factor, **kwargs
plot=False,
save=False,
output_intermediate=unrolled_output_factor or pre_process_aux,
**kwargs,
)

else:
prediction = model.forward(lensless, psfs, **kwargs)

if unrolled_output_factor:
unrolled_out = prediction[-1]
if unrolled_output_factor or pre_process_aux:
pre_process_out = prediction[2]
unrolled_out = prediction[1]
prediction = prediction[0]
prediction_original = prediction.clone()

Expand Down Expand Up @@ -245,7 +252,17 @@ def benchmark(
unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3)

# -- extraction region of interest
if crop is not None:
if hasattr(dataset, "alignment"):
if dataset.alignment is not None:
unrolled_out = dataset.extract_roi(unrolled_out, axis=(-2, -1))
else:
unrolled_out = dataset.extract_roi(
unrolled_out,
axis=(-2, -1),
# lensed=lensed # lensed already extracted before
)
assert np.all(lensed.shape == unrolled_out.shape)
elif crop is not None:
unrolled_out = unrolled_out[
...,
crop["vertical"][0] : crop["vertical"][1],
Expand Down Expand Up @@ -288,13 +305,20 @@ def benchmark(
else:
metrics_values[metric + "_unrolled"].append(vals.item())

# compute metrics for pre-processed output
if pre_process_aux:
metrics_values["ReconstructionError_PreProc"] += model.reconstruction_error(
prediction=prediction_original, lensless=pre_process_out
).tolist()

model.reset()
idx += batchsize

# average metrics
if return_average:
for metric in metrics:
if "MSE" in metric or "ReconstructionError" in metric or "LPIPS" in metric:
for metric in metrics_values.keys():
if "MSE" in metric or "LPIPS" in metric:
# differently because metrics are grouped into bathces
metrics_values[metric] = np.sum(metrics_values[metric]) / len(dataset)
else:
metrics_values[metric] = np.mean(metrics_values[metric])
Expand Down
26 changes: 26 additions & 0 deletions lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lensless.recon.trainable_inversion import TrainableInversion
from lensless.hardware.trainable_mask import prep_trainable_mask
import yaml
from lensless.recon.multi_wiener import MultiWiener
from huggingface_hub import snapshot_download
from collections import OrderedDict

Expand Down Expand Up @@ -100,6 +101,17 @@
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave",
},
},
"tapecam": {
"mirflickr": {
"U5+Unet8M": "bezzam/tapecam-mirflickr-unrolled-admm5-unet8M",
"TrainInv+Unet8M": "bezzam/tapecam-mirflickr-trainable-inv-unet8M",
"MMCN4M+Unet4M": "bezzam/tapecam-mirflickr-mmcn-unet4M",
"MWDN8M": "bezzam/tapecam-mirflickr-mwdn-8M",
"Unet4M+TrainInv+Unet4M": "bezzam/tapecam-mirflickr-unet4M-trainable-inv-unet4M",
"Unet4M+U5+Unet4M": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M",
"Unet2M+MMCN+Unet2M": "bezzam/tapecam-mirflickr-unet2M-mmcn-unet2M",
},
},
}


Expand Down Expand Up @@ -225,6 +237,10 @@ def load_model(
if "nc" in config["reconstruction"]["post_process"].keys()
else None,
device=device,
# get from dict
concatenate_compensation=True
if config["reconstruction"].get("compensation", None) is not None
else False,
)

if config["reconstruction"]["method"] == "unrolled_admm":
Expand All @@ -237,6 +253,7 @@ def load_model(
legacy_denoiser=legacy_denoiser,
skip_pre=skip_pre,
skip_post=skip_post,
compensation=config["reconstruction"].get("compensation", None),
)
elif config["reconstruction"]["method"] == "trainable_inv":
recon = TrainableInversion(
Expand All @@ -248,6 +265,15 @@ def load_model(
skip_pre=skip_pre,
skip_post=skip_post,
)
elif config["reconstruction"]["method"] == "multi_wiener":
recon = MultiWiener(
in_channels=3,
out_channels=3,
psf=psf,
psf_channels=3,
nc=config["reconstruction"]["multi_wiener"]["nc"],
)
recon.to(device)

if mask is not None:
psf_learned = torch.nn.Parameter(psf_learned)
Expand Down
17 changes: 7 additions & 10 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def __init__(
assert len(psf.shape) == 4, "PSF must be 4D: (depth, height, width, channels)."
assert psf.shape[3] == 3 or psf.shape[3] == 1, "PSF must either be rgb (3) or grayscale (1)"
self._psf = psf
self._npix = np.prod(self._psf.shape)
self._n_iter = n_iter

self._psf_shape = np.array(self._psf.shape)
Expand Down Expand Up @@ -611,22 +612,18 @@ def reconstruction_error(self, prediction=None, lensless=None):
if lensless is None:
lensless = self._data

convolver = self._convolver
# convolver = self._convolver
convolver = RealFFTConvolve2D(self._psf.to(prediction.device), **self._convolver_param)
if not convolver.pad:
prediction = convolver._pad(prediction)
Fx = convolver.convolve(prediction)
Fy = lensless
Hx = convolver.convolve(prediction)

if not convolver.pad:
Fx = convolver._crop(Fx)
Hx = convolver._crop(Hx)

# don't reduce batch dimension
if self.is_torch:
return torch.sum(torch.sqrt((Fx - Fy) ** 2), dim=(-1, -2, -3, -4)) / np.prod(
prediction.shape[1:]
)
return torch.sum(torch.sqrt((Hx - lensless) ** 2), dim=(-1, -2, -3, -4)) / self._npix

else:
return np.sum(np.sqrt((Fx - Fy) ** 2), axis=(-1, -2, -3, -4)) / np.prod(
prediction.shape[1:]
)
return np.sum(np.sqrt((Hx - lensless) ** 2), axis=(-1, -2, -3, -4)) / self._npix
18 changes: 10 additions & 8 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
skip_unrolled=False,
skip_pre=False,
skip_post=False,
return_unrolled_output=False,
return_intermediate=False,
legacy_denoiser=False,
compensation=None,
**kwargs,
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self.skip_unrolled = skip_unrolled
self.skip_pre = skip_pre
self.skip_post = skip_post
self.return_unrolled_output = return_unrolled_output
self.return_intermediate = return_intermediate
self.compensation_branch = compensation
if compensation is not None:
from lensless.recon.utils import CompensationBranch
Expand All @@ -112,11 +112,12 @@ def __init__(
len(compensation) == n_iter
), "compensation_nc must have the same length as n_iter"
self.compensation_branch = CompensationBranch(compensation)
self.compensation_branch = self.compensation_branch.to(self._psf.device)

if self.return_unrolled_output:
if self.return_intermediate:
assert (
post_process is not None
), "If return_unrolled_output is True, post_process must be defined."
post_process is not None or pre_process is not None
), "If return_intermediate is True, post_process or pre_process must be defined."
if self.skip_unrolled:
assert (
post_process is not None or pre_process is not None
Expand Down Expand Up @@ -246,6 +247,7 @@ def forward(self, batch, psfs=None):
device_before = self._data.device
self._data = self.pre_process(self._data, self.pre_process_param)
self._data = self._data.to(device_before)
pre_processed = self._data

self.reset(batch_size=batch_size)

Expand Down Expand Up @@ -273,8 +275,8 @@ def forward(self, batch, psfs=None):
else:
final_est = image_est

if self.return_unrolled_output:
return final_est, image_est
if self.return_intermediate:
return final_est, image_est, pre_processed
else:
return final_est

Expand Down Expand Up @@ -365,7 +367,7 @@ def apply(
plt.savefig(plib.Path(save) / "final.png")

if output_intermediate:
return im, pre_processed_image, pre_post_process_image
return im, pre_post_process_image, pre_processed_image
elif plot:
return im, ax
else:
Expand Down
3 changes: 2 additions & 1 deletion lensless/recon/unrolled_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,6 @@ def _update(self, iter):

def _form_image(self):
image = self._convolver._crop(self._image_est)
image = torch.clamp(image, min=0)
# image = torch.clamp(image, min=0)
image = torch.clip(image, min=0.0)
return image
Loading

0 comments on commit b1ad3a7

Please sign in to comment.