Skip to content

Commit

Permalink
Fixes for when no background is present.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Oct 2, 2024
1 parent 3ea395f commit 5fc0a05
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,20 +1565,26 @@ def _get_images_pair(self, idx):

if len(background_np.shape) == 2:
warnings.warn(f"Converting background[{idx}] to RGB")
background_np = np.stack([background_np] * 3, axis=2) if not None else None
background_np = (
np.stack([background_np] * 3, axis=2) if background_np is not None else None
)
elif len(background_np.shape) == 3:
pass

# convert to float
if lensless_np.dtype == np.uint8:
lensless_np = lensless_np.astype(np.float32) / 255
lensed_np = lensed_np.astype(np.float32) / 255
background_np = background_np.astype(np.float32) / 255 if not None else None
background_np = (
background_np.astype(np.float32) / 255 if background_np is not None else None
)
else:
# 16 bit
lensless_np = lensless_np.astype(np.float32) / 65535
lensed_np = lensed_np.astype(np.float32) / 65535
background_np = background_np.astype(np.float32) / 65535 if not None else None
background_np = (
background_np.astype(np.float32) / 65535 if background_np is not None else None
)

# downsample if necessary
if self.downsample_lensless != 1.0:
Expand All @@ -1591,13 +1597,13 @@ def _get_images_pair(self, idx):
factor=1 / self.downsample_lensless,
interpolation=cv2.INTER_NEAREST,
)
if not None
if background_np is not None
else None
)

lensless = lensless_np
lensed = lensed_np
background = background_np if not None else None
background = background_np if background_np is not None else None

if self.simulator is not None:
# convert to torch
Expand Down Expand Up @@ -1640,7 +1646,7 @@ def __getitem__(self, idx):
# to torch
lensless = torch.from_numpy(lensless)
lensed = torch.from_numpy(lensed)
background = torch.from_numpy(background) if not None else None
background = torch.from_numpy(background) if background is not None else None
# If [H, W, C] -> [D, H, W, C]
if len(lensless.shape) == 3:
lensless = lensless.unsqueeze(0)
Expand Down

0 comments on commit 5fc0a05

Please sign in to comment.