Skip to content

Commit

Permalink
Add script for psf err vs metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Oct 26, 2024
1 parent fc89d28 commit 1bc1ad5
Show file tree
Hide file tree
Showing 12 changed files with 399 additions and 27 deletions.
2 changes: 2 additions & 0 deletions configs/authen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ save_idx: [1, 2, 4, 5, 9]

font_scale: 1.5 # for plotting confusion matrix

metric: "recon" # "recon", "mse", "lpips"

# Dataset parameters
huggingface:
repo: "bezzam/DigiCam-Mirflickr-MultiMask-25K"
Expand Down
4 changes: 2 additions & 2 deletions configs/benchmark_digicam_mirflickr_multi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ algorithms: [

## -- reconstructions trained on measured data
"hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave",
"hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave",
# "hf:digicam:mirflickr_multi_25k:Unet4M+U10+Unet4M_wave",
"hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_aux1",
"hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips",
# "hf:digicam:mirflickr_multi_25k:Unet4M+U5+Unet4M_wave_flips",

# ## -- reconstructions trained on other datasets/systems
# "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M",
Expand Down
1 change: 1 addition & 0 deletions configs/recon_digicam_mirflickr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defaults:
- defaults_recon
- _self_

dataset: mirflickr_single_25k # for loading model, "mirflickr_single_25k" or "mirflickr_multi_25k"
cache_dir: /dev/shm

# fn: null # if not null, download this file from https://huggingface.co/datasets/bezzam/DigiCam-Mirflickr-SingleMask-25K/tree/main
Expand Down
22 changes: 22 additions & 0 deletions configs/recon_digicam_mirflickr_err.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# python scripts/recon/digicam_mirflickr.py
defaults:
- defaults_recon
- _self_

cache_dir: null
metrics_fp : null
hf_repo: null # by default use one in model config

# set model
# -- for learning-based methods (comment if using ADMM)
model: Unet4M+U5+Unet4M_wave

# # -- for ADMM with fixed parameters
# model: admm
# n_iter: 10

device: cuda:1
save_idx: [1, 2, 4, 5, 9]
n_files: null
percent_pixels_wrong: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
flip: True # whether to flip mask values (True) or reset them (False)
6 changes: 5 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ reconstruction:
init_pre: True # if `init_processors`, set pre-procesor is available
init_post: True # if `init_processors`, set post-procesor is available

# processing PSF
psf_network: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64], with skip connection
psf_residual: True # if psf_network used, whether to use residual connection for original PSF estimate

# background subtraction (if dataset has corresponding background images)
direct_background_subtraction: False # True or False
learned_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64]
Expand Down Expand Up @@ -193,10 +197,10 @@ optimizer:
type: AdamW # Adam, SGD... (Pytorch class)
lr: 1e-4
lr_step_epoch: True # True -> update LR at end of each epoch, False at the end of each mini-batch
cosine_decay_warmup: True # if set, cosine decay with warmup of 5%
final_lr: False # if set, exponentially decay *to* this value
exp_decay: False # if set, exponentially decay *with* this value
slow_start: False #float how much to reduce lr for first epoch
cosine_decay_warmup: True # if set, cosine decay with warmup of 5%
# Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
step: False # int, period of learning rate decay. False to not apply
gamma: 0.1 # float, factor for learning rate decay
Expand Down
4 changes: 3 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def benchmark(
)

else:
prediction = model.forward(lensless, psfs, background=background, **kwargs)
prediction = model.forward(
batch=lensless, psfs=psfs, background=background, **kwargs
)

if unrolled_output_factor or pre_process_aux:
pre_process_out = prediction[2]
Expand Down
30 changes: 30 additions & 0 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
legacy_denoiser=False,
compensation=None,
compensation_residual=True,
psf_network=None,
psf_residual=True,
# background subtraction
direct_background_subtraction=False,
background_network=None,
Expand Down Expand Up @@ -95,6 +97,10 @@ def __init__(
Post-processor must be defined if compensation provided.
compensation_residual : bool, optional
Whether to use residual connection in compensation layer.
psf_network : :py:class:`function` or :py:class:`~torch.nn.Module`, optional
Function or model to apply to PSF prior to camera inversion.
psf_residual : bool, optional
Whether to use residual connection in PSF network.
"""

assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor"
Expand Down Expand Up @@ -141,6 +147,12 @@ def __init__(
), "Cannot use direct_background_subtraction and background_network at the same time."
self.set_background_network(background_network)

# PSF network
self.psf_network = None
self.psf_residual = psf_residual
if psf_network is not None:
self.set_psf_network(psf_network)

# compensation branch
self.return_intermediate = return_intermediate
self.compensation_branch = compensation
Expand Down Expand Up @@ -227,6 +239,13 @@ def set_background_network(self, background_network):
self.background_network_param,
) = self._prepare_process_block(background_network)

def set_psf_network(self, psf_network):
(
self.psf_network,
self.psf_network_model,
self.psf_network_param,
) = self._prepare_process_block(psf_network)

def freeze_pre_process(self):
"""
Method for freezing the pre process block.
Expand Down Expand Up @@ -307,6 +326,15 @@ def forward(self, batch, psfs=None, background=None):
).to(self._data.device)
self._data = torch.clamp(self._data, 0, 1)

# set / transform PSFs if need be
if self.psf_network is not None:
if psfs is None:
psfs = self._psf
if self.psf_residual:
psfs = self.psf_network(psfs, self.psf_network_param).to(psfs.device) + psfs
else:
psfs = self.psf_network(psfs, self.psf_network_param).to(psfs.device)

