Skip to content

Commit

Permalink
Expose PSF simulation parameters, ensure RGB for training.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed May 25, 2024
1 parent d13ef6c commit a5f6edf
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 41 deletions.
10 changes: 8 additions & 2 deletions configs/sim_digicam_psf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ dtype: float32
torch_device: cuda
requires_grad: False

# if repo not provided, check for local file
huggingface_repo: bezzam/DigiCam-CelebA-26K
huggingface_mask_pattern: mask_pattern.npy
huggingface_psf: psf_measured.png

digicam:

slm: adafruit
sensor: rpi_hq
downsample: null # null for no downsampling
downsample: 1

# https://drive.switch.ch/index.php/s/NdgHlcDeHVDH5ww?path=%2Fpsf
pattern: data/psf/adafruit_random_pattern_20230719.npy
Expand All @@ -33,7 +38,8 @@ sim:
flipud: True

# in practice found waveprop=True or False doesn't make difference
waveprop: False
waveprop: True
deadspace: True

# below are ignored if waveprop=False
scene2mask: 0.3 # [m]
Expand Down
8 changes: 7 additions & 1 deletion configs/train_digicam_singlemask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ files:
# TODO: these parameters should be in the dataset?
image_res: [900, 1200] # used during measurement
rotate: True # if measurement is upside-down
save_psf: False
save_psf: True

extra_eval:
multimask:
Expand Down Expand Up @@ -56,3 +56,9 @@ reconstruction:
network : UnetRes # UnetRes or DruNet or null
depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [32,64,116,128]

simulation:
use_waveprop: True
deadspace: True
scene2mask: 0.3
mask2sensor: 0.002
6 changes: 4 additions & 2 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ files:
# horizontal: null
image_res: null # for measured data, what resolution used at screen
extra_eval: null # dict of extra datasets to evaluate on
force_rgb: False

alignment: null
# top_left: null # height, width
Expand All @@ -62,7 +63,7 @@ display:
save: True

reconstruction:
# Method: unrolled_admm, unrolled_fista
# Method: unrolled_admm, unrolled_fista, trainable_inv
method: unrolled_admm
skip_unrolled: False
init_processors: null # model name
Expand Down Expand Up @@ -127,7 +128,8 @@ simulation:
# these distance parameters are typically fixed for a given PSF
# for DiffuserCam psf # for tape_rgb psf
scene2mask: 10e-2 # scene2mask: 40e-2
mask2sensor: 9e-3 # mask2sensor: 4e-3
mask2sensor: 9e-3 # mask2sensor: 4e-3
deadspace: True # whether to account for deadspace for programmable mask
# see waveprop.devices
use_waveprop: False # for PSF simulation
sensor: "rpi_hq"
Expand Down
96 changes: 70 additions & 26 deletions lensless/hardware/slm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
try:
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

