Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for dataset with measured background #142

Merged
merged 21 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Unreleased
Added
~~~~~

- Option to pass background image to ``utils.io.load_data``.
- Option to set image resolution with ``hardware.utils.display`` function.
- Add utility for mask adapter generation in ``lenseless.hardware.fabrication``
- Option to add simulated background in ``util.dataset``
Expand All @@ -28,11 +27,19 @@ Added
- HFSimulated object for simulating lensless data from ground-truth and PSF.
- Option to set cache directory for Hugging Face datasets.
- Option to initialize training with another model.
- Option to pass background image to ``utils.io.load_data``.
- Option to use background in ``lensless.eval.benchmark``.
- Different techniques to use measured background: direct subtraction, learned subtraction, integrated subtraction, concatenated to input.
- Learnable background subtraction for classes that derive from ``lensless.recon.trainable_recon.TrainableReconstructionAlgorithm``.
- Integrated background subtraction object ``lensless.recon.integrated_background.IntegratedBackgroundSub``.
- Option to concatenate background to input to pre-processor.
- Add support for datasets with measured background to ``lensless.utils.dataset.HFDataset``.


Changed
~~~~~~~

- Nothing
- ``lensless.utils.dataset.HFDataset`` no longer inherits from ``lensless.utils.dataset.DualDataset``.

Bugfix
~~~~~~
Expand Down
1 change: 1 addition & 0 deletions configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ huggingface:
downsample_lensed: 1
split_seed: null
single_channel_psf: False
use_background: True

device: "cuda"
# numbers of iterations to benchmark
Expand Down
45 changes: 45 additions & 0 deletions configs/benchmark_multilens_mirflickr_ambient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# python scripts/eval/benchmark_recon.py -cn benchmark_multilens_mirflickr_ambient
defaults:
- benchmark
- _self_

dataset: HFDataset
batchsize: 8
device: "cuda:0"

huggingface:
repo: Lensless/MultiLens-Mirflickr-Ambient
cache_dir: /dev/shm
psf: psf.png
image_res: [600, 600] # used during measurement
rotate: False # if measurement is upside-down
alignment:
top_left: [118, 220] # height, width
height: 123
use_background: True

## -- reconstructions trained with same dataset/system
algorithms: [
"ADMM",
"hf:multilens:mirflickr_ambient:U5+Unet8M",
"hf:multilens:mirflickr_ambient:U5+Unet8M_direct_sub",
"hf:multilens:mirflickr_ambient:U5+Unet8M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_direct_sub",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+U5+Unet4M_concat",
"hf:multilens:mirflickr_ambient:TrainInv+Unet8M",
"hf:multilens:mirflickr_ambient:TrainInv+Unet8M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_learned_sub",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_concat",
"hf:multilens:mirflickr_ambient:TrainInv+Unet8M_direct_sub",
"hf:multilens:mirflickr_ambient:Unet4M+TrainInv+Unet4M_direct_sub",
]

save_idx: [
1, 2, 4, 5, 9, 64, # bottom right
2141, 2155, 2162, 2225, 2502, 2602, # top right (door, flower, cookies, wolf, plush, sky)
3262, 3304, 3438, 3451, 3644, 3667 # bottom left (pancakes, flower, grapes, pencils, bird, sign)
]
n_iter_range: [100] # for ADMM
51 changes: 51 additions & 0 deletions configs/recon_multilens_ambient_mirflickr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# python scripts/recon/multilens_ambient_mirflickr.py
defaults:
- defaults_recon
- _self_

cache_dir: /dev/shm

## - Uncomment to reconstruct from dataset (screen capture)
idx: 1 # index from test set to reconstruct
fn: null # if not null, set local path or download this file from https://huggingface.co/datasets/Lensless/MultiLens-Mirflickr-Ambient/tree/main
background_fn: null

## - Uncomment to reconstruct plush parrot (direct capture)
# fn: parrot_raw.png
# background_fn: parrot_background.png
# rotate: False
# alignment:
# dim: [160, 160]
# top_left: [110, 200]

## - Uncomment to reconstruct plush monkey (direct capture)
# fn: monkey_raw.png
# background_fn: monkey_background.png
# rotate: False
# alignment:
# dim: [123, 123]
# top_left: [118, 220]

## - Uncomment to reconstruct plant (direct capture)
# fn: plant_raw.png
# background_fn: plant_background.png
# rotate: False
# alignment:
# dim: [200, 200]
# top_left: [60, 186]

## Reconstruction
background_sub: True # whether to subtract background

# -- for learning-based methods (uncommment one line)
model: Unet4M+U5+Unet4M_concat
# model: U5+Unet8M
# model: Unet4M+U5+Unet4M_learned_sub

# # -- for ADMM with fixed parameters (uncomment and comment learning-based methods)
# model: admm
# n_iter: 100

device: cuda:0
n_trials: 1 # to get average inference time
save: True
24 changes: 24 additions & 0 deletions configs/train_mirflickr_multilens_ambient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_multilens_ambient
defaults:
- train_mirflickr_tape
- _self_

wandb_project: multilens_ambient

# Dataset
files:
dataset: Lensless/MultiLens-Mirflickr-Ambient
cache_dir: /dev/shm
image_res: [600, 600]

reconstruction:
direct_background_subtraction: True

alignment:
# when there is no downsampling
top_left: [118, 220] # height, width
height: 123

optimizer:
type: AdamW
cosine_decay_warmup: True
24 changes: 24 additions & 0 deletions configs/train_mirflickr_tape_ambient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape_ambient
defaults:
- train_mirflickr_tape
- _self_

wandb_project: tapecam_ambient
device_ids:

# Dataset
files:
dataset: Lensless/TapeCam-Mirflickr-Ambient
image_res: [600, 600]

