From d44ab57a44f65d548ebb48c51028fee32714fe60 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Thu, 2 May 2024 17:15:42 +0200 Subject: [PATCH 1/2] Option to downsample data before HF upload. (#128) * Add option for BGR conversion. * Add downsample option before Hugging Face upload. --- configs/upload_dataset_huggingface.yaml | 1 + lensless/utils/io.py | 8 ++++- scripts/data/upload_dataset_huggingface.py | 35 ++++++++++++++++++---- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/configs/upload_dataset_huggingface.yaml b/configs/upload_dataset_huggingface.yaml index 31ec9192..61c73b55 100644 --- a/configs/upload_dataset_huggingface.yaml +++ b/configs/upload_dataset_huggingface.yaml @@ -15,6 +15,7 @@ lensless: dir: null ext: null # for example: .png, .jpg eight_norm: False # save as 8-bit normalized image + downsample: null lensed: dir: null diff --git a/lensless/utils/io.py b/lensless/utils/io.py index a51feff5..47fd94f4 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -37,6 +37,7 @@ def load_image( shape=None, dtype=None, normalize=True, + bgr_input=True, ): """ Load image as numpy array. @@ -151,7 +152,7 @@ def load_image( ) else: - if len(img.shape) == 3: + if len(img.shape) == 3 and bgr_input: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) original_dtype = img.dtype @@ -223,6 +224,7 @@ def load_psf( single_psf=False, shape=None, use_3d=False, + bgr_input=True, ): """ Load and process PSF for analysis or for reconstruction. @@ -296,6 +298,7 @@ def load_psf( blue_gain=blue_gain, red_gain=red_gain, nbits_out=nbits_out, + bgr_input=bgr_input, ) original_dtype = psf.dtype @@ -391,6 +394,7 @@ def load_data( torch=False, torch_device="cpu", normalize=False, + bgr_input=True, ): """ Load data for image reconstruction. @@ -471,6 +475,7 @@ def load_data( single_psf=single_psf, shape=shape, use_3d=use_3d, + bgr_input=bgr_input, ) # load and process raw measurement @@ -485,6 +490,7 @@ def load_data( return_float=return_float, shape=shape, normalize=normalize, + bgr_input=bgr_input, ) if data.shape != psf.shape: diff --git a/scripts/data/upload_dataset_huggingface.py b/scripts/data/upload_dataset_huggingface.py index 8edf598d..2b212d97 100644 --- a/scripts/data/upload_dataset_huggingface.py +++ b/scripts/data/upload_dataset_huggingface.py @@ -24,6 +24,8 @@ from lensless.utils.dataset import natural_sort from tqdm import tqdm from lensless.utils.io import save_image +import cv2 +from joblib import Parallel, delayed @hydra.main( @@ -82,24 +84,44 @@ def upload_dataset(config): ] lensed_files = [os.path.join(config.lensed.dir, f + config.lensed.ext) for f in common_files] + if config.lensless.downsample is not None: + + tmp_dir = config.lensless.dir + "_tmp" + os.makedirs(tmp_dir, exist_ok=True) + + def downsample(f, output_dir): + img = cv2.imread(f, cv2.IMREAD_UNCHANGED) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize( + img, + (0, 0), + fx=1 / config.lensless.downsample, + fy=1 / config.lensless.downsample, + interpolation=cv2.INTER_LINEAR, + ) + new_fp = os.path.join(output_dir, os.path.basename(f)) + new_fp = new_fp.split(".")[0] + config.lensless.ext + save_image(img, new_fp, normalize=False) + + Parallel(n_jobs=n_jobs)(delayed(downsample)(f, tmp_dir) for f in tqdm(lensless_files)) + lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}")) + # convert to normalized 8 bit if config.lensless.eight_norm: - import cv2 - from joblib import Parallel, delayed - tmp_dir = config.lensless.dir + "_tmp" os.makedirs(tmp_dir, exist_ok=True) # -- parallelize with joblib def save_8bit(f, output_dir, normalize=True): img = cv2.imread(f, cv2.IMREAD_UNCHANGED) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) new_fp = os.path.join(output_dir, os.path.basename(f)) - new_fp = new_fp.split(".")[0] + ".png" + new_fp = new_fp.split(".")[0] + config.lensless.ext save_image(img, new_fp, normalize=normalize) Parallel(n_jobs=n_jobs)(delayed(save_8bit)(f, tmp_dir) for f in tqdm(lensless_files)) - lensless_files = glob.glob(os.path.join(tmp_dir, "*png")) + lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}")) # check for attribute df_attr = None @@ -222,6 +244,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None): upload_file( path_or_fileobj=lensless_files[0], + # path_in_repo=f"lensless_example{config.lensless.ext}" if not config.lensless.eight_norm else f"lensless_example.png", path_in_repo=f"lensless_example{config.lensless.ext}", repo_id=repo_id, repo_type="dataset", @@ -248,7 +271,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None): print(f"Total time: {(time.time() - start_time) / 60} minutes") # delete PNG files - if config.lensless.eight_norm: + if config.lensless.eight_norm or config.lensless.downsample: os.system(f"rm -rf {tmp_dir}") From e1886451cb92ecdf7c47c16537e146e121f46d36 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 14 May 2024 12:12:34 +0200 Subject: [PATCH 2/2] Improve alignment API, phase mask design, expose waveprop for PSF simulation (#130) * Add option for BGR conversion. * Add downsample option before Hugging Face upload. * Remove prints. * Add option to use waveprop in simulating PSF, improve ROI extraction API. * Improve load_data interface. * Add one-shot method for ADMM and GD. * Refactor is_torch to use_torch. * Add option to center mask. * Height map as result for phase mask. * Make normalize explicit. * Fix top_left naming. * Update CHANGELOG. * Fix mask tests. --- CHANGELOG.rst | 2 + configs/benchmark.yaml | 5 +- configs/train_digicam_multimask.yaml | 4 +- configs/train_digicam_singlemask.yaml | 4 +- configs/train_unrolledADMM.yaml | 5 +- lensless/eval/benchmark.py | 19 ++- lensless/hardware/mask.py | 137 ++++++++++++------- lensless/hardware/slm.py | 6 +- lensless/recon/admm.py | 26 +++- lensless/recon/gd.py | 25 +++- lensless/recon/utils.py | 17 +-- lensless/utils/dataset.py | 67 ++++++--- lensless/utils/io.py | 10 +- profile/admm.py | 4 +- profile/gradient_descent.py | 4 +- scripts/eval/benchmark_recon.py | 3 +- scripts/measure/collect_dataset_on_device.py | 2 +- scripts/measure/remote_capture.py | 2 +- scripts/recon/admm.py | 2 +- scripts/recon/dataset.py | 4 +- scripts/recon/digicam_mirflickr.py | 4 +- scripts/recon/gradient_descent.py | 2 +- scripts/recon/train_learning_based.py | 20 +-- test/test_algos.py | 10 +- test/test_io.py | 2 +- test/test_masks.py | 11 +- 26 files changed, 262 insertions(+), 135 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7dd4d036..579d73cf 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -36,6 +36,8 @@ Changed - For trainable masks, set trainable parameters inside the child class. - ``distance_sensor`` optional for ``lensless.hardware.mask.Mask``, e.g. don't need for fabrication. - More intuitive interface for MURA for coded aperture (``lensless.hardware.mask.CodedAperture``), i.e. directly pass prime number. +- ``is_torch`` to ``use_torch`` in ``lensless.hardware.mask.Mask`` +- ``self.height_map`` as characterization of phase masks (instead of phase pattern which can change for each wavelength) Bugfix diff --git a/configs/benchmark.yaml b/configs/benchmark.yaml index edf5eac6..24e0366e 100644 --- a/configs/benchmark.yaml +++ b/configs/benchmark.yaml @@ -7,7 +7,7 @@ hydra: chdir: True -dataset: DiffuserCam # DiffuserCam, DigiCamCelebA, DigiCamHF +dataset: DiffuserCam # DiffuserCam, DigiCamCelebA, HFDataset seed: 0 huggingface: @@ -15,7 +15,7 @@ huggingface: image_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down alignment: - topright: [80, 100] # height, width + top_left: [80, 100] # height, width height: 200 downsample: 1 @@ -88,6 +88,7 @@ simulation: scene2mask: 0.25 # [m] mask2sensor: 0.002 # [m] # see waveprop.devices + use_waveprop: False # for PSF simulation sensor: "rpi_hq" snr_db: 10 # simulate different sensor resolution diff --git a/configs/train_digicam_multimask.yaml b/configs/train_digicam_multimask.yaml index e05dda06..654c2468 100644 --- a/configs/train_digicam_multimask.yaml +++ b/configs/train_digicam_multimask.yaml @@ -25,13 +25,13 @@ files: display_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down alignment: - topright: [80, 100] # height, width + top_left: [80, 100] # height, width height: 200 # TODO: these parameters should be in the dataset? alignment: # when there is no downsampling - topright: [80, 100] # height, width + top_left: [80, 100] # height, width height: 200 training: diff --git a/configs/train_digicam_singlemask.yaml b/configs/train_digicam_singlemask.yaml index 932d68a8..c919c195 100644 --- a/configs/train_digicam_singlemask.yaml +++ b/configs/train_digicam_singlemask.yaml @@ -24,13 +24,13 @@ files: display_res: [900, 1200] # used during measurement rotate: True # if measurement is upside-down alignment: - topright: [80, 100] # height, width + top_left: [80, 100] # height, width height: 200 # TODO: these parameters should be in the dataset? alignment: # when there is no downsampling - topright: [80, 100] # height, width + top_left: [80, 100] # height, width height: 200 training: diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 47fba326..72d9e621 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -28,7 +28,7 @@ files: # -- processing parameters downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution - downsample_lensed: 2 + downsample_lensed: 2 # only used if lensed if measured input_snr: null # adding shot noise at input (for measured dataset) at this SNR in dB vertical_shift: null horizontal_shift: null @@ -41,7 +41,7 @@ files: extra_eval: null # dict of extra datasets to evaluate on alignment: null -# topright: null # height, width +# top_left: null # height, width # height: null torch: True @@ -129,6 +129,7 @@ simulation: scene2mask: 10e-2 # scene2mask: 40e-2 mask2sensor: 9e-3 # mask2sensor: 4e-3 # see waveprop.devices + use_waveprop: False # for PSF simulation sensor: "rpi_hq" snr_db: 10 # simulate different sensor resolution diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 8df388c1..75f54c3b 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -84,10 +84,6 @@ def benchmark( if not os.path.exists(output_dir): os.mkdir(output_dir) - alignment = None - if hasattr(dataset, "alignment"): - alignment = dataset.alignment - if metrics is None: metrics = { "MSE": MSELoss().to(device), @@ -156,13 +152,14 @@ def benchmark( prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) - if alignment is not None: - prediction = prediction[ - ..., - alignment["topright"][0] : alignment["topright"][0] + alignment["height"], - alignment["topright"][1] : alignment["topright"][1] + alignment["width"], - ] - # expected that lensed is also reshaped accordingly + if hasattr(dataset, "alignment"): + if dataset.alignment is not None: + prediction = dataset.extract_roi(prediction, axis=(-2, -1)) + else: + prediction, lensed = dataset.extract_roi( + prediction, axis=(-2, -1), lensed=lensed + ) + assert np.all(lensed.shape == prediction.shape) elif crop is not None: prediction = prediction[ ..., diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index a14d8c61..0e8c9456 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -54,8 +54,9 @@ def __init__( size=None, feature_size=None, psf_wavelength=[460e-9, 550e-9, 640e-9], - is_torch=False, + use_torch=False, torch_device="cpu", + centered=True, **kwargs, ): """ @@ -73,6 +74,12 @@ def __init__( Size of the feature (m). Only one of ``size`` or ``feature_size`` needs to be specified. psf_wavelength: list, optional List of wavelengths to simulate PSF (m). Default is [460e-9, 550e-9, 640e-9] nm (blue, green, red). + use_torch : bool, optional + If True, the mask is created as a torch tensor. Default is False. + torch_device : str, optional + Device to use for torch tensor. Default is 'cpu'. + centered: bool, optional + If True, the mask is centered. Default is True. """ resolution = np.array(resolution) @@ -100,21 +107,22 @@ def __init__( self.resolution = resolution self.resolution = (int(self.resolution[0]), int(self.resolution[1])) self.size = size + self.centered = centered if feature_size is None: self.feature_size = self.size / self.resolution else: self.feature_size = feature_size self.distance_sensor = distance_sensor - if is_torch: + if use_torch: assert torch_available, "PyTorch is not available" - self.is_torch = is_torch + self.use_torch = use_torch self.torch_device = torch_device # create mask - self.phase_pattern = None # for phase masks + self.height_map = None # for phase masks self.create_mask() # creates self.mask - self.shape = self.mask.shape + self.shape = self.height_map.shape if self.height_map is not None else self.mask.shape # PSF assert hasattr(psf_wavelength, "__len__"), "psf_wavelength should be a list" @@ -161,7 +169,31 @@ def create_mask(self): """ pass - def compute_psf(self, distance_sensor=None): + def height_map_to_field(self, wavelength, return_phase=False): + """ + Compute phase from height map. + + Parameters + ---------- + height_map: :py:class:`~numpy.ndarray` + Height map. + wavelength: float + Wavelength of the light (m). + return_phase: bool, optional + If True, return the phase instead of the field. Default is False. + """ + assert self.height_map is not None, "Height map should be computed first." + assert self.refractive_index is not None, "Refractive index should be specified." + + phase_pattern = self.height_map * (self.refractive_index - 1) * 2 * np.pi / wavelength + if return_phase: + return phase_pattern + else: + return ( + np.exp(1j * phase_pattern) if not self.use_torch else torch.exp(1j * phase_pattern) + ) + + def compute_psf(self, distance_sensor=None, wavelength=None, intensity=True): """ Compute the intensity PSF with bandlimited angular spectrum (BLAS) for each wavelength. Common to all types of masks. @@ -170,6 +202,8 @@ def compute_psf(self, distance_sensor=None): ---------- distance_sensor: float, optional Distance between mask and sensor (m). Default is the distance specified at initialization. + wavelength: float or array_like, optional + Wavelength(s) to compute the PSF (m). Default is the list of wavelengths specified at initialization. """ if distance_sensor is not None: self.distance_sensor = distance_sensor @@ -177,30 +211,38 @@ def compute_psf(self, distance_sensor=None): self.distance_sensor is not None ), "Distance between mask and sensor should be specified." - if self.is_torch: + if wavelength is None: + wavelength = self.psf_wavelength + else: + if not hasattr(wavelength, "__len__"): + wavelength = [wavelength] + + if self.use_torch: psf = torch.zeros( tuple(self.resolution) + (len(self.psf_wavelength),), dtype=torch.complex64, device=self.torch_device, ) else: - psf = np.zeros(tuple(self.resolution) + (len(self.psf_wavelength),), dtype=np.complex64) - for i, wv in enumerate(self.psf_wavelength): + psf = np.zeros(tuple(self.resolution) + (len(wavelength),), dtype=np.complex64) + for i, wv in enumerate(wavelength): psf[:, :, i] = angular_spectrum( - u_in=self.mask, + u_in=self.mask if self.height_map is None else self.height_map_to_field(wv), wv=wv, d1=self.feature_size, dz=self.distance_sensor, - dtype=np.float32 if not self.is_torch else torch.float32, + dtype=np.float32 if not self.use_torch else torch.float32, bandlimit=True, - device=self.torch_device if self.is_torch else None, + device=self.torch_device if self.use_torch else None, )[0] # intensity PSF - if self.is_torch: - self.psf = torch.abs(psf) ** 2 + if intensity: + self.psf = np.abs(psf) ** 2 if not self.use_torch else torch.abs(psf) ** 2 else: - self.psf = np.abs(psf) ** 2 + self.psf = psf + + return self.psf def plot(self, ax=None, **kwargs): """ @@ -217,18 +259,26 @@ def plot(self, ax=None, **kwargs): if ax is None: _, ax = plt.subplots() - if self.phase_pattern is not None: - mask = self.phase_pattern - title = "Phase pattern" + if self.height_map is not None: + mask = self.height_map + title = "Height map" else: mask = self.mask - title = "Mask" - if self.is_torch: + title = "Amplitude mask" + if self.use_torch: mask = mask.cpu().numpy() - ax.imshow( - mask, extent=(0, 1e3 * self.size[1], 1e3 * self.size[0], 0), cmap="gray", **kwargs - ) + if self.centered: + extent = ( + -self.size[1] / 2 * 1e3, + self.size[1] / 2 * 1e3, + self.size[0] / 2 * 1e3, + -self.size[0] / 2 * 1e3, + ) + else: + extent = (0, self.size[1] * 1e3, self.size[0] * 1e3, 0) + + ax.imshow(mask, extent=extent, cmap="gray", **kwargs) ax.set_title(title) ax.set_xlabel("[mm]") ax.set_ylabel("[mm]") @@ -301,7 +351,7 @@ def create_mask(self, row=None, col=None, mask=None): # output product if necessary if self.row is not None: - if self.is_torch: + if self.use_torch: self.mask = torch.outer(self.row, self.col) self.mask = torch.round((self.mask + 1) / 2).to(torch.uint8) else: @@ -312,7 +362,7 @@ def create_mask(self, row=None, col=None, mask=None): # resize to sensor shape if np.any(self.resolution != self.mask.shape): - if self.is_torch: + if self.use_torch: self.mask = self.mask.unsqueeze(0).unsqueeze(0) self.mask = torch.nn.functional.interpolate( self.mask, size=tuple(self.resolution), mode="nearest" @@ -434,7 +484,6 @@ def __init__( radius=None, loc=None, refractive_index=1.2, - design_wv=532e-9, seed=0, min_height=1e-5, radius_range=(1e-4, 4e-4), @@ -455,8 +504,6 @@ def __init__( Location of the lenses (m). refractive_index: float Refractive index of the mask substrate. Default is 1.2. - design_wv: float - Wavelength used to design the mask (m). Default is 532e-9. seed: int Seed for the random number generator. Default is 0. min_height: float @@ -472,7 +519,6 @@ def __init__( self.radius = radius self.loc = loc self.refractive_index = refractive_index - self.wavelength = design_wv self.seed = seed self.min_height = min_height self.radius_range = radius_range @@ -491,7 +537,7 @@ def check_asserts(self): self.radius_range[0] < self.radius_range[1] ), "Minimum radius should be smaller than maximum radius" if self.radius is not None: - if self.is_torch: + if self.use_torch: assert torch.all(self.radius >= 0) else: assert np.all(self.radius >= 0) @@ -504,7 +550,7 @@ def check_asserts(self): self.N = len(self.radius) circles = ( np.array([(self.loc[i][0], self.loc[i][1], self.radius[i]) for i in range(self.N)]) - if not self.is_torch + if not self.use_torch else torch.tensor( [(self.loc[i][0], self.loc[i][1], self.radius[i]) for i in range(self.N)] ).to(self.torch_device) @@ -520,7 +566,9 @@ def check_asserts(self): self.radius = np.random.uniform(self.radius_range[0], self.radius_range[1], self.N) # radius get sorted in descending order self.loc, self.radius = self.place_spheres_on_plane(self.radius) - if self.is_torch: + if self.centered: + self.loc = self.loc - np.array(self.size) / 2 + if self.use_torch: self.radius = torch.tensor(self.radius).to(self.torch_device) self.loc = torch.tensor(self.loc).to(self.torch_device) @@ -594,35 +642,32 @@ def create_mask(self, loc=None, radius=None): # convert to pixels (assume same size for x and y) locs_pix = self.loc * (1 / self.feature_size[0]) radius_pix = self.radius * (1 / self.feature_size[0]) - height = self.create_height_map(radius_pix, locs_pix) - self.phase_pattern = height * (self.refractive_index - 1) * 2 * np.pi / self.wavelength - self.mask = ( - np.exp(1j * self.phase_pattern) - if not self.is_torch - else torch.exp(1j * self.phase_pattern) - ) + self.height_map = self.create_height_map(radius_pix, locs_pix) def create_height_map(self, radius, locs): height = ( np.full((self.resolution[0], self.resolution[1]), self.min_height).astype(np.float32) - if not self.is_torch + if not self.use_torch else torch.full((self.resolution[0], self.resolution[1]), self.min_height).to( self.torch_device, dtype=torch.float32 ) ) x = ( np.arange(self.resolution[0]).astype(np.float32) - if not self.is_torch + if not self.use_torch else torch.arange(self.resolution[0]).to(self.torch_device) ) y = ( np.arange(self.resolution[1]).astype(np.float32) - if not self.is_torch + if not self.use_torch else torch.arange(self.resolution[1]).to(self.torch_device) ) + if self.centered: + x = x - self.resolution[0] / 2 + y = y - self.resolution[1] / 2 X, Y = ( np.meshgrid(x, y, indexing="ij") - if not self.is_torch + if not self.use_torch else torch.meshgrid(x, y, indexing="ij") ) for idx, rad in enumerate(radius): @@ -635,7 +680,7 @@ def create_height_map(self, radius, locs): def lens_contribution(self, x, y, radius, loc): return ( np.sqrt(radius**2 - (x - loc[1]) ** 2 - (y - loc[0]) ** 2) - if not self.is_torch + if not self.use_torch else torch.sqrt(radius**2 - (x - loc[1]) ** 2 - (y - loc[0]) ** 2) ) @@ -697,7 +742,7 @@ def create_mask(self): assert ( self.distance_sensor is not None ), "Distance between mask and sensor should be specified." - phase_mask, height_map = phase_retrieval( + _, height_map = phase_retrieval( target_psf=self.target_psf, wv=self.design_wv, d1=self.feature_size, @@ -707,8 +752,6 @@ def create_mask(self): height_map=True, ) self.height_map = height_map - self.phase_pattern = phase_mask - self.mask = np.exp(1j * phase_mask) def phase_retrieval(target_psf, wv, d1, dz, n=1.2, n_iter=10, height_map=False): diff --git a/lensless/hardware/slm.py b/lensless/hardware/slm.py index 5d0b70be..29ea9d3a 100644 --- a/lensless/hardware/slm.py +++ b/lensless/hardware/slm.py @@ -240,11 +240,11 @@ def adafruit_sub2full( # pad to full pattern pattern = np.zeros((3, 128, 160), dtype=np.uint8) - topleft = [center[0] - controllable_shape[1] // 2, center[1] - controllable_shape[2] // 2] + top_left = [center[0] - controllable_shape[1] // 2, center[1] - controllable_shape[2] // 2] pattern[ :, - topleft[0] : topleft[0] + controllable_shape[1], - topleft[1] : topleft[1] + controllable_shape[2], + top_left[0] : top_left[0] + controllable_shape[1], + top_left[1] : top_left[1] + controllable_shape[2], ] = subpattern_rgb.astype(np.uint8) return pattern diff --git a/lensless/recon/admm.py b/lensless/recon/admm.py index bd50d6db..6be0d276 100644 --- a/lensless/recon/admm.py +++ b/lensless/recon/admm.py @@ -10,6 +10,8 @@ import numpy as np from lensless.recon.recon import ReconstructionAlgorithm from scipy import fft +from lensless.utils.io import load_data +import time try: import torch @@ -45,7 +47,7 @@ def __init__( norm="backward", # PnP denoiser=None, - **kwargs + **kwargs, ): """ @@ -393,3 +395,25 @@ def finite_diff_gram(shape, dtype=None, is_torch=False): return torch.fft.rfft2(gram, dim=(-3, -2)) else: return fft.rfft2(gram, axes=(-3, -2)) + + +def apply_admm(psf_fp, data_fp, n_iter, verbose=False, **kwargs): + + # load data + psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs) + + # create reconstruction object + recon = ADMM(psf, n_iter=n_iter) + + # set data + recon.set_data(data) + + # perform reconstruction + start_time = time.time() + res = recon.apply(plot=False) + proc_time = time.time() - start_time + + if verbose: + print(f"Reconstruction time : {proc_time} s") + print(f"Reconstruction shape: {res.shape}") + return res diff --git a/lensless/recon/gd.py b/lensless/recon/gd.py index 646906f4..dc61e809 100644 --- a/lensless/recon/gd.py +++ b/lensless/recon/gd.py @@ -10,6 +10,8 @@ import numpy as np from lensless.recon.recon import ReconstructionAlgorithm import inspect +from lensless.utils.io import load_data +import time try: import torch @@ -229,9 +231,30 @@ def reset(self, tk=None): def _update(self, iter): self._image_est -= self._alpha * self._grad() - # xk = self._proj(self._image_est) xk = self._form_image() tk = (1 + np.sqrt(1 + 4 * self._tk**2)) / 2 self._image_est = xk + (self._tk - 1) / tk * (xk - self._xk) self._tk = tk self._xk = xk + + +def apply_gradient_descent(psf_fp, data_fp, n_iter, verbose=False, proj=non_neg, **kwargs): + + # load data + psf, data = load_data(psf_fp=psf_fp, data_fp=data_fp, plot=False, **kwargs) + + # create reconstruction object + recon = GradientDescent(psf, n_iter=n_iter, proj=proj) + + # set data + recon.set_data(data) + + # perform reconstruction + start_time = time.time() + res = recon.apply(plot=False) + proc_time = time.time() - start_time + + if verbose: + print(f"Reconstruction time : {proc_time} s") + print(f"Reconstruction shape: {res.shape}") + return res diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 80349001..0d4b44f4 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -648,17 +648,12 @@ def train_epoch(self, data_loader): y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) # extraction region of interest for loss - if ( - hasattr(self.train_dataset, "alignment") - and self.train_dataset.alignment is not None - ): - alignment = self.train_dataset.alignment - y_pred = y_pred[ - ..., - alignment["topright"][0] : alignment["topright"][0] + alignment["height"], - alignment["topright"][1] : alignment["topright"][1] + alignment["width"], - ] - # expected that lensed is also reshaped accordingly + if hasattr(self.train_dataset, "alignment"): + if self.train_dataset.alignment is not None: + y_pred = self.train_dataset.extract_roi(y_pred, axis=(-2, -1)) + else: + y_pred, y = self.train_dataset.extract_roi(y_pred, axis=(-2, -1), lensed=y) + elif self.crop is not None: y_pred = y_pred[ ..., diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index 5a01e770..44ebeaec 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -14,6 +14,7 @@ from abc import abstractmethod from torch.utils.data import Dataset, Subset from torchvision import datasets, transforms +from torchvision.transforms import functional as F from lensless.hardware.trainable_mask import prep_trainable_mask, AdafruitLCD from lensless.utils.simulation import FarFieldSimulator from lensless.utils.io import load_image, load_psf, save_image @@ -26,6 +27,7 @@ from huggingface_hub import hf_hub_download import cv2 from lensless.hardware.sensor import sensor_dict, SensorParam +from scipy.ndimage import rotate def convert(text): @@ -1032,6 +1034,7 @@ def __init__( alignment=None, return_mask_label=False, save_psf=False, + simulation_config=dict(), **kwargs, ): """ @@ -1061,12 +1064,14 @@ def __init__( If `psf` not provided, the SLM to use for the PSF simulation, by default "adafruit". alignment : dict, optional Alignment parameters between lensless and lensed data. - If "topright", "height", and "width" are provided, the region-of-interest from the reconstruction of ``lensless`` is extracted and ``lensed`` is reshaped to match. + If "top_left", "height", and "width" are provided, the region-of-interest from the reconstruction of ``lensless`` is extracted and ``lensed`` is reshaped to match. If "crop" is provided, the region-of-interest is extracted from the simulated lensed image, namely a ``simulation`` configuration should be provided within ``alignment``. return_mask_label : bool, optional If multimask dataset, return the mask label (True) or the corresponding PSF (False). save_psf : bool, optional If multimask dataset, save the simulated PSFs. + simulation_config : dict, optional + Simulation parameters for PSF if using a mask pattern. """ @@ -1101,11 +1106,11 @@ def __init__( self.crop = None if alignment is not None: # preparing ground-truth in expected shape - if "topright" in alignment: + if "top_left" in alignment: self.alignment = dict(alignment.copy()) - self.alignment["topright"] = ( - int(self.alignment["topright"][0] / downsample), - int(self.alignment["topright"][1] / downsample), + self.alignment["top_left"] = ( + int(self.alignment["top_left"][0] / downsample), + int(self.alignment["top_left"][1] / downsample), ) self.alignment["height"] = int(self.alignment["height"] / downsample) @@ -1159,6 +1164,9 @@ def __init__( slm=slm, downsample=downsample_fact, flipud=rotate, + use_waveprop=simulation_config.get("use_waveprop", False), + scene2mask=simulation_config.get("scene2mask", None), + mask2sensor=simulation_config.get("mask2sensor", None), ) self.psf[label] = mask.get_psf().detach() @@ -1172,6 +1180,7 @@ def __init__( else: + # single mask pattern mask_fp = hf_hub_download( repo_id=huggingface_repo, filename="mask_pattern.npy", repo_type="dataset" ) @@ -1182,12 +1191,19 @@ def __init__( slm=slm, downsample=downsample_fact, flipud=rotate, + use_waveprop=simulation_config.get("use_waveprop", False), + scene2mask=simulation_config.get("scene2mask", None), + mask2sensor=simulation_config.get("mask2sensor", None), ) self.psf = mask.get_psf().detach() assert ( self.psf.shape[-3:-1] == lensless.shape[:2] ), "PSF shape should match lensless shape" + if save_psf: + # same viewable image of PSF + save_image(self.psf.squeeze().cpu().numpy(), "psf.png") + # create simulator self.simulator = None self.vertical_shift = None @@ -1233,6 +1249,7 @@ def _get_images_pair(self, idx): lensless = lensless_np lensed = lensed_np + if self.simulator is not None: # convert to torch lensless = torch.from_numpy(lensless_np) @@ -1282,27 +1299,41 @@ def __getitem__(self, idx): else: return lensless, lensed - def extract_roi(self, reconstruction, lensed=None): - assert len(reconstruction.shape) == 4, "Reconstruction should have shape [B, H, W, C]" - if lensed is not None: - assert len(lensed.shape) == 4, "Lensed should have shape [B, H, W, C]" + def extract_roi(self, reconstruction, lensed=None, axis=(1, 2)): + n_dim = len(reconstruction.shape) + assert max(axis) < n_dim, "Axis should be within the dimensions of the reconstruction." if self.alignment is not None: - top_right = self.alignment["topright"] + top_left = self.alignment["top_left"] height = self.alignment["height"] width = self.alignment["width"] - reconstruction = reconstruction[ - :, top_right[0] : top_right[0] + height, top_right[1] : top_right[1] + width - ] + + # extract according to axis + index = [slice(None)] * n_dim + index[axis[0]] = slice(top_left[0], top_left[0] + height) + index[axis[1]] = slice(top_left[1], top_left[1] + width) + reconstruction = reconstruction[tuple(index)] + + # rotate if necessary + angle = self.alignment.get("angle", 0) + if isinstance(reconstruction, torch.Tensor): + reconstruction = F.rotate(reconstruction, angle, expand=False) + else: + reconstruction = rotate(reconstruction, angle, axes=axis, reshape=False) + elif self.crop is not None: vertical = self.crop["vertical"] horizontal = self.crop["horizontal"] - reconstruction = reconstruction[ - :, vertical[0] : vertical[1], horizontal[0] : horizontal[1] - ] + + # extract according to axis + index = [slice(None)] * n_dim + index[axis[0]] = slice(vertical[0], vertical[1]) + index[axis[1]] = slice(horizontal[0], horizontal[1]) + reconstruction = reconstruction[tuple(index)] if lensed is not None: - lensed = lensed[:, vertical[0] : vertical[1], horizontal[0] : horizontal[1]] - if lensed is not None: + lensed = lensed[tuple(index)] + + if self.alignment is None and lensed is not None: return reconstruction, lensed else: return reconstruction diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 47fd94f4..6d07bc27 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -383,6 +383,8 @@ def load_data( bg_pix=(5, 25), plot=True, flip=False, + flip_ud=False, + flip_lr=False, bayer=False, blue_gain=None, red_gain=None, @@ -391,7 +393,7 @@ def load_data( dtype=None, single_psf=False, shape=None, - torch=False, + use_torch=False, torch_device="cpu", normalize=False, bgr_input=True, @@ -468,6 +470,8 @@ def load_data( bg_pix=bg_pix, return_bg=True, flip=flip, + flip_ud=flip_ud, + flip_lr=flip_lr, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain, @@ -482,6 +486,8 @@ def load_data( data = load_image( data_fp, flip=flip, + flip_ud=flip_ud, + flip_lr=flip_lr, bayer=bayer, blue_gain=blue_gain, red_gain=red_gain, @@ -528,7 +534,7 @@ def load_data( psf = np.array(psf, dtype=dtype) data = np.array(data, dtype=dtype) - if torch: + if use_torch: import torch if dtype == np.float32: diff --git a/profile/admm.py b/profile/admm.py index d3942362..bb70b715 100644 --- a/profile/admm.py +++ b/profile/admm.py @@ -68,7 +68,7 @@ plot=False, gray=gray, dtype=dtype, - torch=True, + use_torch=True, ) recon = ADMM(psf, dtype=dtype) @@ -91,7 +91,7 @@ plot=False, gray=gray, dtype=dtype, - torch=True, + use_torch=True, torch_device="cuda", ) diff --git a/profile/gradient_descent.py b/profile/gradient_descent.py index 2516e071..4f19d1c2 100644 --- a/profile/gradient_descent.py +++ b/profile/gradient_descent.py @@ -97,7 +97,7 @@ plot=False, gray=gray, dtype=dtype, - torch=True, + use_torch=True, ) recon = FISTA(psf, dtype=dtype) @@ -120,7 +120,7 @@ plot=False, gray=gray, dtype=dtype, - torch=True, + use_torch=True, torch_device="cuda", ) diff --git a/scripts/eval/benchmark_recon.py b/scripts/eval/benchmark_recon.py index ece0bcfa..7be38ddd 100644 --- a/scripts/eval/benchmark_recon.py +++ b/scripts/eval/benchmark_recon.py @@ -84,7 +84,7 @@ def benchmark_recon(config): _, benchmark_dataset = torch.utils.data.random_split( dataset, [train_size, test_size], generator=generator ) - elif dataset == "DigiCamHF": + elif dataset == "HFDataset": benchmark_dataset = HFDataset( huggingface_repo=config.huggingface.repo, split="test", @@ -92,6 +92,7 @@ def benchmark_recon(config): rotate=config.huggingface.rotate, downsample=config.huggingface.downsample, alignment=config.huggingface.alignment, + simulation_config=config.simulation, ) if benchmark_dataset.multimask: # get first PSF for initialization diff --git a/scripts/measure/collect_dataset_on_device.py b/scripts/measure/collect_dataset_on_device.py index 596d4819..76393b9c 100644 --- a/scripts/measure/collect_dataset_on_device.py +++ b/scripts/measure/collect_dataset_on_device.py @@ -270,7 +270,7 @@ def collect_dataset(config): )[0] # save image - save_image(output, output_fp) + save_image(output, output_fp, normalize=False) # print range print(f"{output_fp}, range: {output.min()} - {output.max()}") diff --git a/scripts/measure/remote_capture.py b/scripts/measure/remote_capture.py index 6777347b..b7b5c6c0 100644 --- a/scripts/measure/remote_capture.py +++ b/scripts/measure/remote_capture.py @@ -240,7 +240,7 @@ def liveview(config): # save image as viewable 8 bit fp = os.path.join(save, f"{fn}_8bit.png") - save_image(img, fp) + save_image(img, fp, normalize=True) # plot RGB if plot: diff --git a/scripts/recon/admm.py b/scripts/recon/admm.py index 1fdd68ba..12584e59 100644 --- a/scripts/recon/admm.py +++ b/scripts/recon/admm.py @@ -40,7 +40,7 @@ def admm(config): gray=config["preprocess"]["gray"], single_psf=config["preprocess"]["single_psf"], shape=config["preprocess"]["shape"], - torch=config.torch, + use_torch=config.torch, torch_device=config.torch_device, bg_pix=config.preprocess.bg_pix, normalize=config.preprocess.normalize, diff --git a/scripts/recon/dataset.py b/scripts/recon/dataset.py index 906508db..eefbfe22 100644 --- a/scripts/recon/dataset.py +++ b/scripts/recon/dataset.py @@ -115,7 +115,7 @@ def recover(i): scores = psnr(lensed[0], res), lpips(lensed[0], res) output_fp = os.path.join(output_folder, f"{i}.png") - save_image(res, output_fp) + save_image(res, output_fp, normalize=True) return scores n_jobs = config.apgd.n_jobs @@ -178,7 +178,7 @@ def recover(i): else: img = res[0] output_fp = os.path.join(output_folder, f"{i}.png") - save_image(img, output_fp) + save_image(img, output_fp, normalize=True) if len(psnr_scores) > 0: # print average metrics diff --git a/scripts/recon/digicam_mirflickr.py b/scripts/recon/digicam_mirflickr.py index 60411fd0..a7e25ff9 100644 --- a/scripts/recon/digicam_mirflickr.py +++ b/scripts/recon/digicam_mirflickr.py @@ -84,10 +84,10 @@ def apply_pretrained(config): if save: print(f"Saving images to {os.getcwd()}") alignment = test_set.alignment - top_right = alignment["topright"] + top_left = alignment["top_left"] height = alignment["height"] width = alignment["width"] - res_np = img[top_right[0] : top_right[0] + height, top_right[1] : top_right[1] + width] + res_np = img[top_left[0] : top_left[0] + height, top_left[1] : top_left[1] + width] lensed_np = lensed[0].cpu().numpy() save_image(lensed_np, f"original_idx{idx}.png") save_image(res_np, f"{model_name}_idx{idx}.png") diff --git a/scripts/recon/gradient_descent.py b/scripts/recon/gradient_descent.py index 16f36955..f2a9c883 100644 --- a/scripts/recon/gradient_descent.py +++ b/scripts/recon/gradient_descent.py @@ -41,7 +41,7 @@ def gradient_descent( gray=config["preprocess"]["gray"], single_psf=config["preprocess"]["single_psf"], shape=config["preprocess"]["shape"], - torch=config.torch, + use_torch=config.torch, torch_device=config.torch_device, ) diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 9ad7a016..d5e6a111 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -104,7 +104,6 @@ def train_learned(config): test_set = None psf = None crop = None - alignment = None # very similar to crop, TODO: should switch to this approach mask = None if "DiffuserCam" in config.files.dataset and config.files.huggingface_dataset is False: @@ -216,6 +215,7 @@ def train_learned(config): alignment=config.alignment, save_psf=config.files.save_psf, n_files=config.files.n_files, + simulation_config=config.simulation, ) test_set = HFDataset( huggingface_repo=config.files.dataset, @@ -228,6 +228,7 @@ def train_learned(config): alignment=config.alignment, save_psf=config.files.save_psf, n_files=config.files.n_files, + simulation_config=config.simulation, ) if train_set.multimask: # get first PSF for initialization @@ -239,7 +240,6 @@ def train_learned(config): else: psf = train_set.psf.to(device) crop = test_set.crop # same for train set - alignment = test_set.alignment # -- if learning mask mask = prep_trainable_mask(config, psf) @@ -277,6 +277,7 @@ def train_learned(config): split="test", downsample=config.files.downsample, # needs to be same size n_files=config.files.n_files, + simulation_config=config.simulation, **config.files.extra_eval[eval_set], ) @@ -310,13 +311,14 @@ def train_learned(config): # -- plot lensed and res on top of each other cropped = False - if alignment is not None: - top_right = alignment["topright"] - height = alignment["height"] - width = alignment["width"] - res_np = res_np[ - top_right[0] : top_right[0] + height, top_right[1] : top_right[1] + width - ] + if hasattr(test_set, "alignment"): + if test_set.alignment is not None: + res_np = test_set.extract_roi(res_np, axis=(0, 1)) + else: + res_np, lensed_np = test_set.extract_roi( + res_np, lensed=lensed_np, axis=(0, 1) + ) + cropped = True elif config.training.crop_preloss: diff --git a/test/test_algos.py b/test/test_algos.py index b5e3d94c..0ea89f14 100644 --- a/test/test_algos.py +++ b/test/test_algos.py @@ -45,7 +45,7 @@ def test_set_initial_est(algorithm): downsample=downsample, plot=False, gray=gray, - torch=False, + use_torch=False, ) recon = algorithm(psf) assert recon._initial_est is None @@ -71,7 +71,7 @@ def test_set_initial_est_unrolled(algorithm): downsample=downsample, plot=False, gray=gray, - torch=True, + use_torch=True, ) recon = algorithm(psf) assert recon._initial_est is None @@ -97,7 +97,7 @@ def test_recon_numpy(algorithm): plot=False, gray=gray, dtype=dtype, - torch=False, + use_torch=False, ) recon = algorithm(psf, dtype=dtype) recon.set_data(data) @@ -120,7 +120,7 @@ def test_recon_torch(algorithm): plot=False, gray=gray, dtype=dtype, - torch=True, + use_torch=True, ) recon = algorithm(psf, dtype=dtype, n_iter=_n_iter) recon.set_data(data) @@ -144,7 +144,7 @@ def test_apgd(): plot=False, gray=gray, dtype=dtype, - torch=False, + use_torch=False, ) recon = APGD( psf, diff --git a/test/test_io.py b/test/test_io.py index 5c2f8884..84a822ec 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -33,7 +33,7 @@ def test_rgb2gray(): downsample=downsample, plot=False, dtype="float32", - torch=is_torch, + use_torch=is_torch, ) data = data[0] # drop first depth dimension diff --git a/test/test_masks.py b/test/test_masks.py index 6a7fa81a..b2c70c7c 100644 --- a/test/test_masks.py +++ b/test/test_masks.py @@ -43,12 +43,13 @@ def test_phlatcam(): feature_size=d1, distance_sensor=dz, ) - assert np.all(mask.mask.shape == resolution) + assert np.all(mask.shape == resolution) desired_psf_shape = np.array(tuple(resolution) + (len(mask.psf_wavelength),)) assert np.all(mask.psf.shape == desired_psf_shape) + u1 = mask.height_map_to_field(wavelength=mask.design_wv) Mp = np.sqrt(mask.target_psf) * np.exp( - 1j * np.angle(fresnel_conv(mask.mask, mask.design_wv, d1, dz, dtype=np.float32)[0]) + 1j * np.angle(fresnel_conv(u1, mask.design_wv, d1, dz, dtype=np.float32)[0]) ) assert mse(abs(Mp), np.sqrt(mask.target_psf)) < 0.1 assert psnr(abs(Mp), np.sqrt(mask.target_psf)) > 30 @@ -72,21 +73,21 @@ def test_classmethod(): mask1 = CodedAperture.from_sensor( sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz ) - assert np.all(mask1.mask.shape == resolution) + assert np.all(mask1.shape == resolution) desired_psf_shape = np.array(tuple(resolution) + (len(mask1.psf_wavelength),)) assert np.all(mask1.psf.shape == desired_psf_shape) mask2 = PhaseContour.from_sensor( sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz ) - assert np.all(mask2.mask.shape == resolution) + assert np.all(mask2.shape == resolution) desired_psf_shape = np.array(tuple(resolution) + (len(mask2.psf_wavelength),)) assert np.all(mask2.psf.shape == desired_psf_shape) mask3 = FresnelZoneAperture.from_sensor( sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz ) - assert np.all(mask3.mask.shape == resolution) + assert np.all(mask3.shape == resolution) desired_psf_shape = np.array(tuple(resolution) + (len(mask3.psf_wavelength),)) assert np.all(mask3.psf.shape == desired_psf_shape)