From 01d6d5d3382d7e107d0b50f998b4779766f3e1e1 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 9 Jul 2024 01:17:44 -0700 Subject: [PATCH] Fix normalization, and downsample before color correction. (#134) * Fix normalization, and downsample before color correction. * Fix starting from already started measurement. * Change dataset check on saturation ratio. --- configs/analyze_dataset.yaml | 4 +- configs/collect_dataset.yaml | 6 +- lensless/utils/image.py | 6 ++ lensless/utils/io.py | 40 ++++++---- scripts/measure/analyze_measured_dataset.py | 59 ++++++++++++--- scripts/measure/collect_dataset_on_device.py | 77 ++++++++++++-------- 6 files changed, 129 insertions(+), 63 deletions(-) diff --git a/configs/analyze_dataset.yaml b/configs/analyze_dataset.yaml index 53d6a130..475f452c 100644 --- a/configs/analyze_dataset.yaml +++ b/configs/analyze_dataset.yaml @@ -1,9 +1,11 @@ +# python scripts/measure/analyze_measured_dataset.py hydra: job: chdir: True # change to output folder dataset_path: null -desired_range: [150, 254] +desired_range: [150, 255] +saturation_percent: 0.05 delete_bad: False n_files: null start_idx: null diff --git a/configs/collect_dataset.yaml b/configs/collect_dataset.yaml index c5f56d45..0fa87fa2 100644 --- a/configs/collect_dataset.yaml +++ b/configs/collect_dataset.yaml @@ -42,10 +42,12 @@ display: capture: skip: False # to test looping over displaying images - config_pause: 2 + config_pause: 3 iso: 100 res: null down: 4 exposure: 0.02 # min exposure awb_gains: [1.9, 1.2] # red, blue - # awb_gains: null \ No newline at end of file + # awb_gains: null + fact_increase: 2 # multiplicative factor to increase exposure + fact_decrease: 1.5 \ No newline at end of file diff --git a/lensless/utils/image.py b/lensless/utils/image.py index edb213fc..cc2f4936 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -221,6 +221,7 @@ def get_max_val(img, nbits=None): def bayer2rgb_cc( img, nbits, + down=None, blue_gain=None, red_gain=None, black_level=RPI_HQ_CAMERA_BLACK_LEVEL, @@ -269,6 +270,10 @@ def bayer2rgb_cc( # demosaic Bayer data img = cv2.cvtColor(img, cv2.COLOR_BayerRG2RGB) + # downsample + if down is not None: + img = resize(img[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC)[0] + # correction img = img - black_level if red_gain: @@ -277,6 +282,7 @@ def bayer2rgb_cc( img[:, :, 2] *= blue_gain img = img / (2**nbits - 1 - black_level) img[img > 1] = 1 + img = (img.reshape(-1, 3, order="F") @ ccm.T).reshape(img.shape, order="F") img[img < 0] = 0 img[img > 1] = 1 diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 7e44975b..4d1eec70 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -576,31 +576,39 @@ def save_image(img, fp, max_val=255, normalize=True): img_tmp = img.copy() - if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8: - img_tmp = img_tmp.astype(np.float32) - if normalize: + + if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8: + img_tmp = img_tmp.astype(np.float32) + img_tmp -= img_tmp.min() img_tmp /= img_tmp.max() - else: - normalized = False - if img_tmp.min() < 0: - img_tmp -= img_tmp.min() - normalize = True - if img_tmp.max() > 1: - img_tmp /= img_tmp.max() - normalize = True - if normalized: - print(f"Warning (out of range): {fp} normalizing data to [0, 1]") - - if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32: img_tmp *= max_val img_tmp = img_tmp.astype(np.uint8) - # RGB + else: + + if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32: + # check within [0, 1] and convert to uint8 + + normalized = False + if img_tmp.min() < 0: + img_tmp -= img_tmp.min() + normalized = True + if img_tmp.max() > 1: + img_tmp /= img_tmp.max() + normalized = True + if normalized: + print(f"Warning (out of range): {fp} normalizing data to [0, 1]") + img_tmp *= max_val + img_tmp = img_tmp.astype(np.uint8) + + # save if len(img_tmp.shape) == 3 and img_tmp.shape[2] == 3: + # RGB img_tmp = Image.fromarray(img_tmp) else: + # grayscale img_tmp = Image.fromarray(img_tmp.squeeze()) img_tmp.save(fp) diff --git a/scripts/measure/analyze_measured_dataset.py b/scripts/measure/analyze_measured_dataset.py index 6137e176..2d5c7050 100644 --- a/scripts/measure/analyze_measured_dataset.py +++ b/scripts/measure/analyze_measured_dataset.py @@ -15,6 +15,19 @@ import matplotlib.pyplot as plt import time import tqdm +import re + + +def convert(text): + return int(text) if text.isdigit() else text.lower() + + +def alphanum_key(key): + return [convert(c) for c in re.split("([0-9]+)", key)] + + +def natural_sort(arr): + return sorted(arr, key=alphanum_key) @hydra.main(version_base=None, config_path="../../configs", config_name="analyze_dataset") @@ -24,13 +37,14 @@ def analyze_dataset(config): desired_range = config.desired_range delete_bad = config.delete_bad start_idx = config.start_idx + saturation_percent = config.saturation_percent assert ( folder is not None ), "Must specify folder to analyze in config or through command line (folder=PATH)." # get all PNG files in folder - files = sorted(glob.glob(os.path.join(folder, "*.png"))) + files = natural_sort(glob.glob(os.path.join(folder, "*.png"))) print("Found {} files".format(len(files))) if start_idx is not None: files = files[start_idx:] @@ -48,10 +62,9 @@ def analyze_dataset(config): im = np.array(Image.open(fn)) max_val = im.max() max_vals.append(max_val) + saturation_ratio = np.sum(im >= desired_range[1]) / im.size - # if out of desired range, print filename - if max_val < desired_range[0] or max_val > desired_range[1]: - # print("File {} has max value {}".format(fn, max_val)) + if max_val < desired_range[0]: n_bad_files += 1 bad_files.append(fn) @@ -61,6 +74,28 @@ def analyze_dataset(config): else: print("File {} has max value {}".format(fn, max_val)) + elif saturation_ratio > saturation_percent: + n_bad_files += 1 + bad_files.append(fn) + + if delete_bad: + os.remove(fn) + print("REMOVED file {}".format(fn)) + else: + print("File {} has saturation ratio {}".format(fn, saturation_ratio)) + + # # if out of desired range, print filename + # if max_val < desired_range[0] or saturation_ratio > saturation_percent: + # # print("File {} has max value {}".format(fn, max_val)) + # n_bad_files += 1 + # bad_files.append(fn) + + # if delete_bad: + # os.remove(fn) + # print("REMOVED file {}".format(fn)) + # else: + # print("File {} has max value {}".format(fn, max_val)) + proc_time = time.time() - start_time print("Went through {} files in {:.2f} seconds".format(len(files), proc_time)) print( @@ -69,6 +104,14 @@ def analyze_dataset(config): ) ) + # plot histogram + output_folder = os.getcwd() + output_fp = os.path.join(output_folder, "max_vals.png") + plt.hist(max_vals, bins=100) + plt.savefig(output_fp) + + print("Saved histogram to {}".format(output_fp)) + # command line input on whether to delete bad files if not delete_bad: response = None @@ -80,14 +123,6 @@ def analyze_dataset(config): else: print("Not deleting bad files") - # plot histogram - output_folder = os.getcwd() - output_fp = os.path.join(output_folder, "max_vals.png") - plt.hist(max_vals, bins=100) - plt.savefig(output_fp) - - print("Saved histogram to {}".format(output_fp)) - if __name__ == "__main__": analyze_dataset() diff --git a/scripts/measure/collect_dataset_on_device.py b/scripts/measure/collect_dataset_on_device.py index 69bd65d9..96c7c727 100644 --- a/scripts/measure/collect_dataset_on_device.py +++ b/scripts/measure/collect_dataset_on_device.py @@ -14,6 +14,7 @@ import numpy as np import hydra +from hydra.utils import to_absolute_path import time import os import pathlib as plib @@ -67,8 +68,8 @@ def collect_dataset(config): start_idx = config.start_idx if os.path.exists(output_dir): files = list(plib.Path(output_dir).glob(f"*.{config.output_file_ext}")) - start_idx = len(files) - print("\nNumber of completed measurements :", start_idx) + n_completed_files = len(files) + print("\nNumber of completed measurements :", n_completed_files) output_dir = plib.Path(output_dir) if config.masks is not None: mask_dir = plib.Path(output_dir) / "masks" @@ -91,19 +92,23 @@ def collect_dataset(config): recon = None if config.recon is not None: + print("Initializing ADMM recon...") # initialize ADMM reconstruction from lensless import ADMM from lensless.utils.io import load_psf psf, bg = load_psf( - fp=config.recon.psf, - downsample=config.capture.down, # assume full resolution PSF - return_bg=True + fp=to_absolute_path(config.recon.psf), + downsample=config.capture.down, # assume full resolution PSF + return_bg=True, ) + + print("PSF shape: ", psf.shape) recon = ADMM(psf, n_iter=config.recon.n_iter) recon_dir = plib.Path(output_dir) / "recon" recon_dir.mkdir(exist_ok=True) + print("Finished initializing ADMM recon.") # assert input directory exists assert os.path.exists(input_dir) @@ -250,8 +255,8 @@ def collect_dataset(config): # -- take picture max_pixel_val = 0 - fact_increase = 2 - fact_decrease = 1.5 + fact_increase = config.capture.fact_increase + fact_decrease = config.capture.fact_decrease n_tries = 0 camera.shutter_speed = init_shutter_speed @@ -272,6 +277,7 @@ def collect_dataset(config): # convert to RGB output = bayer2rgb_cc( output_bayer, + down=down, nbits=12, blue_gain=float(g[1]), red_gain=float(g[0]), @@ -280,10 +286,10 @@ def collect_dataset(config): nbits_out=8, ) - if down: - output = resize( - output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC - )[0] + # if down: + # output = resize( + # output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC + # )[0] # save image save_image(output, output_fp, normalize=False) @@ -305,27 +311,34 @@ def collect_dataset(config): elif max_pixel_val > MAX_LEVEL: - # decrease exposure - current_shutter_speed = int(current_shutter_speed / fact_decrease) - camera.shutter_speed = current_shutter_speed - time.sleep(config.capture.config_pause) - print(f"decreasing shutter speed to {current_shutter_speed}") - - # # decrease screen brightness - # current_screen_brightness = current_screen_brightness - 10 - # screen_res = np.array(config.display.screen_res) - # hshift = config.display.hshift - # vshift = config.display.vshift - # pad = config.display.pad - # brightness = current_screen_brightness - # display_image_path = config.display.output_fp - # rot90 = config.display.rot90 - # os.system( - # f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" - # ) - # print(f"decreasing screen brightness to {current_screen_brightness}") - - # time.sleep(config.display.delay) + if current_shutter_speed > 13098: # TODO: minimum for RPi HQ + # decrease exposure + current_shutter_speed = int(current_shutter_speed / fact_decrease) + camera.shutter_speed = current_shutter_speed + time.sleep(config.capture.config_pause) + print(f"decreasing shutter speed to {current_shutter_speed}") + + else: + + # decrease screen brightness + current_screen_brightness = current_screen_brightness - 10 + screen_res = np.array(config.display.screen_res) + hshift = config.display.hshift + vshift = config.display.vshift + pad = config.display.pad + brightness = current_screen_brightness + display_image_path = config.display.output_fp + rot90 = config.display.rot90 + + display_command = f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" + if config.display.landscape: + display_command += " --landscape" + if config.display.image_res is not None: + display_command += f" --image_res {config.display.image_res[0]} {config.display.image_res[1]}" + # print(display_command) + os.system(display_command) + + time.sleep(config.display.delay) exposure_vals.append(current_shutter_speed / 1e6) brightness_vals.append(current_screen_brightness)