Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
12cd68b
WIP add FastMRI dataset and demosaicing task
Melvin-klein May 19, 2025
6a8108f
WIP
Melvin-klein May 26, 2025
80f232b
WIP Add ifft2 as baseline
Melvin-klein May 26, 2025
217436c
WIP skip ifft2 solver on datasets that are not FastMRI
Melvin-klein May 26, 2025
5b7602f
WIP
Melvin-klein Jun 5, 2025
65beb1a
WIP
Melvin-klein Jun 5, 2025
4525802
WIP : Add new denoiser for DPIR
Melvin-klein Jun 11, 2025
92e8791
Removed unused files
Melvin-klein Jun 11, 2025
a0227e5
Add DPIR_2C to handle imaginary images
Melvin-klein Jun 12, 2025
b8fa8ee
Refactor code
Melvin-klein Jun 12, 2025
2a1a435
Update DiffPIR for imaginary images
Melvin-klein Jun 12, 2025
5c1ea37
Every solvers run on SimpleFastMIR
Melvin-klein Jun 12, 2025
7442ff1
WIP: Change SimpleFastMRISliceDataset to FastMRISliceDataset
Melvin-klein Jun 16, 2025
c508509
WIP
Jun 18, 2025
27a3f07
WIP
Jun 18, 2025
4f9e189
WIP
Jul 17, 2025
0737ca6
WIP
Jul 18, 2025
386be6a
WIP
Melvin-klein Jul 18, 2025
ecc81ae
WIP
Melvin-klein Jul 19, 2025
8029157
WIP
Melvin-klein Jul 22, 2025
25b1375
UNet working
Melvin-klein Jul 23, 2025
ac88628
UNet, DPIR, DiffPIR working
Melvin-klein Jul 23, 2025
7a55efd
WIP
Melvin-klein Jul 23, 2025
be94a63
WIP
Melvin-klein Jul 29, 2025
23baf2b
WIP
Melvin-klein Aug 4, 2025
4af546b
Added inpainting, fix bugs
Melvin-klein Aug 25, 2025
65c1f65
Add inference time per degraded image to metrics
Melvin-klein Aug 27, 2025
32526b3
Fix U-Net solver's scheduler
Melvin-klein Aug 27, 2025
8ca02b8
Fix comments
Melvin-klein Sep 22, 2025
e04c973
merge main
Etyl Oct 21, 2025
20d9d35
FIX: linting
Etyl Oct 21, 2025
c54d702
more linting
Etyl Oct 21, 2025
919ea7c
FIX bsd500 dataset
Etyl Oct 21, 2025
354846e
fix: fix bsd500_cbsd84
Etyl Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ benchopt.ini

.DS_Store
coverage.xml

tmp
data/
22 changes: 22 additions & 0 deletions benchmark_utils/custom_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from deepinv.models import UNet


class MRIUNet(UNet):
def __init__(self, in_channels, out_channels, scales=3, batch_norm=False):
self.name = "MRIUNet"
self.in_channels = in_channels

super().__init__(
in_channels=in_channels,
out_channels=out_channels,
scales=scales,
batch_norm=batch_norm
)

def forward(self, x, sigma=None, **kwargs):
# Reshape for MRI specific processing
x = x.reshape(1, self.in_channels, x.shape[3], x.shape[4])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it forcing the batch_size to be 1? why not x.shape[0]? Why do you need to reshape at all actually?


x = super().forward(x, sigma=sigma, **kwargs)

return x
26 changes: 26 additions & 0 deletions benchmark_utils/denoiser_2c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from deepinv.models import DRUNet
from deepinv.models import Denoiser


class Denoiser_2c(Denoiser):
def __init__(self, device):
super(Denoiser_2c, self).__init__()
self.model_c1 = DRUNet(
in_channels=1, out_channels=1,
pretrained="download", device=device
)
self.model_c2 = DRUNet(
in_channels=1, out_channels=1,
pretrained="download", device=device
)

def forward(self, y, sigma):
y1, y2 = torch.split(y, 1, dim=1)

x_hat_1 = self.model_c1(y1, sigma=sigma)
x_hat_2 = self.model_c2(y2, sigma=sigma)

x_hat = torch.cat([x_hat_1, x_hat_2], dim=1)

return x_hat
49 changes: 49 additions & 0 deletions benchmark_utils/fastmri_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torch.utils.data import Dataset
from deepinv.datasets import FastMRISliceDataset
import torch.nn.functional as F


class FastMRIDataset(Dataset):
def __init__(self, dataset: FastMRISliceDataset, mask, max_coils=32):
self.dataset = dataset
self.max_coils = max_coils
self.mask = mask

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
x, y = self.dataset[idx]
x, y = x.to(device=self.mask.device), y.to(device=self.mask.device)

# Pad the width
target_width = 400
pad_total = target_width - y.shape[3]
pad_left = pad_total // 2
pad_right = pad_total - pad_left
y = F.pad(y, (pad_left, pad_right, 0, 0), mode='constant', value=0)

# Pad the height
target_height = 700
pad_total = target_height - y.shape[2]
pad_left = pad_total // 2
pad_right = pad_total - pad_left
y = F.pad(y, (0, 0, pad_left, pad_right), mode='constant', value=0)

# Transform the mask to match the kspace shape
mask = self.mask.repeat(y.shape[0], y.shape[1], 1, 1)

# Apply the mask to the k-space data
y = y * mask

# Add an imaginary part of zeros
x = torch.cat([x, torch.zeros_like(x)], dim=0)

# Pad the coil dimension if necessary
coil_dim = y.shape[1]
if coil_dim < self.max_coils:
pad_size = self.max_coils - coil_dim
y = F.pad(y, (0, 0, 0, 0, 0, pad_size))

