From 0bb0beb41b7d87e8b4d0bfb542b71a67480e3f2d Mon Sep 17 00:00:00 2001 From: YohannPerron <73244423+YohannPerron@users.noreply.github.com> Date: Fri, 21 Jul 2023 11:12:28 +0200 Subject: [PATCH] Pre and post denoising (#58) * Add suport for DruNet (cherry picked from commit 8976a186b00b23d25805deef60a5957cd616cd0b) * Better gpu selection * Fix normalization * Add cuda support for DruNet * Add baseline result from learned reconstruction (cherry picked from commit d0d49fd99985595a3e1b419042a0276ff16e5902) * Support for post ADMM denoising * Better error * Fix LPIPS normalization * Fix LPIPS * Added More denoizing options * Fix docstrings * Update Changelog * Fix name and doc * Added original repo for Drunet * Move post process to trainable reconstruction * Better post processing with torch Module * fix Trainable recon apply * Fix downsample limited to 8 * Added Test for post processing * Fix for PR * Cleaning for PR * Add inference with unrolled ADMM * Add pre process support * Update changelog * Add test for preprocessing * Fix callable assert even with None * Fix for process = None * Fix name in log * More stable training * Fix NAN during training * Fix no output during training * Clean Up * Cleanup process creation in training * Add support for reconstruction with denoising * Fix bug without pre/post process * Move drunet to recon module --- CHANGELOG.rst | 3 + configs/benchmark.yaml | 3 +- configs/defaults_recon.yaml | 9 + configs/unrolled_recon.yaml | 21 +- lensless/eval/benchmark.py | 9 +- lensless/recon/drunet/basicblock.py | 363 ++++++++++++++++++++++++++ lensless/recon/drunet/network_unet.py | 180 +++++++++++++ lensless/recon/trainable_recon.py | 89 ++++++- lensless/recon/unrolled_admm.py | 28 +- lensless/recon/unrolled_fista.py | 26 -- lensless/utils/image.py | 130 +++++++++ scripts/eval/benchmark_recon.py | 49 +++- scripts/recon/admm.py | 52 +++- scripts/recon/train_unrolled.py | 139 ++++++++-- test/test_algos.py | 21 +- 15 files changed, 1014 insertions(+), 108 deletions(-) create mode 100644 lensless/recon/drunet/basicblock.py create mode 100644 lensless/recon/drunet/network_unet.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0fd9008e..a6072041 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,9 @@ Added - Link and citation for JOSS. - Authors at top of source code files. - Add paramiko as dependency for remote capture and display. +- Support for preprocessing and postprocessing, such as denoising, in ``TrainableReconstructionAlgorithm``. Both trainable and fix postprocessing can be used. +- Utilities to load a trained DruNet model for use as postprocessing in ``TrainableReconstructionAlgorithm``. +- support for unrolled loading and inference in the script ``admm.py``. Changed diff --git a/configs/benchmark.yaml b/configs/benchmark.yaml index 4de33d91..c1169551 100644 --- a/configs/benchmark.yaml +++ b/configs/benchmark.yaml @@ -5,6 +5,7 @@ hydra: job: chdir: True +device: "cuda" # numbers of iterations to benchmark n_iter_range: [5, 10, 30, 60, 100, 200, 300] # number of files to benchmark @@ -12,7 +13,7 @@ n_files: 200 #How much should the image be downsampled downsample: 8 #algorithm to benchmark -algorithms: ["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"] +algorithms: ["ADMM", "ADMM_Monakhova2019", "FISTA"] #["ADMM", "ADMM_Monakhova2019", "FISTA", "GradientDescent", "NesterovGradientDescent"] # Hyperparameters nesterov: diff --git a/configs/defaults_recon.yaml b/configs/defaults_recon.yaml index 4cb78b97..5cd05d6c 100644 --- a/configs/defaults_recon.yaml +++ b/configs/defaults_recon.yaml @@ -61,6 +61,15 @@ admm: mu2: 1e-5 mu3: 4e-5 tau: 0.0001 + #Loading unrolled model + unrolled: false + checkpoint_fp: null + pre_process_model: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process_model: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet apgd: # Stopping criteria diff --git a/configs/unrolled_recon.yaml b/configs/unrolled_recon.yaml index bfe5ad3c..621e3cfa 100644 --- a/configs/unrolled_recon.yaml +++ b/configs/unrolled_recon.yaml @@ -27,7 +27,7 @@ preprocess: display: # How many iterations to wait for intermediate plot. # Set to negative value for no intermediate plots. - disp: 100 + disp: 400 # Whether to plot results. plot: True # Gamma factor for plotting. @@ -54,7 +54,12 @@ reconstruction: mu2: 1e-4 mu3: 1e-4 tau: 2e-4 - + pre_process: + network : UnetRes # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process: + network : UnetRes # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet # Train Dataset @@ -90,13 +95,17 @@ simulation: #Training training: - batch_size: 16 - epoch: 10 + batch_size: 8 + epoch: 50 + #In case of instable training + skip_NAN: True + slow_start: False #float how much to reduce lr for first epoch + optimizer: type: Adam - lr: 1e-4 + lr: 1e-6 loss: 'l2' # set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) -lpips: 0.6 \ No newline at end of file +lpips: 1.0 \ No newline at end of file diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index bd422da1..b4aa6b79 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -256,7 +256,12 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): metrics = { "MSE": MSELoss().to(device), "MAE": L1Loss().to(device), - "LPIPS": lpip.LearnedPerceptualImagePatchSimilarity(net_type="vgg").to(device), + "LPIPS_Vgg": lpip.LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=True + ).to(device), + "LPIPS_Alex": lpip.LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(device), "PSNR": psnr.PeakSignalNoiseRatio().to(device), "SSIM": StructuralSimilarityIndexMeasure().to(device), "ReconstructionError": None, @@ -283,7 +288,7 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs): prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) # normalization - prediction_max = torch.amax(prediction, dim=(1, 2, 3), keepdim=True) + prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True) if torch.all(prediction_max != 0): prediction = prediction / prediction_max else: diff --git a/lensless/recon/drunet/basicblock.py b/lensless/recon/drunet/basicblock.py new file mode 100644 index 00000000..ed17a10b --- /dev/null +++ b/lensless/recon/drunet/basicblock.py @@ -0,0 +1,363 @@ +# ############################################################################# +# basicblock.py +# ================= +# Original Repo: +# https://github.com/cszn/DPIR/blob/15bca3fcc1f3cc51a1f99ccf027691e278c19354/models/basicblock.py +# ############################################################################# + +from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.functional as F + + +""" +# -------------------------------------------- +# Advanced nn.Sequential +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +""" + + +def sequential(*args): + """Advanced nn.Sequential. + + Args: + nn.Sequential, nn.Module + + Returns: + nn.Sequential + """ + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError("sequential does not support OrderedDict input.") + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +""" +# -------------------------------------------- +# Useful blocks +# https://github.com/xinntao/BasicSR +# -------------------------------- +# conv + normaliation + relu (conv) +# (PixelUnShuffle) +# (ConditionalBatchNorm2d) +# concat (ConcatBlock) +# sum (ShortcutBlock) +# resblock (ResBlock) +# Channel Attention (CA) Layer (CALayer) +# Residual Channel Attention Block (RCABlock) +# Residual Channel Attention Group (RCAGroup) +# Residual Dense Block (ResidualDenseBlock_5C) +# Residual in Residual Dense Block (RRDB) +# -------------------------------------------- +""" + + +# -------------------------------------------- +# return nn.Sequantial of (Conv + BN + ReLU) +# -------------------------------------------- +def conv( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="CBR", + negative_slope=0.2, +): + L = [] + for t in mode: + if t == "C": + L.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + ) + elif t == "T": + L.append( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + ) + elif t == "B": + L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) + elif t == "I": + L.append(nn.InstanceNorm2d(out_channels, affine=True)) + elif t == "R": + L.append(nn.ReLU(inplace=True)) + elif t == "r": + L.append(nn.ReLU(inplace=False)) + elif t == "L": + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) + elif t == "l": + L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) + elif t == "2": + L.append(nn.PixelShuffle(upscale_factor=2)) + elif t == "3": + L.append(nn.PixelShuffle(upscale_factor=3)) + elif t == "4": + L.append(nn.PixelShuffle(upscale_factor=4)) + elif t == "U": + L.append(nn.Upsample(scale_factor=2, mode="nearest")) + elif t == "u": + L.append(nn.Upsample(scale_factor=3, mode="nearest")) + elif t == "v": + L.append(nn.Upsample(scale_factor=4, mode="nearest")) + elif t == "M": + L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + elif t == "A": + L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) + else: + raise NotImplementedError("Undefined type: {}".format(t)) + return sequential(*L) + + +# -------------------------------------------- +# Res Block: x + conv(relu(conv(x))) +# -------------------------------------------- +class ResBlock(nn.Module): + def __init__( + self, + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="CRC", + negative_slope=0.2, + ): + super(ResBlock, self).__init__() + + assert in_channels == out_channels, "Only support in_channels==out_channels." + if mode[0] in ["R", "L"]: + mode = mode[0].lower() + mode[1:] + + self.res = conv( + in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope + ) + + def forward(self, x): + # res = self.res(x) + return x + self.res(x) + + +""" +# -------------------------------------------- +# Upsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# upsample_pixelshuffle +# upsample_upconv +# upsample_convtranspose +# -------------------------------------------- +""" + + +# -------------------------------------------- +# conv + subp (+ relu) +# -------------------------------------------- +def upsample_pixelshuffle( + in_channels=64, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." + up1 = conv( + in_channels, + out_channels * (int(mode[0]) ** 2), + kernel_size, + stride, + padding, + bias, + mode="C" + mode, + negative_slope=negative_slope, + ) + return up1 + + +# -------------------------------------------- +# nearest_upsample + conv (+ R) +# -------------------------------------------- +def upsample_upconv( + in_channels=64, + out_channels=3, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR" + if mode[0] == "2": + uc = "UC" + elif mode[0] == "3": + uc = "uC" + elif mode[0] == "4": + uc = "vC" + mode = mode.replace(mode[0], uc) + up1 = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode=mode, + negative_slope=negative_slope, + ) + return up1 + + +# -------------------------------------------- +# convTranspose (+ relu) +# -------------------------------------------- +def upsample_convtranspose( + in_channels=64, + out_channels=3, + kernel_size=2, + stride=2, + padding=0, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], "T") + up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) + return up1 + + +""" +# -------------------------------------------- +# Downsampler +# Kai Zhang, https://github.com/cszn/KAIR +# -------------------------------------------- +# downsample_strideconv +# downsample_maxpool +# downsample_avgpool +# -------------------------------------------- +""" + + +# -------------------------------------------- +# strideconv (+ relu) +# -------------------------------------------- +def downsample_strideconv( + in_channels=64, + out_channels=64, + kernel_size=2, + stride=2, + padding=0, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in ["2", "3", "4"], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." + kernel_size = int(mode[0]) + stride = int(mode[0]) + mode = mode.replace(mode[0], "C") + down1 = conv( + in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope + ) + return down1 + + +# -------------------------------------------- +# maxpooling + conv (+ relu) +# -------------------------------------------- +def downsample_maxpool( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=0, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in ["2", "3"], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], "MC") + pool = conv( + kernel_size=kernel_size_pool, + stride=stride_pool, + mode=mode[0], + negative_slope=negative_slope, + ) + pool_tail = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode=mode[1:], + negative_slope=negative_slope, + ) + return sequential(pool, pool_tail) + + +# -------------------------------------------- +# averagepooling + conv (+ relu) +# -------------------------------------------- +def downsample_avgpool( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + mode="2R", + negative_slope=0.2, +): + assert len(mode) < 4 and mode[0] in ["2", "3"], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." + kernel_size_pool = int(mode[0]) + stride_pool = int(mode[0]) + mode = mode.replace(mode[0], "AC") + pool = conv( + kernel_size=kernel_size_pool, + stride=stride_pool, + mode=mode[0], + negative_slope=negative_slope, + ) + pool_tail = conv( + in_channels, + out_channels, + kernel_size, + stride, + padding, + bias, + mode=mode[1:], + negative_slope=negative_slope, + ) + return sequential(pool, pool_tail) diff --git a/lensless/recon/drunet/network_unet.py b/lensless/recon/drunet/network_unet.py new file mode 100644 index 00000000..6f9c390e --- /dev/null +++ b/lensless/recon/drunet/network_unet.py @@ -0,0 +1,180 @@ +# ############################################################################# +# network_unet.py +# ================= +# Original Repo: +# https://github.com/cszn/DPIR/blob/15bca3fcc1f3cc51a1f99ccf027691e278c19354/models/network_unet.py +# ############################################################################# + +import torch +import torch.nn as nn +import lensless.recon.drunet.basicblock as B +import numpy as np + +""" +# ==================== +# unet +# From https://github.com/cszn/DPIR/blob/master/main_dpir_denoising.py +# ==================== +""" + + +class UNet(nn.Module): + def __init__( + self, + in_nc=1, + out_nc=1, + nc=[64, 128, 256, 512], + nb=2, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ): + super(UNet, self).__init__() + + self.m_head = B.conv(in_nc, nc[0], mode="C" + act_mode[-1]) + + # downsample + if downsample_mode == "avgpool": + downsample_block = B.downsample_avgpool + elif downsample_mode == "maxpool": + downsample_block = B.downsample_maxpool + elif downsample_mode == "strideconv": + downsample_block = B.downsample_strideconv + else: + raise NotImplementedError("downsample mode [{:s}] is not found".format(downsample_mode)) + + self.m_down1 = B.sequential( + *[B.conv(nc[0], nc[0], mode="C" + act_mode) for _ in range(nb)], + downsample_block(nc[0], nc[1], mode="2" + act_mode) + ) + self.m_down2 = B.sequential( + *[B.conv(nc[1], nc[1], mode="C" + act_mode) for _ in range(nb)], + downsample_block(nc[1], nc[2], mode="2" + act_mode) + ) + self.m_down3 = B.sequential( + *[B.conv(nc[2], nc[2], mode="C" + act_mode) for _ in range(nb)], + downsample_block(nc[2], nc[3], mode="2" + act_mode) + ) + + self.m_body = B.sequential( + *[B.conv(nc[3], nc[3], mode="C" + act_mode) for _ in range(nb + 1)] + ) + + # upsample + if upsample_mode == "upconv": + upsample_block = B.upsample_upconv + elif upsample_mode == "pixelshuffle": + upsample_block = B.upsample_pixelshuffle + elif upsample_mode == "convtranspose": + upsample_block = B.upsample_convtranspose + else: + raise NotImplementedError("upsample mode [{:s}] is not found".format(upsample_mode)) + + self.m_up3 = B.sequential( + upsample_block(nc[3], nc[2], mode="2" + act_mode), + *[B.conv(nc[2], nc[2], mode="C" + act_mode) for _ in range(nb)] + ) + self.m_up2 = B.sequential( + upsample_block(nc[2], nc[1], mode="2" + act_mode), + *[B.conv(nc[1], nc[1], mode="C" + act_mode) for _ in range(nb)] + ) + self.m_up1 = B.sequential( + upsample_block(nc[1], nc[0], mode="2" + act_mode), + *[B.conv(nc[0], nc[0], mode="C" + act_mode) for _ in range(nb)] + ) + + self.m_tail = B.conv(nc[0], out_nc, bias=True, mode="C") + + def forward(self, x0): + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + x0 + + return x + + +class UNetRes(nn.Module): + def __init__( + self, + in_nc=1, + out_nc=1, + nc=[64, 128, 256, 512], + nb=4, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ): + super(UNetRes, self).__init__() + + self.m_head = B.conv(in_nc, nc[0], bias=False, mode="C") + + # downsample + if downsample_mode == "avgpool": + downsample_block = B.downsample_avgpool + elif downsample_mode == "maxpool": + downsample_block = B.downsample_maxpool + elif downsample_mode == "strideconv": + downsample_block = B.downsample_strideconv + else: + raise NotImplementedError("downsample mode [{:s}] is not found".format(downsample_mode)) + + self.m_down1 = B.sequential( + *[B.ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C") for _ in range(nb)], + downsample_block(nc[0], nc[1], bias=False, mode="2") + ) + self.m_down2 = B.sequential( + *[B.ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C") for _ in range(nb)], + downsample_block(nc[1], nc[2], bias=False, mode="2") + ) + self.m_down3 = B.sequential( + *[B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") for _ in range(nb)], + downsample_block(nc[2], nc[3], bias=False, mode="2") + ) + + self.m_body = B.sequential( + *[B.ResBlock(nc[3], nc[3], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] + ) + + # upsample + if upsample_mode == "upconv": + upsample_block = B.upsample_upconv + elif upsample_mode == "pixelshuffle": + upsample_block = B.upsample_pixelshuffle + elif upsample_mode == "convtranspose": + upsample_block = B.upsample_convtranspose + else: + raise NotImplementedError("upsample mode [{:s}] is not found".format(upsample_mode)) + + self.m_up3 = B.sequential( + upsample_block(nc[3], nc[2], bias=False, mode="2"), + *[B.ResBlock(nc[2], nc[2], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] + ) + self.m_up2 = B.sequential( + upsample_block(nc[2], nc[1], bias=False, mode="2"), + *[B.ResBlock(nc[1], nc[1], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] + ) + self.m_up1 = B.sequential( + upsample_block(nc[1], nc[0], bias=False, mode="2"), + *[B.ResBlock(nc[0], nc[0], bias=False, mode="C" + act_mode + "C") for _ in range(nb)] + ) + + self.m_tail = B.conv(nc[0], out_nc, bias=False, mode="C") + + def forward(self, x0): + x1 = self.m_head(x0) + x2 = self.m_down1(x1) + x3 = self.m_down2(x2) + x4 = self.m_down3(x3) + x = self.m_body(x4) + x = self.m_up3(x + x4) + x = self.m_up2(x + x3) + x = self.m_up1(x + x2) + x = self.m_tail(x + x1) + + return x diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 1a27d07d..a6268032 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -41,7 +41,15 @@ class TrainableReconstructionAlgorithm(ReconstructionAlgorithm, torch.nn.Module) """ - def __init__(self, psf, dtype=None, n_iter=1, **kwargs): + def __init__( + self, + psf, + dtype=None, + n_iter=1, + pre_process=None, + post_process=None, + **kwargs, + ): """ Base constructor. Derived constructor may define new state variables here and also reset them in `reset`. @@ -58,29 +66,93 @@ def __init__(self, psf, dtype=None, n_iter=1, **kwargs): Data type to use for optimization. n_iter : int Number of iterations for unrolled algorithm. + pre_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + If :py:class:`function` : Function to apply to the image estimate before algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. + If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate before algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. + post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. + If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. """ assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor" super(TrainableReconstructionAlgorithm, self).__init__( psf, dtype=dtype, n_iter=n_iter, **kwargs ) - @abc.abstractmethod + # pre processing + ( + self.pre_process, + self.pre_process_model, + self.pre_process_param, + ) = self._prepare_process_block(pre_process) + + # post processing + ( + self.post_process, + self.post_process_model, + self.post_process_param, + ) = self._prepare_process_block(post_process) + + def _prepare_process_block(self, process): + """ + Method for preparing the pre or post process block. + Parameters + ---------- + process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional + Pre or post process block to prepare. + """ + if isinstance(process, torch.nn.Module): + # If the post_process is a torch module, we assume it is a DruNet like network. + from lensless.utils.image import process_with_DruNet + + process_model = process + process_function = process_with_DruNet(process_model, self._psf.device, mode="train") + elif process is not None: + # Otherwise, we assume it is a function. + assert callable(process), "pre_process must be a callable function" + process_function = process + process_model = None + else: + process_function = None + process_model = None + if process_function is not None: + process_param = torch.nn.Parameter(torch.tensor([1.0], device=self._psf.device)) + else: + process_param = None + + return process_function, process_model, process_param + def batch_call(self, batch): """ Method for performing iterative reconstruction on a batch of images. - This implementation simply calls `apply` on each image in the batch. - Training algorithms are expected to override this method with a properly vectorized implementation. + This implementation is a properly vectorized implementation of FISTA. Parameters ---------- - batch : :py:class:`~torch.Tensor` of shape (batch, depth, height, width, channels) + batch : :py:class:`~torch.Tensor` of shape (batch, depth, channels, height, width) The lensless images to reconstruct. Returns ------- - :py:class:`~torch.Tensor` of shape (batch, depth, height, width, channels) + :py:class:`~torch.Tensor` of shape (batch, depth, channels, height, width) The reconstructed images. """ + self._data = batch + assert len(self._data.shape) == 5, "batch must be of shape (N, D, C, H, W)" + batch_size = batch.shape[0] + + # pre process data + if self.pre_process is not None: + self._data = self.pre_process(self._data, self.pre_process_param) + + self.reset(batch_size=batch_size) + + for i in range(self._n_iter): + self._update(i) + + image_est = self._form_image() + if self.post_process is not None: + image_est = self.post_process(image_est, self.post_process_param) + return image_est def apply( self, disp_iter=10, plot_pause=0.2, plot=True, save=False, gamma=None, ax=None, reset=True @@ -118,6 +190,9 @@ def apply( returning if `plot` or `save` is True. """ + if self.pre_process is not None: + self._data = self.pre_process(self._data, self.pre_process_param) + im = super(TrainableReconstructionAlgorithm, self).apply( n_iter=self._n_iter, disp_iter=disp_iter, @@ -128,4 +203,6 @@ def apply( ax=ax, reset=reset, ) + if self.post_process is not None: + im = self.post_process(im, self.post_process_param) return im diff --git a/lensless/recon/unrolled_admm.py b/lensless/recon/unrolled_admm.py index 06a9f5d6..43b6b956 100644 --- a/lensless/recon/unrolled_admm.py +++ b/lensless/recon/unrolled_admm.py @@ -8,6 +8,7 @@ from lensless.recon.trainable_recon import TrainableReconstructionAlgorithm from lensless.recon.admm import soft_thresh, finite_diff, finite_diff_adj, finite_diff_gram + try: import torch @@ -215,30 +216,3 @@ def _form_image(self): image = self._convolver._crop(self._image_est) image[image < 0] = 0 return image - - def batch_call(self, batch): - """ - Method for performing iterative reconstruction on a batch of images. - This implementation is a properly vectorized implementation of ADMM. - - Parameters - ---------- - batch : :py:class:`~torch.Tensor` of shape (N, D, C, H, W) - The lensless images to reconstruct. - - Returns - ------- - :py:class:`~torch.Tensor` of shape (N, D, C, H, W) - The reconstructed images. - """ - self._data = batch - assert len(self._data.shape) == 5, "batch must be of shape (N, D, C, H, W)" - batch_size = batch.shape[0] - - self.reset(batch_size=batch_size) - - for i in range(self._n_iter): - self._update(i) - - image_est = self._form_image() - return image_est diff --git a/lensless/recon/unrolled_fista.py b/lensless/recon/unrolled_fista.py index 75824812..1361cda1 100644 --- a/lensless/recon/unrolled_fista.py +++ b/lensless/recon/unrolled_fista.py @@ -99,29 +99,3 @@ def _update(self, iter): xk = self._proj(self._image_est) self._image_est = xk + (self._tk[iter] - 1) / self._tk[iter + 1] * (xk - self._xk) self._xk = xk - - def batch_call(self, batch): - """ - Method for performing iterative reconstruction on a batch of images. - This implementation is a properly vectorized implementation of FISTA. - - Parameters - ---------- - batch : :py:class:`~torch.Tensor` of shape (N, D, C, H, W) - The lensless images to reconstruct. - - Returns - ------- - :py:class:`~torch.Tensor` of shape (N, D, C, H, W) - The reconstructed images. - """ - self._data = batch - assert len(self._data.shape) == 5, "Input must be of shape (N, D, H, W, C)" - batch_size = batch.shape[0] - - self.reset(batch_size) - - for i in range(self._n_iter): - self._update(i) - - return self._proj(self._image_est) diff --git a/lensless/utils/image.py b/lensless/utils/image.py index b267bb75..c72ca10b 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -276,3 +276,133 @@ def autocorr2d(vals, pad_mode="reflect"): # remove padding return autocorr[shape[0] // 2 : -shape[0] // 2, shape[1] // 2 : -shape[1] // 2] + + +def load_drunet(model_path, n_channels=3, requires_grad=False): + """ + Load a pre-trained Drunet model. + + Parameters + ---------- + model_path : str + Path to pre-trained model. + n_channels : int + Number of channels in input image. + requires_grad : bool + Whether to require gradients for model parameters. + + Returns + ------- + model : :py:class:`~torch.nn.Module` + Loaded model. + """ + from lensless.recon.drunet.network_unet import UNetRes + + model = UNetRes( + in_nc=n_channels + 1, + out_nc=n_channels, + nc=[64, 128, 256, 512], + nb=4, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ) + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = requires_grad + + return model + + +def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference"): + """ + Apply a pre-trained denoising model with input in the format Channel, Height, Width. + An additionnal channel is added for the noise level as done in Drunet. + + Parameters + ---------- + model : :py:class:`~torch.nn.Module` + Drunet compatible model. Its input must concist of 4 channels ( RGB + noise level) and outbut an RGB image both in CHW format. + image : :py:class:`~torch.Tensor` + Input image. + noise_level : float or :py:class:`~torch.Tensor` + Noise level in the image. + device : str + Device to use for computation. Can be "cpu" or "cuda". + mode : str + Mode to use for model. Can be "inference" or "train". + + Returns + ------- + image : :py:class:`~torch.Tensor` + Reconstructed image. + """ + # convert from NDHWC to NCHW + depth = image.shape[-4] + image = image.movedim(-1, -3) + image = image.reshape(-1, *image.shape[-3:]) + # pad image H and W to next multiple of 8 + top = (8 - image.shape[-2] % 8) // 2 + bottom = (8 - image.shape[-2] % 8) - top + left = (8 - image.shape[-1] % 8) // 2 + right = (8 - image.shape[-1] % 8) - left + image = torch.nn.functional.pad(image, (left, right, top, bottom), mode="constant", value=0) + # add noise level as extra channel + image = image.to(device) + if isinstance(noise_level, torch.Tensor): + noise_level = noise_level / 255.0 + else: + noise_level = torch.tensor([noise_level / 255.0]).to(device) + image = torch.cat( + ( + image, + noise_level.repeat(image.shape[0], 1, image.shape[2], image.shape[3]), + ), + dim=1, + ) + + # apply model + if mode == "inference": + with torch.no_grad(): + image = model(image) + elif mode == "train": + image = model(image) + else: + raise ValueError("mode must be 'inference' or 'train'") + + # remove padding + image = image[:, :, top:-bottom, left:-right] + # convert back to NDHWC + image = image.movedim(-3, -1) + image = image.reshape(-1, depth, *image.shape[-3:]) + return image + + +def process_with_DruNet(model, device="cpu", mode="inference"): + """ + Return a porcessing function that applies the DruNet model to an image. + + Parameters + ---------- + model : torch.nn.Module + DruNet like denoiser model + device : str + Device to use for computation. Can be "cpu" or "cuda". + mode : str + Mode to use for model. Can be "inference" or "train". + """ + + def process(image, noise_level): + x_max = torch.amax(image, dim=(-2, -3), keepdim=True) + 1e-6 + image = apply_denoiser( + model, + image, + noise_level=noise_level, + device=device, + mode="train", + ) + image = torch.clip(image, min=0.0) * x_max + return image + + return process diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index 4d81dcea..b79647da 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -37,8 +37,8 @@ def benchmark_recon(config): n_iter_range = config.n_iter_range # check if GPU is available - if torch.cuda.is_available(): - device = "cuda" + if torch.cuda.is_available() and config.device[:4] == "cuda": + device = config.device else: device = "cpu" @@ -113,8 +113,15 @@ def benchmark_recon(config): else: unrolled_results[model_name][metric] = result[metric] + # Baseline results + baseline_results = { + "MSE": 0.0618, + "LPIPS_Alex": 0.4434, + "ReconstructionError": 13.70, + } + # for each metrics plot the results comparing each model - metrics_to_plot = ["SSIM", "PSNR", "MSE", "LPIPS", "ReconstructionError"] + metrics_to_plot = ["SSIM", "PSNR", "MSE", "LPIPS_Vgg", "LPIPS_Alex", "ReconstructionError"] for metric in metrics_to_plot: plt.figure() # plot benchmarked algorithm @@ -124,6 +131,16 @@ def benchmark_recon(config): [result[metric] for result in results[model_name]], label=model_name, ) + # plot baseline as horizontal dotted line + if metric in baseline_results.keys(): + plt.hlines( + baseline_results[metric], + 0, + max(n_iter_range), + linestyles="dashed", + label="Unrolled MONAKHOVA 5iter", + color="orange", + ) # plot unrolled algorithms results color_list = ["red", "green", "blue", "orange", "purple"] @@ -136,8 +153,10 @@ def benchmark_recon(config): plot_name = model_name # set color depending on plot name using same color for same algorithm + first = False if plot_name not in algorithm_colors.keys(): algorithm_colors[plot_name] = color_list.pop() + first = True color = algorithm_colors[plot_name] # check if metric is defined @@ -155,17 +174,25 @@ def benchmark_recon(config): ) else: # plot as point - plt.plot( - unrolled_results[model_name]["n_iter"], - unrolled_results[model_name][metric], - label=plot_name, - marker="o", - color=color, - ) + if first: + plt.plot( + unrolled_results[model_name]["n_iter"], + unrolled_results[model_name][metric], + label=plot_name, + marker="o", + color=color, + ) + else: + plt.plot( + unrolled_results[model_name]["n_iter"], + unrolled_results[model_name][metric], + marker="o", + color=color, + ) plt.title(metric) plt.xlabel("Number of iterations") plt.ylabel(metric) - plt.legend() + plt.legend(fontsize="8") plt.savefig(f"{metric}.png") diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index cb63be9c..17a88461 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -19,6 +19,12 @@ @hydra.main(version_base=None, config_path="../../configs", config_name="defaults_recon") def admm(config): + if config.torch: + try: + import torch + except ImportError: + raise ImportError("Pytorch not found. Please install pytorch to use torch mode.") + psf, data = load_data( psf_fp=to_absolute_path(config.input.psf), data_fp=to_absolute_path(config.input.data), @@ -46,18 +52,48 @@ def admm(config): save = os.getcwd() start_time = time.time() - recon = ADMM(psf, **config.admm) + if not config.admm.unrolled: + recon = ADMM(psf, **config.admm) + else: + assert config.torch, "Unrolled ADMM only works with torch" + from lensless.recon.unrolled_admm import UnrolledADMM + import train_unrolled + + pre_process = train_unrolled.create_process_network( + network=config.admm.pre_process_model.network, + depth=config.admm.pre_process_depth.depth, + device=config.torch_device, + ) + post_process = train_unrolled.create_process_network( + network=config.admm.post_process_model.network, + depth=config.admm.post_process_depth.depth, + device=config.torch_device, + ) + + recon = UnrolledADMM(psf, pre_process=pre_process, post_process=post_process, **config.admm) + path = to_absolute_path(config.admm.checkpoint_fp) + print("Loading checkpoint from : ", path) + assert os.path.exists(path), "Checkpoint does not exist" + recon.load_state_dict(torch.load(path, map_location=config.torch_device)) recon.set_data(data) print(f"Setup time : {time.time() - start_time} s") start_time = time.time() - res = recon.apply( - n_iter=config["admm"]["n_iter"], - disp_iter=disp, - save=save, - gamma=config["display"]["gamma"], - plot=config["display"]["plot"], - ) + if config.torch: + with torch.no_grad(): + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + ) + else: + res = recon.apply( + disp_iter=disp, + save=save, + gamma=config["display"]["gamma"], + plot=config["display"]["plot"], + ) print(f"Processing time : {time.time() - start_time} s") if config.torch: diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 1db64a00..8972bd49 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -14,6 +14,7 @@ """ +import math import hydra from hydra.utils import get_original_cwd import os @@ -85,6 +86,35 @@ def simulate_dataset(config, psf): return ds_loader +def create_process_network(network, depth, device="cpu"): + if network == "DruNet": + from lensless.utils.image import load_drunet + + process = load_drunet( + os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True + ).to(device) + process_name = "DruNet" + elif network == "UnetRes": + from lensless.recon.drunet.network_unet import UNetRes + + n_channels = 3 + process = UNetRes( + in_nc=n_channels + 1, + out_nc=n_channels, + nc=[64, 128, 256, 512], + nb=depth, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ).to(device) + process_name = "UnetRes_d" + str(depth) + else: + process = None + process_name = None + + return (process, process_name) + + def measure_gradient(model): # return the L2 norm of the gradient total_norm = 0.0 @@ -113,7 +143,9 @@ def train_unrolled( # benchmarking dataset: path = os.path.join(get_original_cwd(), "data") - benchmark_dataset = DiffuserCamTestDataset(data_dir=path) + benchmark_dataset = DiffuserCamTestDataset( + data_dir=path, downsample=config.simulation.downsample + ) psf = benchmark_dataset.psf.to(device) background = benchmark_dataset.background @@ -135,6 +167,20 @@ def train_unrolled( save = os.getcwd() start_time = time.time() + + # Load pre process model + pre_process, pre_process_name = create_process_network( + config.reconstruction.pre_process.network, + config.reconstruction.pre_process.depth, + device=device, + ) + # Load post process model + post_process, post_process_name = create_process_network( + config.reconstruction.post_process.network, + config.reconstruction.post_process.depth, + device=device, + ) + # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": recon = UnrolledFISTA( psf, @@ -142,6 +188,8 @@ def train_unrolled( tk=config.reconstruction.unrolled_fista.tk, pad=True, learn_tk=config.reconstruction.unrolled_fista.learn_tk, + pre_process=pre_process, + post_process=post_process, ).to(device) n_iter = config.reconstruction.unrolled_fista.n_iter elif config.reconstruction.method == "unrolled_admm": @@ -152,11 +200,22 @@ def train_unrolled( mu2=config.reconstruction.unrolled_admm.mu2, mu3=config.reconstruction.unrolled_admm.mu3, tau=config.reconstruction.unrolled_admm.tau, + pre_process=pre_process, + post_process=post_process, ).to(device) n_iter = config.reconstruction.unrolled_admm.n_iter else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") + # constructing algorithm name by appending pre and post process + algorithm_name = config.reconstruction.method + if config.reconstruction.pre_process.network is not None: + algorithm_name = pre_process_name + "_" + algorithm_name + if config.reconstruction.post_process.network is not None: + algorithm_name += "_" + post_process_name + + # print number of parameters + print(f"Training model with {sum(p.numel() for p in recon.parameters())} parameters") # transform from BGR to RGB transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]]) @@ -212,36 +271,63 @@ def train_unrolled( # optimizer if config.optimizer.type == "Adam": - optimizer = torch.optim.Adam(recon.parameters(), lr=config.optimizer.lr) + # the parameters of the base model and non torch.Module process must be added separatly + parameters = [{"params": recon.parameters()}] + optimizer = torch.optim.Adam(parameters, lr=config.optimizer.lr) else: raise ValueError(f"Unsuported optimizer : {config.optimizer.type}") + # Scheduler + if config.training.slow_start: + + def learning_rate_function(epoch): + if epoch == 0: + return config.training.slow_start + elif epoch == 1: + return math.sqrt(config.training.slow_start) + else: + return 1 + + else: + + def learning_rate_function(epoch): + return 1 + + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=learning_rate_function) + metrics = { "LOSS": [], "MSE": [], "MAE": [], - "LPIPS": [], + "LPIPS_Vgg": [], + "LPIPS_Alex": [], "PSNR": [], "SSIM": [], "ReconstructionError": [], "n_iter": n_iter, - "algorithm": config.reconstruction.method, + "algorithm": algorithm_name, } # Backward hook that detect NAN in the gradient and print the layer weights - def detect_nan(grad): - if torch.isnan(grad).any(): - print(grad) - for name, param in recon.named_parameters(): - print(name, param) - raise ValueError("Gradient is NaN") - return grad - - for param in recon.parameters(): - param.register_hook(detect_nan) + if not config.training.skip_NAN: + + def detect_nan(grad): + if torch.isnan(grad).any(): + print(grad, flush=True) + for name, param in recon.named_parameters(): + if param.requires_grad: + print(name, param) + raise ValueError("Gradient is NaN") + return grad + + for param in recon.parameters(): + if param.requires_grad: + param.register_hook(detect_nan) + if param.requires_grad: + param.register_hook(detect_nan) # Training loop for epoch in range(config.training.epoch): - print(f"Epoch {epoch}") + print(f"Epoch {epoch} with learning rate {scheduler.get_last_lr()}") mean_loss = 0.0 i = 1.0 pbar = tqdm(data_loader) @@ -255,12 +341,13 @@ def detect_nan(grad): y_pred = recon.batch_call(X.to(device)) # normalizing each output - y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps = 1e-12 + y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps y_pred = y_pred / y_pred_max # normalizing y y = y.to(device) - y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps y = y / y_max if i % disp == 1 and config.display.plot: @@ -279,13 +366,24 @@ def detect_nan(grad): loss_v = Loss(y_pred, y) if config.lpips: - loss_v = loss_v + config.lpips * torch.mean(loss_lpips(y_pred, y)) + # value for LPIPS needs to be in range [-1, 1] + loss_v = loss_v + config.lpips * torch.mean(loss_lpips(2 * y_pred - 1, 2 * y - 1)) loss_v.backward() torch.nn.utils.clip_grad_norm_(recon.parameters(), 1.0) + + # if any gradient is NaN, skip training step + is_NAN = False + for param in recon.parameters(): + if torch.isnan(param.grad).any(): + is_NAN = True + break + if is_NAN: + print("NAN detected in gradiant, skipping training step") + i += 1 + continue optimizer.step() mean_loss += (loss_v.item() - mean_loss) * (1 / i) - pbar.set_description(f"loss : {mean_loss}") i += 1 @@ -296,6 +394,9 @@ def detect_nan(grad): for key in current_metrics: metrics[key].append(current_metrics[key]) + # Update learning rate + scheduler.step() + print(f"Train time : {time.time() - start_time} s") # save dictionary metrics to file with json diff --git a/test/test_algos.py b/test/test_algos.py index d1be035d..b63b5a42 100644 --- a/test/test_algos.py +++ b/test/test_algos.py @@ -168,7 +168,16 @@ def test_trainable_recon(algorithm): for dtype, torch_type in [("float32", torch.float32), ("float64", torch.float64)]: psf = torch.rand(1, 32, 64, 3, dtype=torch_type) data = torch.rand(2, 1, 32, 64, 3, dtype=torch_type) - recon = UnrolledFISTA(psf, n_iter=_n_iter, dtype=dtype) + + def pre_process(x, noise): + return x + + def post_process(x, noise): + return x + + recon = algorithm( + psf, n_iter=_n_iter, dtype=dtype, pre_process=pre_process, post_process=post_process + ) assert ( next(recon.parameters(), None) is not None @@ -197,7 +206,15 @@ def test_trainable_batch(algorithm): data2 = torch.rand(1, 1, 34, 64, 3, dtype=torch_type) data2[0, 0, ...] = data1[0, 0, ...] - recon = algorithm(psf, dtype=dtype, n_iter=_n_iter) + def pre_process(x, noise): + return x + + def post_process(x, noise): + return x + + recon = algorithm( + psf, dtype=dtype, n_iter=_n_iter, pre_process=pre_process, post_process=post_process + ) res1 = recon.batch_call(data1) res2 = recon.batch_call(data2) recon.set_data(data2[0])