reconstruction:
direct_background_subtraction: True

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
33 changes: 33 additions & 0 deletions configs/train_mirflickr_tape_ambient_integrated_sub.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape_ambient_integrated_sub
defaults:
- train_mirflickr_tape
- _self_

wandb_project: tapecam_ambient
device_ids: [0, 1, 2, 3]
torch_device: cuda:0

# Dataset
files:
dataset: Lensless/TapeCam-Mirflickr-Ambient # 16K examples
cache_dir: /dev/shm
#dataset: Lensless/TapeCam-Mirflickr-Ambient-100 # 100 examples
image_res: [600, 600]

reconstruction:
# one or the other
direct_background_subtraction: False
learned_background_subtraction: False
integrated_background_subtraction: [32, 64, 128, 210, 210]
down_subtraction: False
pre_process:
network: null # TODO assert null when integrated_background_subtraction is not False

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
36 changes: 36 additions & 0 deletions configs/train_mirflickr_tape_ambient_learned_sub.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape_ambient_learned_sub
defaults:
- train_mirflickr_tape
- _self_

wandb_project: tapecam_ambient
device_ids: [0, 1 ,2, 3]
torch_device: cuda:0

# Dataset
files:
#n_files: 10
dataset: Lensless/TapeCam-Mirflickr-Ambient # 16K examples
#dataset: Lensless/TapeCam-Mirflickr-Ambient-100 # 100 examples
cache_dir: /dev/shm
image_res: [600, 600]

reconstruction:
# one or the other
direct_background_subtraction: False
learned_background_subtraction: [4, 8, 16, 32] # 127740 parameters, False to turn off
integrated_background_subtraction: False

pre_process: ## Targeting 3923428 parameters
network : UnetRes # UnetRes or DruNet or null
depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [32,64,112,128]

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
2 changes: 1 addition & 1 deletion configs/train_tapecam_simulated_background.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape
# python scripts/recon/train_learning_based.py -cn train_tapecam_simulated_background
defaults:
- train_mirflickr_tape
- _self_
Expand Down
14 changes: 11 additions & 3 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ reconstruction:
init_pre: True # if `init_processors`, set pre-procesor is available
init_post: True # if `init_processors`, set post-procesor is available

# background subtraction (if dataset has corresponding background images)
direct_background_subtraction: False # True or False
learned_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64]
integrated_background_subtraction: False # False, or set number of channels for UnetRes, e.g. [8,16,32,64]
down_subtraction: False # for integrated_background_subtraction, whether to concatenate background subtraction during downsample or upsample
integrated_background_unetres: False # whether to integrate within UNetRes
unetres_input_background: False # whether to input background to UNetRes


# Hyperparameters for each method
unrolled_fista: # for unrolled_fista
# Number of iterations
Expand Down Expand Up @@ -181,18 +190,17 @@ training:
crop_preloss: False # crop region for computing loss, files.crop should be set

optimizer:
type: Adam # Adam, SGD... (Pytorch class)
type: AdamW # Adam, SGD... (Pytorch class)
lr: 1e-4
lr_step_epoch: True # True -> update LR at end of each epoch, False at the end of each mini-batch
final_lr: False # if set, exponentially decay *to* this value
exp_decay: False # if set, exponentially decay *with* this value
slow_start: False #float how much to reduce lr for first epoch
cosine_decay_warmup: False # if set, cosine decay with warmup of 5%
cosine_decay_warmup: True # if set, cosine decay with warmup of 5%
# Decay LR in step fashion: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html
step: False # int, period of learning rate decay. False to not apply
gamma: 0.1 # float, factor for learning rate decay


loss: 'l2'
# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1)
lpips: 1.0
Expand Down
6 changes: 5 additions & 1 deletion configs/upload_multilens_mirflickr_ambient.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ defaults:
repo_id: "Lensless/MultiLens-Mirflickr-Ambient"
n_files:
test_size: 0.15

# # -- to match TapeCam dataset content distribution, and same light distribution in train/test
# split: 100 # "first: first `nfiles*test_size` for test, `int`: test_size*split for test (interleaved) as if multimask with this many masks

lensless:
dir: /dev/shm/all_measured_20240813-183259
ambient: True
ext: ".png"

lensed:
dir: data/mirflickr/mirflickr
dir: /root/LenslessPiCam/data/mirflickr/mirflickr
ext: ".jpg"

files:
Expand Down
9 changes: 8 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def benchmark(
use_wandb=False,
label=None,
epoch=None,
use_background=True,
**kwargs,
):
"""
Expand Down Expand Up @@ -69,6 +70,8 @@ def benchmark(
If True, return the average value of the metrics, by default True.
snr : float, optional
Signal to noise ratio for adding shot noise. If None, no noise is added, by default None.
use_background: bool, optional
If dataset has background, use it for reconstruction, by default True.

Returns
-------
Expand Down Expand Up @@ -121,8 +124,11 @@ def benchmark(

flip_lr = None
flip_ud = None
background = None
lensless = batch[0].to(device)
lensed = batch[1].to(device)
if dataset.measured_bg and use_background:
background = batch[-1].to(device)
if dataset.multimask or dataset.random_flip:
psfs = batch[2]
psfs = psfs.to(device)
Expand All @@ -146,11 +152,12 @@ def benchmark(
plot=False,
save=False,
output_intermediate=unrolled_output_factor or pre_process_aux,
background=background,
**kwargs,
)

else:
prediction = model.forward(lensless, psfs, **kwargs)
prediction = model.forward(lensless, psfs, background=background, **kwargs)

if unrolled_output_factor or pre_process_aux:
pre_process_out = prediction[2]
Expand Down
Loading
Loading