Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
12cd68b
WIP add FastMRI dataset and demosaicing task
Melvin-klein May 19, 2025
6a8108f
WIP
Melvin-klein May 26, 2025
80f232b
WIP Add ifft2 as baseline
Melvin-klein May 26, 2025
217436c
WIP skip ifft2 solver on datasets that are not FastMRI
Melvin-klein May 26, 2025
5b7602f
WIP
Melvin-klein Jun 5, 2025
65beb1a
WIP
Melvin-klein Jun 5, 2025
4525802
WIP : Add new denoiser for DPIR
Melvin-klein Jun 11, 2025
92e8791
Removed unused files
Melvin-klein Jun 11, 2025
a0227e5
Add DPIR_2C to handle imaginary images
Melvin-klein Jun 12, 2025
b8fa8ee
Refactor code
Melvin-klein Jun 12, 2025
2a1a435
Update DiffPIR for imaginary images
Melvin-klein Jun 12, 2025
5c1ea37
Every solvers run on SimpleFastMIR
Melvin-klein Jun 12, 2025
7442ff1
WIP: Change SimpleFastMRISliceDataset to FastMRISliceDataset
Melvin-klein Jun 16, 2025
c508509
WIP
Jun 18, 2025
27a3f07
WIP
Jun 18, 2025
4f9e189
WIP
Jul 17, 2025
0737ca6
WIP
Jul 18, 2025
386be6a
WIP
Melvin-klein Jul 18, 2025
ecc81ae
WIP
Melvin-klein Jul 19, 2025
8029157
WIP
Melvin-klein Jul 22, 2025
25b1375
UNet working
Melvin-klein Jul 23, 2025
ac88628
UNet, DPIR, DiffPIR working
Melvin-klein Jul 23, 2025
7a55efd
WIP
Melvin-klein Jul 23, 2025
be94a63
WIP
Melvin-klein Jul 29, 2025
23baf2b
WIP
Melvin-klein Aug 4, 2025
4af546b
Added inpainting, fix bugs
Melvin-klein Aug 25, 2025
65c1f65
Add inference time per degraded image to metrics
Melvin-klein Aug 27, 2025
32526b3
Fix U-Net solver's scheduler
Melvin-klein Aug 27, 2025
8ca02b8
Fix comments
Melvin-klein Sep 22, 2025
e04c973
merge main
Etyl Oct 21, 2025
20d9d35
FIX: linting
Etyl Oct 21, 2025
c54d702
more linting
Etyl Oct 21, 2025
919ea7c
FIX bsd500 dataset
Etyl Oct 21, 2025
354846e
fix: fix bsd500_cbsd84
Etyl Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions benchmark_utils/custom_models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
from deepinv.models import UNet

class CustomUNet(UNet):
def __init__(self, in_channels, out_channels, scales=3, batch_norm=False, is_mri=False):
self.name = "CustomUNet"
class MRIUNet(UNet):
def __init__(self, in_channels, out_channels, scales=3, batch_norm=False):
self.name = "MRIUNet"
self.in_channels = in_channels
self.is_mri = is_mri

super().__init__(in_channels=in_channels, out_channels=out_channels, scales=scales, batch_norm=batch_norm)

def forward(self, x, sigma=None, **kwargs):
if self.is_mri:
# Reshape for MRI specific processing
x = x.reshape(1, self.in_channels, x.shape[3], x.shape[4])

# Reshape for MRI specific processing
x = x.reshape(1, self.in_channels, x.shape[3], x.shape[4])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it forcing the batch_size to be 1? why not x.shape[0]? Why do you need to reshape at all actually?


x = super().forward(x, sigma=sigma, **kwargs)

if self.is_mri:
# Reshape for when MRI
x = x.reshape(2, 4, x.shape[2], x.shape[3])

return x
2 changes: 1 addition & 1 deletion benchmark_utils/fastmri_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ def __getitem__(self, idx):

# We add an imaginary part of zeros
x = torch.cat([x, torch.zeros_like(x)], dim=0)
y = y.reshape(8, y.shape[2], y.shape[3])
#y = y.reshape(8, y.shape[2], y.shape[3])

