Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move drunet utilities to recon module. #71

Merged
merged 1 commit into from
Jul 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def _prepare_process_block(self, process):
"""
if isinstance(process, torch.nn.Module):
# If the post_process is a torch module, we assume it is a DruNet like network.
from lensless.utils.image import process_with_DruNet
from lensless.recon.utils import get_drunet_function

process_model = process
process_function = process_with_DruNet(process_model, self._psf.device, mode="train")
process_function = get_drunet_function(process_model, self._psf.device, mode="train")
elif process is not None:
# Otherwise, we assume it is a function.
assert callable(process), "pre_process must be a callable function"
Expand Down
131 changes: 131 additions & 0 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
from lensless.recon.drunet.network_unet import UNetRes


def load_drunet(model_path, n_channels=3, requires_grad=False):
"""
Load a pre-trained Drunet model.

Parameters
----------
model_path : str
Path to pre-trained model.
n_channels : int
Number of channels in input image.
requires_grad : bool
Whether to require gradients for model parameters.

Returns
-------
model : :py:class:`~torch.nn.Module`
Loaded model.
"""

model = UNetRes(
in_nc=n_channels + 1,
out_nc=n_channels,
nc=[64, 128, 256, 512],
nb=4,
act_mode="R",
downsample_mode="strideconv",
upsample_mode="convtranspose",
)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for _, v in model.named_parameters():
v.requires_grad = requires_grad

return model


def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference"):
"""
Apply a pre-trained denoising model with input in the format Channel, Height, Width.
An additionnal channel is added for the noise level as done in Drunet.

Parameters
----------
model : :py:class:`~torch.nn.Module`
Drunet compatible model. Its input must consist of 4 channels (RGB + noise level) and output an RGB image both in CHW format.
image : :py:class:`~torch.Tensor`
Input image.
noise_level : float or :py:class:`~torch.Tensor`
Noise level in the image.
device : str
Device to use for computation. Can be "cpu" or "cuda".
mode : str
Mode to use for model. Can be "inference" or "train".

Returns
-------
image : :py:class:`~torch.Tensor`
Reconstructed image.
"""
# convert from NDHWC to NCHW
depth = image.shape[-4]
image = image.movedim(-1, -3)
image = image.reshape(-1, *image.shape[-3:])
# pad image H and W to next multiple of 8
top = (8 - image.shape[-2] % 8) // 2
bottom = (8 - image.shape[-2] % 8) - top
left = (8 - image.shape[-1] % 8) // 2
right = (8 - image.shape[-1] % 8) - left
image = torch.nn.functional.pad(image, (left, right, top, bottom), mode="constant", value=0)
# add noise level as extra channel
image = image.to(device)
if isinstance(noise_level, torch.Tensor):
noise_level = noise_level / 255.0
else:
noise_level = torch.tensor([noise_level / 255.0]).to(device)
image = torch.cat(
(
image,
noise_level.repeat(image.shape[0], 1, image.shape[2], image.shape[3]),
),
dim=1,
)

# apply model
if mode == "inference":
with torch.no_grad():
image = model(image)
elif mode == "train":
image = model(image)
else:
raise ValueError("mode must be 'inference' or 'train'")

# remove padding
image = image[:, :, top:-bottom, left:-right]
# convert back to NDHWC
image = image.movedim(-3, -1)
image = image.reshape(-1, depth, *image.shape[-3:])
return image


def get_drunet_function(model, device="cpu", mode="inference"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mode = "train" would be a safer default. I expect this function to be mostly use during training, and even if it is mistakenly used during inference, it should only make things slightly slower (while the other ways would cause multiple bugs)

"""
Return a porcessing function that applies the DruNet model to an image.