if psfs is not None:
# assert same shape
assert psfs.shape == batch.shape, "psfs must have the same shape as batch"
Expand Down Expand Up @@ -381,6 +409,8 @@ def apply(
algorithm, the number of iteration isn't required. Note that `set_data` must be called
beforehand.
# TODO apply PSF network
Parameters
----------
disp_iter : int
Expand Down
70 changes: 51 additions & 19 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import warnings
from waveprop.noise import add_shot_noise
from lensless.utils.image import shift_with_pad
from PIL import Image


def convert(text):
Expand Down Expand Up @@ -1351,6 +1350,9 @@ def __init__(
self.display_res = display_res
self.return_mask_label = return_mask_label
self.force_rgb = force_rgb # if some data is not 3D
self.sensor = sensor
self.slm = slm
self.simulation_config = simulation_config

# augmentation
self.random_flip = random_flip
Expand All @@ -1372,6 +1374,7 @@ def __init__(
downsample_fact = min(sensor_res / lensless.shape[:2])
else:
downsample_fact = 1
self.downsample_fact = downsample_fact

# deduce recon shape from original image
self.alignment = None
Expand Down Expand Up @@ -1410,6 +1413,7 @@ def __init__(
# download all masks
# TODO: reshape directly with lensless image shape
self.multimask = False
self.huggingface_repo = huggingface_repo
if psf is not None:
# download PSF from huggingface
psf_fp = hf_hub_download(repo_id=huggingface_repo, filename=psf, repo_type="dataset")
Expand All @@ -1435,28 +1439,32 @@ def __init__(
for i in range(len(self.dataset)):
mask_labels.append(self.dataset[i]["mask_label"])
mask_labels = list(set(mask_labels))
self.mask_labels = mask_labels

# simulate all PSFs
self.psf = dict()
for label in mask_labels:
mask_fp = hf_hub_download(
repo_id=huggingface_repo,
filename=f"masks/mask_{label}.npy",
repo_type="dataset",
)
mask_vals = np.load(mask_fp)
mask = AdafruitLCD(
initial_vals=torch.from_numpy(mask_vals.astype(np.float32)),
sensor=sensor,
slm=slm,
downsample=downsample_fact,
flipud=self.rotate or flipud, # TODO separate commands?
use_waveprop=simulation_config.get("use_waveprop", False),
scene2mask=simulation_config.get("scene2mask", None),
mask2sensor=simulation_config.get("mask2sensor", None),
deadspace=simulation_config.get("deadspace", True),
)
self.psf[label] = mask.get_psf().detach()

mask_vals = self.get_mask_vals(label)
self.psf[label] = self.simulate_psf(mask_vals)
# mask_fp = hf_hub_download(
# repo_id=huggingface_repo,
# filename=f"masks/mask_{label}.npy",
# repo_type="dataset",
# )
# mask_vals = np.load(mask_fp)
# mask = AdafruitLCD(
# initial_vals=torch.from_numpy(mask_vals.astype(np.float32)),
# sensor=sensor,
# slm=slm,
# downsample=downsample_fact,
# flipud=self.rotate or flipud, # TODO separate commands?
# use_waveprop=simulation_config.get("use_waveprop", False),
# scene2mask=simulation_config.get("scene2mask", None),
# mask2sensor=simulation_config.get("mask2sensor", None),
# deadspace=simulation_config.get("deadspace", True),
# )
# self.psf[label] = mask.get_psf().detach()

assert (
self.psf[label].shape[-3:-1] == lensless.shape[:2]
Expand Down Expand Up @@ -1541,6 +1549,30 @@ def __init__(
def __len__(self):
return len(self.dataset)

def get_mask_vals(self, idx):
assert self.multimask
assert idx in self.mask_labels
mask_fp = hf_hub_download(
repo_id=self.huggingface_repo,
filename=f"masks/mask_{idx}.npy",
repo_type="dataset",
)
return np.load(mask_fp)

def simulate_psf(self, mask_vals):
mask = AdafruitLCD(
initial_vals=torch.from_numpy(mask_vals.astype(np.float32)),
sensor=self.sensor,
slm=self.slm,
downsample=self.downsample_fact,
flipud=self.rotate or self.flipud, # TODO separate commands?
use_waveprop=self.simulation_config.get("use_waveprop", False),
scene2mask=self.simulation_config.get("scene2mask", None),
mask2sensor=self.simulation_config.get("mask2sensor", None),
deadspace=self.simulation_config.get("deadspace", True),
)
return mask.get_psf().detach()

def _get_images_pair(self, idx):

# load images
Expand Down
2 changes: 1 addition & 1 deletion scripts/data/authenticate.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def authen(config):
if i in save_idx:
res_np = res[0].cpu().numpy()
res_np = res_np / res_np.max()
fp = os.path.join(save_dir, f"{psf_idx}.png")
fp = os.path.join(save_dir, f"psf{psf_idx}.png")
save_image(res_np, fp)

scores[str(mask_label)].append(np.array(scores_i).tolist())
Expand Down
5 changes: 2 additions & 3 deletions scripts/recon/digicam_mirflickr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@ def apply_pretrained(config):
model_config = yaml.safe_load(stream)

else:
model_path = download_model(
camera="digicam", dataset="mirflickr_single_25k", model=model_name
)
model_path = download_model(camera="digicam", dataset=config.dataset, model=model_name)
config_path = os.path.join(model_path, ".hydra", "config.yaml")
with open(config_path, "r") as stream:
model_config = yaml.safe_load(stream)

# load data
# TODO try with multi-mask, should load single mask dataset...
test_set = HFDataset(
huggingface_repo=model_config["files"]["dataset"],
psf=(
Expand Down
Loading

0 comments on commit 1bc1ad5

Please sign in to comment.