Skip to content

Commit

Permalink
Add models, optional background subtraction.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jul 9, 2024
1 parent 6f2ff1e commit c0124a4
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
55 changes: 39 additions & 16 deletions configs/benchmark_digicam_mirflickr_single.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ defaults:

dataset: HFDataset
batchsize: 4
device: "cuda:0"
device: "cuda:3"

huggingface:
repo: "bezzam/DigiCam-Mirflickr-SingleMask-25K"
cache_dir: /dev/shm
psf: null # null for simulating PSF
image_res: [900, 1200] # used during measurement
rotate: True # if measurement is upside-down
Expand All @@ -20,26 +21,48 @@ huggingface:
height: 200
downsample: 1


algorithms: [
## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=True)
"hf:digicam:mirflickr_single_25k:U10_wave",
"hf:digicam:mirflickr_single_25k:Unet8M_wave",
"ADMM",
"hf:digicam:mirflickr_single_25k:U5+Unet8M_wave",
"hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave",
"hf:digicam:mirflickr_single_25k:U10+Unet8M_wave",
"hf:digicam:mirflickr_single_25k:MMCN4M+Unet4M_wave",
"hf:digicam:mirflickr_single_25k:MWDN8M_wave",
"hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave",
"hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave"

## -- below models need to set correct PSF simulation
# ## - measured PSF (huggingface.psf=psf_measured.png)
# "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_measured",
# ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=False)
# "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave_nodead",
# ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=True)
# "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M",
# ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=False)
# "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_nodead"
"hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave",
"hf:digicam:mirflickr_single_25k:Unet2M+MMCN+Unet2M_wave",
"hf:digicam:mirflickr_single_25k:Unet2M+MWDN6M_wave",
"hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave",

# ## -- reconstructions trained on other datasets/systems
# "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M",
# "hf:tapecam:mirflickr:Unet4M+U10+Unet4M",
# "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave",
# "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave",
]


# algorithms: [
# # ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=True)
# # "hf:digicam:mirflickr_single_25k:U10_wave",
# # "hf:digicam:mirflickr_single_25k:Unet8M_wave",
# # "hf:digicam:mirflickr_single_25k:TrainInv+Unet8M_wave",
# # "hf:digicam:mirflickr_single_25k:U10+Unet8M_wave",
# # "hf:digicam:mirflickr_single_25k:Unet4M+TrainInv+Unet4M_wave",
# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave"

# ## -- below models need to set correct PSF simulation
# # ## - measured PSF (huggingface.psf=psf_measured.png)
# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_measured",
# # ## - simulated PSF (simulation.use_waveprop=True, simulation.deadspace=False)
# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave_nodead",
# # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=True)
# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M",
# # ## - simulated PSF (simulation.use_waveprop=False, simulation.deadspace=False)
# # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_nodead"
# ]
save_idx: [1, 2, 4, 5, 9]
n_iter_range: [100] # for ADMM

# simulating PSF
simulation:
Expand Down
4 changes: 4 additions & 0 deletions lensless/recon/model_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-celeba-unet4M-trainable-inv-unet4M_wave",
"Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-celeba-unet2M-mmcn-unet2M",
"Unet4M+U5+Unet4M_wave": "bezzam/digicam-celeba-unet4M-unrolled-admm5-unet4M",
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-celeba-unet4M-unrolled-admm10-unet4M",
},
"mirflickr_single_25k": {
# simulated PSF (without waveprop, with deadspace)
Expand All @@ -112,6 +113,8 @@
"Unet2M+MMCN+Unet2M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mmcn-unet2M-wave",
"Unet4M+TrainInv+Unet4M_wave": "bezzam/digicam-mirflickr-single-25k-unet4M-trainable-inv-unet4M-wave",
"Unet2M+MWDN6M_wave": "bezzam/digicam-mirflickr-single-25k-unet2M-mwdn-6M",
"Unet4M+U5+Unet4M_wave_aux1": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-aux1",
"Unet4M+U5+Unet4M_wave_flips": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm5-unet4M-wave-flips",
# measured PSF
"Unet4M+U10+Unet4M_measured": "bezzam/digicam-mirflickr-single-25k-unet4M-unrolled-admm10-unet4M-measured",
# simulated PSF (with waveprop, no deadspace)
Expand All @@ -125,6 +128,7 @@
"Unet4M+U10+Unet4M": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M",
# simulated PSF (with waveprop, with deadspace)
"Unet4M+U10+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm10-unet4M-wave",
"Unet4M+U5+Unet4M_wave": "bezzam/digicam-mirflickr-multi-25k-unet4M-unrolled-admm5-unet4M-wave",
},
},
"tapecam": {
Expand Down
9 changes: 7 additions & 2 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,13 @@ def load_data(
use_3d = psf_fp.endswith(".npy") or psf_fp.endswith(".npz")

# load and process PSF data
psf, bg = load_psf(
bg = None
res = load_psf(
psf_fp,
downsample=downsample,
return_float=return_float,
bg_pix=bg_pix,
return_bg=True,
return_bg=True if bg_pix is not None else False,
flip=flip,
flip_ud=flip_ud,
flip_lr=flip_lr,
Expand All @@ -489,6 +490,10 @@ def load_data(
use_3d=use_3d,
bgr_input=bgr_input,
)
if bg_pix is not None:
psf, bg = res
else:
psf = res

# load and process raw measurement
data = load_image(
Expand Down
2 changes: 1 addition & 1 deletion scripts/data/rename_mirflickr25k.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lensless.utils.dataset import natural_sort


dir_path = "/root/mirflickr/mirflickr25k"
dir_path = "/dev/shm/mirflickr"

# get all jpg files
files = natural_sort(glob.glob(os.path.join(dir_path, "*.jpg")))
Expand Down

0 comments on commit c0124a4

Please sign in to comment.