Skip to content

Commit

Permalink
Add support for single channel and flipping.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jun 19, 2024
1 parent 26e860d commit 9afc864
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 9 deletions.
37 changes: 37 additions & 0 deletions configs/train_mirflickr_diffuser.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape
defaults:
- train_unrolledADMM
- _self_

torch_device: 'cuda:0'
device_ids: [0, 1, 2, 3]
eval_disp_idx: [0, 1, 3, 4, 8]

# Dataset
files:
dataset: bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM
huggingface_dataset: True
huggingface_psf: psf.tiff
single_channel_psf: True
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution
downsample_lensed: 2 # only used if lensed if measured
flipud: True
flip_lensed: True

training:
batch_size: 4
epoch: 25
eval_batch_size: 4

reconstruction:
method: unrolled_admm
unrolled_admm:
n_iter: 5
pre_process:
network : UnetRes # UnetRes or DruNet or null
depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [32,64,116,128]
post_process:
network : UnetRes # UnetRes or DruNet or null
depth : 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [32,64,116,128]
2 changes: 2 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ files:
vertical_shift: null
horizontal_shift: null
rotate: False
flipud: False
flip_lensed: False
save_psf: False
crop: null
# vertical: null
Expand Down
10 changes: 5 additions & 5 deletions lensless/recon/multi_wiener.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def forward(self, x):
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.avgpool_conv = nn.Sequential(
self.pool_conv = nn.Sequential(
# nn.AvgPool2d(2),
nn.MaxPool2d(2),
nn.MaxPool2d(2), # original paper says max-pooling
DoubleConv(in_channels, out_channels),
)

def forward(self, x):
return self.avgpool_conv(x)
return self.pool_conv(x)


class Up(nn.Module):
Expand Down Expand Up @@ -120,9 +120,9 @@ def __init__(
"""
assert in_channels == 1 or in_channels == 3, "in_channels must be 1 or 3"
assert out_channels == 1 or out_channels == 3, "out_channels must be 1 or 3"
assert in_channels >= out_channels
if nc is None:
nc = [64, 128, 256, 512, 512]
# assert nc[-1] == nc[-2], "Last two channels must be the same"

super(MultiWiener, self).__init__()
self.in_channels = in_channels
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
self._psf, (self.left, self.right, self.top, self.bottom), mode="constant", value=0
)
self._n_iter = 1
self._convolver = RealFFTConvolve2D(psf, pad=True)
self._convolver = RealFFTConvolve2D(psf, pad=True, rgb=True if out_channels == 3 else False)

self.set_pre_process(pre_process)
self.skip_pre = skip_pre
Expand Down
9 changes: 6 additions & 3 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class RealFFTConvolve2D:
def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs):
def __init__(self, psf, dtype=None, pad=True, norm="ortho", rgb=None, **kwargs):
"""
Linear operator that performs convolution in Fourier domain, and assumes
real-valued signals.
Expand Down Expand Up @@ -56,7 +56,10 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs):
len(psf.shape) >= 4
), "Expected 4D PSF of shape ([batch], depth, width, height, channels)"
self._use_3d = psf.shape[-4] != 1
self._is_rgb = psf.shape[-1] == 3
if rgb is None:
self._is_rgb = psf.shape[-1] == 3
else:
self._is_rgb = rgb
assert self._is_rgb or psf.shape[-1] == 1

# save normalization
Expand All @@ -80,7 +83,7 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs):
self._padded_shape = 2 * self._psf_shape[-3:-1] - 1
self._padded_shape = np.array([next_fast_len(i) for i in self._padded_shape])
self._padded_shape = list(
np.r_[self._psf_shape[-4], self._padded_shape, self._psf_shape[-1]]
np.r_[self._psf_shape[-4], self._padded_shape, 3 if self._is_rgb else 1]
)
self._start_idx = (self._padded_shape[-3:-1] - self._psf_shape[-3:-1]) // 2
self._end_idx = self._start_idx + self._psf_shape[-3:-1]
Expand Down
14 changes: 13 additions & 1 deletion scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def train_learned(config):
split=split_train,
display_res=config.files.image_res,
rotate=config.files.rotate,
flipud=config.files.flipud,
flip_lensed=config.files.flip_lensed,
downsample=config.files.downsample,
downsample_lensed=config.files.downsample_lensed,
alignment=config.alignment,
Expand All @@ -229,6 +231,8 @@ def train_learned(config):
split=split_test,
display_res=config.files.image_res,
rotate=config.files.rotate,
flipud=config.files.flipud,
flip_lensed=config.files.flip_lensed,
downsample=config.files.downsample,
downsample_lensed=config.files.downsample_lensed,
alignment=config.alignment,
Expand Down Expand Up @@ -478,11 +482,19 @@ def train_learned(config):
else False,
)
elif config.reconstruction.method == "multi_wiener":

if config.files.single_channel_psf:
psf = psf[..., 0].unsqueeze(-1)
psf_channels = 1
else:
psf_channels = 3
print(psf.shape)

recon = MultiWiener(
in_channels=3,
out_channels=3,
psf=psf,
psf_channels=3,
psf_channels=psf_channels,
nc=config.reconstruction.multi_wiener.nc,
pre_process=pre_process if pre_proc_delay is None else None,
)
Expand Down

0 comments on commit 9afc864

Please sign in to comment.