Skip to content

Commit

Permalink
Add option to use waveprop in simulating PSF, improve ROI extraction …
Browse files Browse the repository at this point in the history
…API.
  • Loading branch information
ebezzam committed May 13, 2024
1 parent 6632695 commit bd74342
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 46 deletions.
3 changes: 2 additions & 1 deletion configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ hydra:
chdir: True


dataset: DiffuserCam # DiffuserCam, DigiCamCelebA, DigiCamHF
dataset: DiffuserCam # DiffuserCam, DigiCamCelebA, HFDataset
seed: 0

huggingface:
Expand Down Expand Up @@ -88,6 +88,7 @@ simulation:
scene2mask: 0.25 # [m]
mask2sensor: 0.002 # [m]
# see waveprop.devices
use_waveprop: False # for PSF simulation
sensor: "rpi_hq"
snr_db: 10
# simulate different sensor resolution
Expand Down
3 changes: 2 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ files:

# -- processing parameters
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution
downsample_lensed: 2
downsample_lensed: 2 # only used if lensed if measured
input_snr: null # adding shot noise at input (for measured dataset) at this SNR in dB
vertical_shift: null
horizontal_shift: null
Expand Down Expand Up @@ -129,6 +129,7 @@ simulation:
scene2mask: 10e-2 # scene2mask: 40e-2
mask2sensor: 9e-3 # mask2sensor: 4e-3
# see waveprop.devices
use_waveprop: False # for PSF simulation
sensor: "rpi_hq"
snr_db: 10
# simulate different sensor resolution
Expand Down
19 changes: 8 additions & 11 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@ def benchmark(
if not os.path.exists(output_dir):
os.mkdir(output_dir)

alignment = None
if hasattr(dataset, "alignment"):
alignment = dataset.alignment

if metrics is None:
metrics = {
"MSE": MSELoss().to(device),
Expand Down Expand Up @@ -156,13 +152,14 @@ def benchmark(
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)

if alignment is not None:
prediction = prediction[
...,
alignment["topright"][0] : alignment["topright"][0] + alignment["height"],
alignment["topright"][1] : alignment["topright"][1] + alignment["width"],
]
# expected that lensed is also reshaped accordingly
if hasattr(dataset, "alignment"):
if dataset.alignment is not None:
prediction = dataset.extract_roi(prediction, axis=(-2, -1))
else:
prediction, lensed = dataset.extract_roi(
prediction, axis=(-2, -1), lensed=lensed
)
assert np.all(lensed.shape == prediction.shape)
elif crop is not None:
prediction = prediction[
...,
Expand Down
17 changes: 6 additions & 11 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,17 +648,12 @@ def train_epoch(self, data_loader):
y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3)

# extraction region of interest for loss
if (
hasattr(self.train_dataset, "alignment")
and self.train_dataset.alignment is not None
):
alignment = self.train_dataset.alignment
y_pred = y_pred[
...,
alignment["topright"][0] : alignment["topright"][0] + alignment["height"],
alignment["topright"][1] : alignment["topright"][1] + alignment["width"],
]
# expected that lensed is also reshaped accordingly
if hasattr(self.train_dataset, "alignment"):
if self.train_dataset.alignment is not None:
y_pred = self.train_dataset.extract_roi(y_pred, axis=(-2, -1))
else:
y_pred, y = self.train_dataset.extract_roi(y_pred, axis=(-2, -1), lensed=y)

elif self.crop is not None:
y_pred = y_pred[
...,
Expand Down
55 changes: 43 additions & 12 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from abc import abstractmethod
from torch.utils.data import Dataset, Subset
from torchvision import datasets, transforms
from torchvision.transforms import functional as F
from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD
from lensless.utils.simulation import FarFieldSimulator
from lensless.utils.io import load_image, load_psf, save_image
Expand All @@ -26,6 +27,7 @@
from huggingface_hub import hf_hub_download
import cv2
from lensless.hardware.sensor import sensor_dict, SensorParam
from scipy.ndimage import rotate


def convert(text):
Expand Down Expand Up @@ -1032,6 +1034,7 @@ def __init__(
alignment=None,
return_mask_label=False,
save_psf=False,
simulation_config=dict(),
**kwargs,
):
"""
Expand Down Expand Up @@ -1067,6 +1070,8 @@ def __init__(
If multimask dataset, return the mask label (True) or the corresponding PSF (False).
save_psf : bool, optional
If multimask dataset, save the simulated PSFs.
simulation_config : dict, optional
Simulation parameters for PSF if using a mask pattern.
"""

Expand Down Expand Up @@ -1159,6 +1164,9 @@ def __init__(
slm=slm,
downsample=downsample_fact,
flipud=rotate,
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
)
self.psf[label] = mask.get_psf().detach()

Expand All @@ -1172,6 +1180,7 @@ def __init__(

else:

# single mask pattern
mask_fp = hf_hub_download(
repo_id=huggingface_repo, filename="mask_pattern.npy", repo_type="dataset"
)
Expand All @@ -1182,12 +1191,19 @@ def __init__(
slm=slm,
downsample=downsample_fact,
flipud=rotate,
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
)
self.psf = mask.get_psf().detach()
assert (
self.psf.shape[-3:-1] == lensless.shape[:2]
), "PSF shape should match lensless shape"

if save_psf:
# same viewable image of PSF
save_image(self.psf.squeeze().cpu().numpy(), "psf.png")

# create simulator
self.simulator = None
self.vertical_shift = None
Expand Down Expand Up @@ -1233,6 +1249,7 @@ def _get_images_pair(self, idx):

lensless = lensless_np
lensed = lensed_np

if self.simulator is not None:
# convert to torch
lensless = torch.from_numpy(lensless_np)
Expand Down Expand Up @@ -1282,27 +1299,41 @@ def __getitem__(self, idx):
else:
return lensless, lensed

def extract_roi(self, reconstruction, lensed=None):
assert len(reconstruction.shape) == 4, "Reconstruction should have shape [B, H, W, C]"
if lensed is not None:
assert len(lensed.shape) == 4, "Lensed should have shape [B, H, W, C]"
def extract_roi(self, reconstruction, lensed=None, axis=(1, 2)):
n_dim = len(reconstruction.shape)
assert max(axis) < n_dim, "Axis should be within the dimensions of the reconstruction."

if self.alignment is not None:
top_right = self.alignment["topright"]
height = self.alignment["height"]
width = self.alignment["width"]
reconstruction = reconstruction[
:, top_right[0] : top_right[0] + height, top_right[1] : top_right[1] + width
]

# extract according to axis
index = [slice(None)] * n_dim
index[axis[0]] = slice(top_right[0], top_right[0] + height)
index[axis[1]] = slice(top_right[1], top_right[1] + width)
reconstruction = reconstruction[tuple(index)]

# rotate if necessary
angle = self.alignment.get("angle", 0)
if isinstance(reconstruction, torch.Tensor):
reconstruction = F.rotate(reconstruction, angle, expand=False)
else:
reconstruction = rotate(reconstruction, angle, axes=axis, reshape=False)

elif self.crop is not None:
vertical = self.crop["vertical"]
horizontal = self.crop["horizontal"]
reconstruction = reconstruction[
:, vertical[0] : vertical[1], horizontal[0] : horizontal[1]
]

# extract according to axis
index = [slice(None)] * n_dim
index[axis[0]] = slice(vertical[0], vertical[1])
index[axis[1]] = slice(horizontal[0], horizontal[1])
reconstruction = reconstruction[tuple(index)]
if lensed is not None:
lensed = lensed[:, vertical[0] : vertical[1], horizontal[0] : horizontal[1]]
if lensed is not None:
lensed = lensed[tuple(index)]

if self.alignment is None and lensed is not None:
return reconstruction, lensed
else:
return reconstruction
Expand Down
3 changes: 2 additions & 1 deletion scripts/eval/benchmark_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ def benchmark_recon(config):
_, benchmark_dataset = torch.utils.data.random_split(
dataset, [train_size, test_size], generator=generator
)
elif dataset == "DigiCamHF":
elif dataset == "HFDataset":
benchmark_dataset = HFDataset(
huggingface_repo=config.huggingface.repo,
split="test",
display_res=config.huggingface.image_res,
rotate=config.huggingface.rotate,
downsample=config.huggingface.downsample,
alignment=config.huggingface.alignment,
simulation_config=config.simulation,
)
if benchmark_dataset.multimask:
# get first PSF for initialization
Expand Down
26 changes: 17 additions & 9 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def train_learned(config):
test_set = None
psf = None
crop = None
alignment = None # very similar to crop, TODO: should switch to this approach
mask = None
if "DiffuserCam" in config.files.dataset and config.files.huggingface_dataset is False:

Expand Down Expand Up @@ -216,6 +215,7 @@ def train_learned(config):
alignment=config.alignment,
save_psf=config.files.save_psf,
n_files=config.files.n_files,
simulation_config=config.simulation,
)
test_set = HFDataset(
huggingface_repo=config.files.dataset,
Expand All @@ -228,6 +228,7 @@ def train_learned(config):
alignment=config.alignment,
save_psf=config.files.save_psf,
n_files=config.files.n_files,
simulation_config=config.simulation,
)
if train_set.multimask:
# get first PSF for initialization
Expand All @@ -239,7 +240,6 @@ def train_learned(config):
else:
psf = train_set.psf.to(device)
crop = test_set.crop # same for train set
alignment = test_set.alignment

# -- if learning mask
mask = prep_trainable_mask(config, psf)
Expand Down Expand Up @@ -277,6 +277,7 @@ def train_learned(config):
split="test",
downsample=config.files.downsample, # needs to be same size
n_files=config.files.n_files,
simulation_config=config.simulation,
**config.files.extra_eval[eval_set],
)

Expand Down Expand Up @@ -310,13 +311,20 @@ def train_learned(config):
# -- plot lensed and res on top of each other
cropped = False

if alignment is not None:
top_right = alignment["topright"]
height = alignment["height"]
width = alignment["width"]
res_np = res_np[
top_right[0] : top_right[0] + height, top_right[1] : top_right[1] + width
]
if hasattr(test_set, "alignment"):
if test_set.alignment is not None:
res_np = test_set.extract_roi(res_np, axis=(0, 1))
else:
res_np, lensed_np = test_set.extract_roi(
res_np, lensed=lensed_np, axis=(0, 1)
)
# if alignment is not None:
# top_right = alignment["topright"]
# height = alignment["height"]
# width = alignment["width"]
# res_np = res_np[
# top_right[0] : top_right[0] + height, top_right[1] : top_right[1] + width
# ]
cropped = True

elif config.training.crop_preloss:
Expand Down

0 comments on commit bd74342

Please sign in to comment.