Parameters
----------
model : torch.nn.Module
DruNet like denoiser model
device : str
Device to use for computation. Can be "cpu" or "cuda".
mode : str
Mode to use for model. Can be "inference" or "train".
"""

def process(image, noise_level):
x_max = torch.amax(image, dim=(-2, -3), keepdim=True) + 1e-6
image = apply_denoiser(
model,
image,
noise_level=noise_level,
device=device,
mode=mode,
)
image = torch.clip(image, min=0.0) * x_max
return image

return process
130 changes: 0 additions & 130 deletions lensless/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,133 +276,3 @@ def autocorr2d(vals, pad_mode="reflect"):

# remove padding
return autocorr[shape[0] // 2 : -shape[0] // 2, shape[1] // 2 : -shape[1] // 2]


def load_drunet(model_path, n_channels=3, requires_grad=False):
"""
Load a pre-trained Drunet model.

Parameters
----------
model_path : str
Path to pre-trained model.
n_channels : int
Number of channels in input image.
requires_grad : bool
Whether to require gradients for model parameters.

Returns
-------
model : :py:class:`~torch.nn.Module`
Loaded model.
"""
from lensless.recon.drunet.network_unet import UNetRes

model = UNetRes(
in_nc=n_channels + 1,
out_nc=n_channels,
nc=[64, 128, 256, 512],
nb=4,
act_mode="R",
downsample_mode="strideconv",
upsample_mode="convtranspose",
)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = requires_grad

return model


def apply_denoiser(model, image, noise_level=10, device="cpu", mode="inference"):
"""
Apply a pre-trained denoising model with input in the format Channel, Height, Width.
An additionnal channel is added for the noise level as done in Drunet.

Parameters
----------
model : :py:class:`~torch.nn.Module`
Drunet compatible model. Its input must concist of 4 channels ( RGB + noise level) and outbut an RGB image both in CHW format.
image : :py:class:`~torch.Tensor`
Input image.
noise_level : float or :py:class:`~torch.Tensor`
Noise level in the image.
device : str
Device to use for computation. Can be "cpu" or "cuda".
mode : str
Mode to use for model. Can be "inference" or "train".

Returns
-------
image : :py:class:`~torch.Tensor`
Reconstructed image.
"""
# convert from NDHWC to NCHW
depth = image.shape[-4]
image = image.movedim(-1, -3)
image = image.reshape(-1, *image.shape[-3:])
# pad image H and W to next multiple of 8
top = (8 - image.shape[-2] % 8) // 2
bottom = (8 - image.shape[-2] % 8) - top
left = (8 - image.shape[-1] % 8) // 2
right = (8 - image.shape[-1] % 8) - left
image = torch.nn.functional.pad(image, (left, right, top, bottom), mode="constant", value=0)
# add noise level as extra channel
image = image.to(device)
if isinstance(noise_level, torch.Tensor):
noise_level = noise_level / 255.0
else:
noise_level = torch.tensor([noise_level / 255.0]).to(device)
image = torch.cat(
(
image,
noise_level.repeat(image.shape[0], 1, image.shape[2], image.shape[3]),
),
dim=1,
)

# apply model
if mode == "inference":
with torch.no_grad():
image = model(image)
elif mode == "train":
image = model(image)
else:
raise ValueError("mode must be 'inference' or 'train'")

# remove padding
image = image[:, :, top:-bottom, left:-right]
# convert back to NDHWC
image = image.movedim(-3, -1)
image = image.reshape(-1, depth, *image.shape[-3:])
return image


def process_with_DruNet(model, device="cpu", mode="inference"):
"""
Return a porcessing function that applies the DruNet model to an image.

Parameters
----------
model : torch.nn.Module
DruNet like denoiser model
device : str
Device to use for computation. Can be "cpu" or "cuda".
mode : str
Mode to use for model. Can be "inference" or "train".
"""

def process(image, noise_level):
x_max = torch.amax(image, dim=(-2, -3), keepdim=True) + 1e-6
image = apply_denoiser(
model,
image,
noise_level=noise_level,
device=device,
mode="train",
)
image = torch.clip(image, min=0.0) * x_max
return image

return process
2 changes: 1 addition & 1 deletion scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def simulate_dataset(config, psf):

def create_process_network(network, depth, device="cpu"):
if network == "DruNet":
from lensless.utils.image import load_drunet
from lensless.recon.utils import load_drunet

process = load_drunet(
os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True
Expand Down