Skip to content

Commit

Permalink
Pre and post denoising (#58)
Browse files Browse the repository at this point in the history
* Add suport for DruNet

(cherry picked from commit 8976a186b00b23d25805deef60a5957cd616cd0b)

* Better gpu selection

* Fix normalization

* Add cuda support for DruNet

* Add baseline result from learned reconstruction

(cherry picked from commit d0d49fd99985595a3e1b419042a0276ff16e5902)

* Support for post ADMM denoising

* Better error

* Fix LPIPS normalization

* Fix LPIPS

* Added More denoizing options

* Fix docstrings

* Update Changelog

* Fix name and doc

* Added original repo for Drunet

* Move post process to trainable reconstruction

* Better post processing with torch Module

* fix Trainable recon apply

* Fix downsample limited to 8

* Added Test for post processing

* Fix for PR

* Cleaning for PR

* Add inference with unrolled ADMM

* Add pre process support

* Update changelog

* Add test for preprocessing

* Fix callable assert even with None

* Fix for process = None

* Fix name in log

* More stable training

* Fix NAN during training

* Fix no output during training

* Clean Up

* Cleanup process creation in training

* Add support for reconstruction with denoising

* Fix bug without pre/post process

* Move drunet to recon module
  • Loading branch information
YohannPerron authored Jul 21, 2023
1 parent 82e09b0 commit 0bb0beb
Show file tree
Hide file tree
Showing 15 changed files with 1,014 additions and 108 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Added
- Link and citation for JOSS.
- Authors at top of source code files.
- Add paramiko as dependency for remote capture and display.
- Support for preprocessing and postprocessing, such as denoising, in ``TrainableReconstructionAlgorithm``. Both trainable and fix postprocessing can be used.
- Utilities to load a trained DruNet model for use as postprocessing in ``TrainableReconstructionAlgorithm``.
- support for unrolled loading and inference in the script ``admm.py``.


Changed
Expand Down
3 changes: 2 additions & 1 deletion configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ hydra:
job:
chdir: True

device: "cuda"
# numbers of iterations to benchmark
n_iter_range: [5, 10, 30, 60, 100, 200, 300]
# number of files to benchmark
n_files: 200
#How much should the image be downsampled
downsample: 8
#algorithm to benchmark
algorithms: ["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"]
algorithms: ["ADMM", "ADMM_Monakhova2019", "FISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"]

# Hyperparameters
nesterov:
Expand Down
9 changes: 9 additions & 0 deletions configs/defaults_recon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ admm:
mu2: 1e-5
mu3: 4e-5
tau: 0.0001
#Loading unrolled model
unrolled: false
checkpoint_fp: null
pre_process_model:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet
post_process_model:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet

apgd:
# Stopping criteria
Expand Down
21 changes: 15 additions & 6 deletions configs/unrolled_recon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ preprocess:
display:
# How many iterations to wait for intermediate plot.
# Set to negative value for no intermediate plots.
disp: 100
disp: 400
# Whether to plot results.
plot: True
# Gamma factor for plotting.
Expand All @@ -54,7 +54,12 @@ reconstruction:
mu2: 1e-4
mu3: 1e-4
tau: 2e-4

pre_process:
network : UnetRes # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet
post_process:
network : UnetRes # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet

# Train Dataset

Expand Down Expand Up @@ -90,13 +95,17 @@ simulation:
#Training

training:
batch_size: 16
epoch: 10
batch_size: 8
epoch: 50
#In case of instable training
skip_NAN: True
slow_start: False #float how much to reduce lr for first epoch


optimizer:
type: Adam
lr: 1e-4
lr: 1e-6

loss: 'l2'
# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1)
lpips: 0.6
lpips: 1.0
9 changes: 7 additions & 2 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,12 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
metrics = {
"MSE": MSELoss().to(device),
"MAE": L1Loss().to(device),
"LPIPS": lpip.LearnedPerceptualImagePatchSimilarity(net_type="vgg").to(device),
"LPIPS_Vgg": lpip.LearnedPerceptualImagePatchSimilarity(
net_type="vgg", normalize=True
).to(device),
"LPIPS_Alex": lpip.LearnedPerceptualImagePatchSimilarity(
net_type="alex", normalize=True
).to(device),
"PSNR": psnr.PeakSignalNoiseRatio().to(device),
"SSIM": StructuralSimilarityIndexMeasure().to(device),
"ReconstructionError": None,
Expand All @@ -283,7 +288,7 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)
# normalization
prediction_max = torch.amax(prediction, dim=(1, 2, 3), keepdim=True)
prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True)
if torch.all(prediction_max != 0):
prediction = prediction / prediction_max
else:
Expand Down
Loading

0 comments on commit 0bb0beb

Please sign in to comment.