return x, y
15 changes: 11 additions & 4 deletions benchmark_utils/hugging_face_torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,26 @@


class HuggingFaceTorchDataset(torch.utils.data.Dataset):
def __init__(self, hf_dataset, key, transform=None):
def __init__(self, hf_dataset, key, physics, device, transform=None):
self.hf_dataset = hf_dataset
self.transform = transform
self.key = key
self.device = device
self.physics = physics

def __len__(self):
return len(self.hf_dataset)

def __getitem__(self, idx):
sample = self.hf_dataset[idx]
image = sample[self.key] # Image PIL
x = sample[self.key] # Image PIL

if self.transform:
image = self.transform(image)
x = self.transform(x)

return image
x = x.to(self.device)

y = self.physics(x.unsqueeze(0))
y = y.squeeze(0)

return x, y
33 changes: 33 additions & 0 deletions benchmark_utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import deepinv as dinv


class CustomMSE(dinv.metric.MSE):

transform = lambda x: x # noqa: E731

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)


class CustomPSNR(dinv.metric.PSNR):

transform = lambda x: x # noqa: E731

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)


class CustomSSIM(dinv.metric.SSIM):

transform = lambda x: x # noqa: E731

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)


class CustomLPIPS(dinv.metric.LPIPS):

transform = lambda x: x # noqa: E731

def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(self.transform(x_net), x, *args, **kwargs)
3 changes: 3 additions & 0 deletions config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data_paths:
fastmri_train: /data/parietal/store3/data/fastMRI-multicoil/multicoil_train
fastmri_test: /data/parietal/store3/data/fastMRI-multicoil/multicoil_val
72 changes: 37 additions & 35 deletions datasets/bsd500_cbsd68.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from benchmark_utils.hugging_face_torch_dataset import (
HuggingFaceTorchDataset
)
from deepinv.physics import Denoising, GaussianNoise, Downsampling
from deepinv.physics import (
Denoising,
GaussianNoise,
Downsampling,
Demosaicing
)
from deepinv.physics.generator import MotionBlurGenerator


Expand All @@ -21,7 +26,9 @@ class Dataset(BaseDataset):
'task': ['denoising',
'gaussian-debluring',
'motion-debluring',
'SRx4'],
'SRx4',
'inpainting',
'demosaicing'],
'img_size': [256],
}

Expand All @@ -32,23 +39,24 @@ def get_data(self):
device = (
dinv.utils.get_freer_gpu()) if torch.cuda.is_available() else "cpu"

n_channels = 3
img_size = (n_channels, self.img_size, self.img_size)

if self.task == "denoising":
noise_level_img = 0.03
noise_level_img = 0.1
physics = Denoising(GaussianNoise(sigma=noise_level_img))
elif self.task == "gaussian-debluring":
filter_torch = dinv.physics.blur.gaussian_blur(sigma=(3, 3))
noise_level_img = 0.03
n_channels = 3

physics = dinv.physics.BlurFFT(
img_size=(n_channels, self.img_size, self.img_size),
img_size=img_size,
filter=filter_torch,
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
device=device
)
elif self.task == "motion-debluring":
psf_size = 31
n_channels = 3
motion_generator = MotionBlurGenerator(
(psf_size, psf_size),
device=device
Expand All @@ -57,18 +65,22 @@ def get_data(self):
filters = motion_generator.step(batch_size=1)

physics = dinv.physics.BlurFFT(
img_size=(n_channels, self.img_size, self.img_size),
img_size=img_size,
filter=filters["filter"],
device=device
)
elif self.task == "SRx4":
n_channels = 3
physics = Downsampling(img_size=(n_channels,
self.img_size,
self.img_size),
physics = Downsampling(img_size=img_size,
filter="bicubic",
factor=4,
device=device)
elif self.task == "inpainting":
physics = dinv.physics.Inpainting(img_size,
mask=0.7,
device=device)
elif self.task == "demosaicing":
physics = Demosaicing(img_size=img_size,
device=device)
else:
raise Exception("Unknown task")

Expand All @@ -78,41 +90,31 @@ def get_data(self):
])

path = get_data_path("BSD500")
train_dataset = dinv.datasets.BSDS500(
bsd500_dataset = dinv.datasets.BSDS500(
path, download=True, transform=transform
)
train_dataset = HuggingFaceTorchDataset(
bsd500_dataset,
key=...,
physics=physics,
device=device,
transform=transforms.Resize((self.img_size, self.img_size))
)

dataset_cbsd68 = load_dataset("deepinv/CBSD68")
test_dataset = HuggingFaceTorchDataset(
dataset_cbsd68["train"], key="png", transform=transform
)

dinv_dataset_path = dinv.datasets.generate_dataset(
train_dataset=train_dataset,
test_dataset=test_dataset,
dataset_cbsd68["train"],
key="png",
physics=physics,
save_dir=get_data_path("bsd500_cbsd68"),
dataset_filename=self.task,
device=device
)

train_dataset = dinv.datasets.HDF5Dataset(
path=dinv_dataset_path, train=True
device=device,
transform=transform
)
test_dataset = dinv.datasets.HDF5Dataset(
path=dinv_dataset_path, train=False
)

x, y = train_dataset[0]
dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)])

x, y = test_dataset[0]
dinv.utils.plot([x.unsqueeze(0), y.unsqueeze(0)])

return dict(
train_dataset=train_dataset,
test_dataset=test_dataset,
physics=physics,
dataset_name="BSD68",
task_name=self.task
task_name=self.task,
image_size=img_size
)
Loading
Loading