torch_available = True
except ImportError:
Expand Down Expand Up @@ -123,13 +124,7 @@ def set_programmable_mask(pattern, device, rpi_username=None, rpi_hostname=None,


def get_programmable_mask(
vals,
sensor,
slm_param,
rotate=None,
flipud=False,
nbits=8,
color_filter=None,
vals, sensor, slm_param, rotate=None, flipud=False, nbits=8, color_filter=None, deadspace=True
):
"""
Get mask as a numpy or torch array. Return same type.
Expand All @@ -148,6 +143,8 @@ def get_programmable_mask(
Flip mask vertically.
nbits : int, optional
Number of bits/levels to quantize mask to.
deadspace: bool, optional
Whether to include deadspace around mask. Default is True.
"""

Expand All @@ -161,9 +158,8 @@ def get_programmable_mask(
# -- prepare SLM mask
n_active_slm_pixels = vals.shape
n_color_filter = np.prod(slm_param["color_filter"].shape[:2])
pixel_pitch = slm_param[SLMParam_wp.PITCH]
centers = get_centers(n_active_slm_pixels, pixel_pitch=pixel_pitch)

# -- prepare color filter
if color_filter is None and SLMParam_wp.COLOR_FILTER in slm_param.keys():
color_filter = slm_param[SLMParam_wp.COLOR_FILTER]
if isinstance(vals, torch.Tensor):
Expand All @@ -180,36 +176,84 @@ def get_programmable_mask(
else:
raise ValueError("color_filter must be numpy array or torch tensor")

d1 = sensor.pitch
_height_pixel, _width_pixel = (slm_param[SLMParam_wp.CELL_SIZE] / d1).astype(int)

# -- prepare mask
if use_torch:
mask = torch.zeros((n_color_filter,) + tuple(sensor.resolution)).to(vals)
slm_vals_flat = vals.flatten()
else:
mask = np.zeros((n_color_filter,) + tuple(sensor.resolution), dtype=dtype)
slm_vals_flat = vals.reshape(-1)
pixel_pitch = slm_param[SLMParam_wp.PITCH]
d1 = sensor.pitch
if deadspace:

centers = get_centers(n_active_slm_pixels, pixel_pitch=pixel_pitch)

_height_pixel, _width_pixel = (slm_param[SLMParam_wp.CELL_SIZE] / d1).astype(int)

for i, _center in enumerate(centers):

for i, _center in enumerate(centers):
_center_pixel = (_center / d1 + sensor.resolution / 2).astype(int)
_center_top_left_pixel = (
_center_pixel[0] - np.floor(_height_pixel / 2).astype(int),
_center_pixel[1] + 1 - np.floor(_width_pixel / 2).astype(int),
)
color_filter_idx = i // n_active_slm_pixels[1] % n_color_filter

mask_val = slm_vals_flat[i] * color_filter[color_filter_idx][0]
if isinstance(mask_val, np.ndarray):
mask_val = mask_val[:, np.newaxis, np.newaxis]
elif isinstance(mask_val, torch.Tensor):
mask_val = mask_val.unsqueeze(-1).unsqueeze(-1)
mask[
:,
_center_top_left_pixel[0] : _center_top_left_pixel[0] + _height_pixel,
_center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel,
] = mask_val

else:

_center_pixel = (_center / d1 + sensor.resolution / 2).astype(int)
_center_top_left_pixel = (
_center_pixel[0] - np.floor(_height_pixel / 2).astype(int),
_center_pixel[1] + 1 - np.floor(_width_pixel / 2).astype(int),
# use color filter to turn mask into RGB
if use_torch:
active_mask_rgb = torch.zeros((n_color_filter,) + n_active_slm_pixels).to(vals)
else:
active_mask_rgb = np.zeros((n_color_filter,) + n_active_slm_pixels, dtype=dtype)

# TODO avoid for loop
for i in range(n_active_slm_pixels[0]):
row_idx = i % color_filter.shape[0]
for j in range(n_active_slm_pixels[1]):

col_idx = j % color_filter.shape[1]
color_filter_idx = color_filter[row_idx, col_idx]
active_mask_rgb[
:, n_active_slm_pixels[0] - i - 1, n_active_slm_pixels[1] - j - 1
] = (vals[i, j] * color_filter_idx)

# size of active pixels in pixels
n_active_dim = np.around(slm_param[SLMParam_wp.PITCH] * n_active_slm_pixels / d1).astype(
int
)
# n_active_dim = np.around(slm_param[SLMParam_wp.CELL_SIZE] * n_active_slm_pixels / d1).astype(int)

color_filter_idx = i // n_active_slm_pixels[1] % n_color_filter
# resize to n_active_dim
if use_torch:
mask_active = transforms.functional.resize(
active_mask_rgb, n_active_dim, interpolation=InterpolationMode.NEAREST
)
else:
# TODO check
mask_active = np.zeros((n_color_filter,) + tuple(n_active_dim), dtype=dtype)
for i in range(n_color_filter):
mask_active[i] = np.resize(active_mask_rgb[i], n_active_dim)

mask_val = slm_vals_flat[i] * color_filter[color_filter_idx][0]
if isinstance(mask_val, np.ndarray):
mask_val = mask_val[:, np.newaxis, np.newaxis]
elif isinstance(mask_val, torch.Tensor):
mask_val = mask_val.unsqueeze(-1).unsqueeze(-1)
# pad to full mask
top_left = (sensor.resolution - n_active_dim) // 2
mask[
:,
_center_top_left_pixel[0] : _center_top_left_pixel[0] + _height_pixel,
_center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel,
] = mask_val
top_left[0] : top_left[0] + n_active_dim[0],
top_left[1] : top_left[1] + n_active_dim[1],
] = mask_active

# # quantize mask
# if use_torch:
Expand Down
3 changes: 3 additions & 0 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
horizontal_shift=None,
scene2mask=None,
mask2sensor=None,
deadspace=True,
downsample=None,
min_val=0,
**kwargs,
Expand Down Expand Up @@ -197,6 +198,7 @@ def __init__(
self.use_waveprop = use_waveprop
self.scene2mask = scene2mask
self.mask2sensor = mask2sensor
self.deadspace = deadspace
self.vertical_shift = vertical_shift
self.horizontal_shift = horizontal_shift
self.min_val = min_val
Expand All @@ -217,6 +219,7 @@ def get_psf(self):
rotate=self.rotate,
flipud=self.flipud,
color_filter=self._color_filter,
deadspace=self.deadspace,
)

if self.vertical_shift is not None:
Expand Down
24 changes: 21 additions & 3 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import cv2
from lensless.hardware.sensor import sensor_dict, SensorParam
from scipy.ndimage import rotate
import warnings


def convert(text):
Expand Down Expand Up @@ -1035,6 +1036,7 @@ def __init__(
return_mask_label=False,
save_psf=False,
simulation_config=dict(),
force_rgb=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -1087,6 +1089,7 @@ def __init__(
self.rotate = rotate
self.display_res = display_res
self.return_mask_label = return_mask_label
self.force_rgb = force_rgb # if some data is not 3D

# deduce downsampling factor from the first image
data_0 = self.dataset[0]
Expand All @@ -1105,7 +1108,6 @@ def __init__(
self.alignment = None
self.crop = None
if alignment is not None:
# preparing ground-truth in expected shape
if "top_left" in alignment:
self.alignment = dict(alignment.copy())
self.alignment["top_left"] = (
Expand All @@ -1117,9 +1119,7 @@ def __init__(
original_aspect_ratio = display_res[1] / display_res[0]
self.alignment["width"] = int(self.alignment["height"] * original_aspect_ratio)

# preparing ground-truth as simulated measurement of original
elif "crop" in alignment:
assert "simulation" in alignment, "Simulation config should be provided"
self.crop = dict(alignment["crop"].copy())
self.crop["vertical"][0] = int(self.crop["vertical"][0] / downsample)
self.crop["vertical"][1] = int(self.crop["vertical"][1] / downsample)
Expand All @@ -1139,6 +1139,7 @@ def __init__(
return_bg=True,
flip=rotate,
bg_pix=(0, 15),
force_rgb=force_rgb,
)
self.psf = torch.from_numpy(psf)

Expand Down Expand Up @@ -1167,6 +1168,7 @@ def __init__(
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
deadspace=simulation_config.get("deadspace", True),
)
self.psf[label] = mask.get_psf().detach()

Expand Down Expand Up @@ -1194,6 +1196,7 @@ def __init__(
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
deadspace=simulation_config.get("deadspace", True),
)
self.psf = mask.get_psf().detach()
assert (
Expand Down Expand Up @@ -1232,6 +1235,21 @@ def _get_images_pair(self, idx):
lensless_np = np.array(self.dataset[idx]["lensless"])
lensed_np = np.array(self.dataset[idx]["lensed"])

if self.force_rgb:
if len(lensless_np.shape) == 2:
warnings.warn(f"Converting lensless[{idx}] to RGB")
lensless_np = np.stack([lensless_np] * 3, axis=2)
elif len(lensless_np.shape) == 3:
pass
else:
raise ValueError(f"lensless[{idx}] should be 2D or 3D")

if len(lensed_np.shape) == 2:
warnings.warn(f"Converting lensed[{idx}] to RGB")
lensed_np = np.stack([lensed_np] * 3, axis=2)
elif len(lensed_np.shape) == 3:
pass

# convert to float
if lensless_np.dtype == np.uint8:
lensless_np = lensless_np.astype(np.float32) / 255
Expand Down
7 changes: 7 additions & 0 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def load_psf(
shape=None,
use_3d=False,
bgr_input=True,
force_rgb=False,
):
"""
Load and process PSF for analysis or for reconstruction.
Expand Down Expand Up @@ -305,6 +306,12 @@ def load_psf(
max_val = get_max_val(psf)
psf = np.array(psf, dtype=dtype)

if force_rgb:
if len(psf.shape) == 2:
psf = np.stack([psf] * 3, axis=2)
elif len(psf.shape) == 3:
pass

if use_3d:
if len(psf.shape) == 3:
grayscale = True
Expand Down
2 changes: 2 additions & 0 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def train_learned(config):
save_psf=config.files.save_psf,
n_files=config.files.n_files,
simulation_config=config.simulation,
force_rgb=config.files.force_rgb,
)
test_set = HFDataset(
huggingface_repo=config.files.dataset,
Expand All @@ -229,6 +230,7 @@ def train_learned(config):
save_psf=config.files.save_psf,
n_files=config.files.n_files,
simulation_config=config.simulation,
force_rgb=config.files.force_rgb,
)
if train_set.multimask:
# get first PSF for initialization
Expand Down
Loading

0 comments on commit a5f6edf

Please sign in to comment.