Skip to content

Commit

Permalink
Fix normalization, and downsample before color correction. (#134)
Browse files Browse the repository at this point in the history
* Fix normalization, and downsample before color correction.

* Fix starting from already started measurement.

* Change dataset check on saturation ratio.
  • Loading branch information
ebezzam authored Jul 9, 2024
1 parent 17462ff commit 01d6d5d
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 63 deletions.
4 changes: 3 additions & 1 deletion configs/analyze_dataset.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# python scripts/measure/analyze_measured_dataset.py
hydra:
job:
chdir: True # change to output folder

dataset_path: null
desired_range: [150, 254]
desired_range: [150, 255]
saturation_percent: 0.05
delete_bad: False
n_files: null
start_idx: null
6 changes: 4 additions & 2 deletions configs/collect_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ display:

capture:
skip: False # to test looping over displaying images
config_pause: 2
config_pause: 3
iso: 100
res: null
down: 4
exposure: 0.02 # min exposure
awb_gains: [1.9, 1.2] # red, blue
# awb_gains: null
# awb_gains: null
fact_increase: 2 # multiplicative factor to increase exposure
fact_decrease: 1.5
6 changes: 6 additions & 0 deletions lensless/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def get_max_val(img, nbits=None):
def bayer2rgb_cc(
img,
nbits,
down=None,
blue_gain=None,
red_gain=None,
black_level=RPI_HQ_CAMERA_BLACK_LEVEL,
Expand Down Expand Up @@ -269,6 +270,10 @@ def bayer2rgb_cc(
# demosaic Bayer data
img = cv2.cvtColor(img, cv2.COLOR_BayerRG2RGB)

# downsample
if down is not None:
img = resize(img[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC)[0]

# correction
img = img - black_level
if red_gain:
Expand All @@ -277,6 +282,7 @@ def bayer2rgb_cc(
img[:, :, 2] *= blue_gain
img = img / (2**nbits - 1 - black_level)
img[img > 1] = 1

img = (img.reshape(-1, 3, order="F") @ ccm.T).reshape(img.shape, order="F")
img[img < 0] = 0
img[img > 1] = 1
Expand Down
40 changes: 24 additions & 16 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,31 +576,39 @@ def save_image(img, fp, max_val=255, normalize=True):

img_tmp = img.copy()

if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8:
img_tmp = img_tmp.astype(np.float32)

if normalize:

if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8:
img_tmp = img_tmp.astype(np.float32)

img_tmp -= img_tmp.min()
img_tmp /= img_tmp.max()
else:
normalized = False
if img_tmp.min() < 0:
img_tmp -= img_tmp.min()
normalize = True
if img_tmp.max() > 1:
img_tmp /= img_tmp.max()
normalize = True
if normalized:
print(f"Warning (out of range): {fp} normalizing data to [0, 1]")

if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32:
img_tmp *= max_val
img_tmp = img_tmp.astype(np.uint8)

# RGB
else:

if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32:
# check within [0, 1] and convert to uint8

normalized = False
if img_tmp.min() < 0:
img_tmp -= img_tmp.min()
normalized = True
if img_tmp.max() > 1:
img_tmp /= img_tmp.max()
normalized = True
if normalized:
print(f"Warning (out of range): {fp} normalizing data to [0, 1]")
img_tmp *= max_val
img_tmp = img_tmp.astype(np.uint8)

# save
if len(img_tmp.shape) == 3 and img_tmp.shape[2] == 3:
# RGB
img_tmp = Image.fromarray(img_tmp)
else:
# grayscale
img_tmp = Image.fromarray(img_tmp.squeeze())
img_tmp.save(fp)

Expand Down
59 changes: 47 additions & 12 deletions scripts/measure/analyze_measured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
import matplotlib.pyplot as plt
import time
import tqdm
import re


def convert(text):
return int(text) if text.isdigit() else text.lower()


def alphanum_key(key):
return [convert(c) for c in re.split("([0-9]+)", key)]


def natural_sort(arr):
return sorted(arr, key=alphanum_key)


@hydra.main(version_base=None, config_path="../../configs", config_name="analyze_dataset")
Expand All @@ -24,13 +37,14 @@ def analyze_dataset(config):
desired_range = config.desired_range
delete_bad = config.delete_bad
start_idx = config.start_idx
saturation_percent = config.saturation_percent

assert (
folder is not None
), "Must specify folder to analyze in config or through command line (folder=PATH)."

# get all PNG files in folder
files = sorted(glob.glob(os.path.join(folder, "*.png")))
files = natural_sort(glob.glob(os.path.join(folder, "*.png")))
print("Found {} files".format(len(files)))
if start_idx is not None:
files = files[start_idx:]
Expand All @@ -48,10 +62,9 @@ def analyze_dataset(config):
im = np.array(Image.open(fn))
max_val = im.max()
max_vals.append(max_val)
saturation_ratio = np.sum(im >= desired_range[1]) / im.size

# if out of desired range, print filename
if max_val < desired_range[0] or max_val > desired_range[1]:
# print("File {} has max value {}".format(fn, max_val))
if max_val < desired_range[0]:
n_bad_files += 1
bad_files.append(fn)

Expand All @@ -61,6 +74,28 @@ def analyze_dataset(config):
else:
print("File {} has max value {}".format(fn, max_val))

elif saturation_ratio > saturation_percent:
n_bad_files += 1
bad_files.append(fn)

if delete_bad:
os.remove(fn)
print("REMOVED file {}".format(fn))
else:
print("File {} has saturation ratio {}".format(fn, saturation_ratio))

# # if out of desired range, print filename
# if max_val < desired_range[0] or saturation_ratio > saturation_percent:
# # print("File {} has max value {}".format(fn, max_val))
# n_bad_files += 1
# bad_files.append(fn)

# if delete_bad:
# os.remove(fn)
# print("REMOVED file {}".format(fn))
# else:
# print("File {} has max value {}".format(fn, max_val))

proc_time = time.time() - start_time
print("Went through {} files in {:.2f} seconds".format(len(files), proc_time))
print(
Expand All @@ -69,6 +104,14 @@ def analyze_dataset(config):
)
)

# plot histogram
output_folder = os.getcwd()
output_fp = os.path.join(output_folder, "max_vals.png")
plt.hist(max_vals, bins=100)
plt.savefig(output_fp)

print("Saved histogram to {}".format(output_fp))

# command line input on whether to delete bad files
if not delete_bad:
response = None
Expand All @@ -80,14 +123,6 @@ def analyze_dataset(config):
else:
print("Not deleting bad files")

# plot histogram
output_folder = os.getcwd()
output_fp = os.path.join(output_folder, "max_vals.png")
plt.hist(max_vals, bins=100)
plt.savefig(output_fp)

print("Saved histogram to {}".format(output_fp))


if __name__ == "__main__":
analyze_dataset()
77 changes: 45 additions & 32 deletions scripts/measure/collect_dataset_on_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import hydra
from hydra.utils import to_absolute_path
import time
import os
import pathlib as plib
Expand Down Expand Up @@ -67,8 +68,8 @@ def collect_dataset(config):
start_idx = config.start_idx
if os.path.exists(output_dir):
files = list(plib.Path(output_dir).glob(f"*.{config.output_file_ext}"))
start_idx = len(files)
print("\nNumber of completed measurements :", start_idx)
n_completed_files = len(files)
print("\nNumber of completed measurements :", n_completed_files)
output_dir = plib.Path(output_dir)
if config.masks is not None:
mask_dir = plib.Path(output_dir) / "masks"
Expand All @@ -91,19 +92,23 @@ def collect_dataset(config):

recon = None
if config.recon is not None:
print("Initializing ADMM recon...")
# initialize ADMM reconstruction
from lensless import ADMM
from lensless.utils.io import load_psf

psf, bg = load_psf(
fp=config.recon.psf,
downsample=config.capture.down, # assume full resolution PSF
return_bg=True
fp=to_absolute_path(config.recon.psf),
downsample=config.capture.down, # assume full resolution PSF
return_bg=True,
)

print("PSF shape: ", psf.shape)
recon = ADMM(psf, n_iter=config.recon.n_iter)

recon_dir = plib.Path(output_dir) / "recon"
recon_dir.mkdir(exist_ok=True)
print("Finished initializing ADMM recon.")

# assert input directory exists
assert os.path.exists(input_dir)
Expand Down Expand Up @@ -250,8 +255,8 @@ def collect_dataset(config):

# -- take picture
max_pixel_val = 0
fact_increase = 2
fact_decrease = 1.5
fact_increase = config.capture.fact_increase
fact_decrease = config.capture.fact_decrease
n_tries = 0

camera.shutter_speed = init_shutter_speed
Expand All @@ -272,6 +277,7 @@ def collect_dataset(config):
# convert to RGB
output = bayer2rgb_cc(
output_bayer,
down=down,
nbits=12,
blue_gain=float(g[1]),
red_gain=float(g[0]),
Expand All @@ -280,10 +286,10 @@ def collect_dataset(config):
nbits_out=8,
)

if down:
output = resize(
output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC
)[0]
# if down:
# output = resize(
# output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC
# )[0]

# save image
save_image(output, output_fp, normalize=False)
Expand All @@ -305,27 +311,34 @@ def collect_dataset(config):

elif max_pixel_val > MAX_LEVEL:

# decrease exposure
current_shutter_speed = int(current_shutter_speed / fact_decrease)
camera.shutter_speed = current_shutter_speed
time.sleep(config.capture.config_pause)
print(f"decreasing shutter speed to {current_shutter_speed}")

# # decrease screen brightness
# current_screen_brightness = current_screen_brightness - 10
# screen_res = np.array(config.display.screen_res)
# hshift = config.display.hshift
# vshift = config.display.vshift
# pad = config.display.pad
# brightness = current_screen_brightness
# display_image_path = config.display.output_fp
# rot90 = config.display.rot90
# os.system(
# f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}"
# )
# print(f"decreasing screen brightness to {current_screen_brightness}")

# time.sleep(config.display.delay)
if current_shutter_speed > 13098: # TODO: minimum for RPi HQ
# decrease exposure
current_shutter_speed = int(current_shutter_speed / fact_decrease)
camera.shutter_speed = current_shutter_speed
time.sleep(config.capture.config_pause)
print(f"decreasing shutter speed to {current_shutter_speed}")

else:

# decrease screen brightness
current_screen_brightness = current_screen_brightness - 10
screen_res = np.array(config.display.screen_res)
hshift = config.display.hshift
vshift = config.display.vshift
pad = config.display.pad
brightness = current_screen_brightness
display_image_path = config.display.output_fp
rot90 = config.display.rot90

display_command = f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}"
if config.display.landscape:
display_command += " --landscape"
if config.display.image_res is not None:
display_command += f" --image_res {config.display.image_res[0]} {config.display.image_res[1]}"
# print(display_command)
os.system(display_command)

time.sleep(config.display.delay)

exposure_vals.append(current_shutter_speed / 1e6)
brightness_vals.append(current_screen_brightness)
Expand Down

0 comments on commit 01d6d5d

Please sign in to comment.