return x, y
7 changes: 5 additions & 2 deletions objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def evaluate_result(self, model, model_name, device):
model(y_i[None], self.physics) for y_i in y
])
else:
x_hat = model(y)
x_hat = model(y, self.physics)

if (self.dataset_name == 'FastMRI'):
transform = torchvision.transforms.Compose(
Expand Down Expand Up @@ -185,4 +185,7 @@ def get_objective(self):
# for `Solver.set_objective`. This defines the
# benchmark's API for passing the objective to the solver.
# It is customizable for each benchmark.
return dict(train_dataset=self.train_dataset, physics=self.physics, image_size=self.image_size)
return dict(train_dataset=self.train_dataset,
physics=self.physics,
image_size=self.image_size,
dataset_name=self.dataset_name,)
2 changes: 1 addition & 1 deletion solvers/diffpir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Solver(BaseSolver):

requirements = []

def set_objective(self, train_dataset, physics, image_size):
def set_objective(self, train_dataset, physics, image_size, dataset_name):
batch_size = 2
self.train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=False
Expand Down
37 changes: 20 additions & 17 deletions solvers/dpir.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Solver(BaseSolver):

requirements = []

def set_objective(self, train_dataset, physics, image_size):
def set_objective(self, train_dataset, physics, image_size, dataset_name):
batch_size = 2
self.train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=False
Expand All @@ -33,6 +33,7 @@ def set_objective(self, train_dataset, physics, image_size):
)
self.physics = physics
self.image_size = image_size
self.dataset_name = dataset_name

def run(self, n_iter):
best_sigma = 0
Expand All @@ -50,25 +51,27 @@ def run(self, n_iter):
for sigma in np.linspace(0.01, 0.1, 10):
model = model_class(sigma=sigma, device=self.device)

psnr = []

for x, y in self.train_dataloader:
x, y = x.to(self.device), y.to(self.device)

x_hat = model(y)
if (self.dataset_name == 'FastMRI'):
transform = torchvision.transforms.Compose(
[
torchvision.transforms.CenterCrop(x.shape[-2:]),
dinv.metric.functional.complex_abs,
]
)

CustomPSNR.transform = transform
psnr.append(CustomPSNR()(x_hat, x))
else:
psnr.append(dinv.metric.PSNR()(x_hat, x))
x_hat = model(y, self.physics)

if (self.dataset_name == 'FastMRI'):
transform = torchvision.transforms.Compose(
[
torchvision.transforms.CenterCrop(x.shape[-2:]),
dinv.metric.functional.complex_abs,
]
)

CustomPSNR.transform = transform

psnr.append(CustomPSNR()(x_hat, x))
else:
psnr.append(dinv.metric.PSNR()(x_hat, x))

psnr = torch.mean(torch.cat(psnr)).item()

#results = dinv.test(
Expand Down
18 changes: 13 additions & 5 deletions solvers/u-net.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import deepinv as dinv
import torchvision
from benchmark_utils.metrics import CustomMSE, CustomPSNR
from benchmark_utils.custom_models import CustomUNet
from benchmark_utils.custom_models import MRIUNet


class Solver(BaseSolver):
Expand All @@ -22,7 +22,7 @@ class Solver(BaseSolver):

requirements = []

def set_objective(self, train_dataset, physics, image_size):
def set_objective(self, train_dataset, physics, image_size, dataset_name):
batch_size = 1
self.train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=False
Expand All @@ -32,15 +32,23 @@ def set_objective(self, train_dataset, physics, image_size):
)
self.physics = physics.to(self.device)
self.image_size = image_size
self.dataset_name = dataset_name

def run(self, n_iter):
epochs = 4

x, y = next(iter(self.train_dataloader))

model = dinv.models.UNet(
in_channels=y.shape[1], out_channels=x.shape[1], scales=3, batch_norm=False
).to(self.device)
if self.dataset_name == 'FastMRI':
model = MRIUNet(
in_channels=y.shape[1] * y.shape[2], out_channels=x.shape[1], scales=3,
batch_norm=False
).to(self.device)
else:
model = dinv.models.UNet(
in_channels=y.shape[1], out_channels=x.shape[1], scales=3,
batch_norm=False
).to(self.device)

verbose = True # print training information
wandb_vis = False # plot curves and images in Weight&Bias
Expand Down
Loading