Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix normalization, and downsample before color correction. #134

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion configs/analyze_dataset.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# python scripts/measure/analyze_measured_dataset.py
hydra:
job:
chdir: True # change to output folder

dataset_path: null
desired_range: [150, 254]
desired_range: [150, 255]
saturation_percent: 0.05
delete_bad: False
n_files: null
start_idx: null
6 changes: 4 additions & 2 deletions configs/collect_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ display:

capture:
skip: False # to test looping over displaying images
config_pause: 2
config_pause: 3
iso: 100
res: null
down: 4
exposure: 0.02 # min exposure
awb_gains: [1.9, 1.2] # red, blue
# awb_gains: null
# awb_gains: null
fact_increase: 2 # multiplicative factor to increase exposure
fact_decrease: 1.5
6 changes: 6 additions & 0 deletions lensless/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def get_max_val(img, nbits=None):
def bayer2rgb_cc(
img,
nbits,
down=None,
blue_gain=None,
red_gain=None,
black_level=RPI_HQ_CAMERA_BLACK_LEVEL,
Expand Down Expand Up @@ -269,6 +270,10 @@ def bayer2rgb_cc(
# demosaic Bayer data
img = cv2.cvtColor(img, cv2.COLOR_BayerRG2RGB)

# downsample
if down is not None:
img = resize(img[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC)[0]

# correction
img = img - black_level
if red_gain:
Expand All @@ -277,6 +282,7 @@ def bayer2rgb_cc(
img[:, :, 2] *= blue_gain
img = img / (2**nbits - 1 - black_level)
img[img > 1] = 1

img = (img.reshape(-1, 3, order="F") @ ccm.T).reshape(img.shape, order="F")
img[img < 0] = 0
img[img > 1] = 1
Expand Down
40 changes: 24 additions & 16 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,31 +576,39 @@ def save_image(img, fp, max_val=255, normalize=True):

img_tmp = img.copy()

if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8:
img_tmp = img_tmp.astype(np.float32)

if normalize:

if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8:
img_tmp = img_tmp.astype(np.float32)

img_tmp -= img_tmp.min()
img_tmp /= img_tmp.max()
else:
normalized = False
if img_tmp.min() < 0:
img_tmp -= img_tmp.min()
normalize = True
if img_tmp.max() > 1:
img_tmp /= img_tmp.max()
normalize = True
if normalized:
print(f"Warning (out of range): {fp} normalizing data to [0, 1]")

if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32:
img_tmp *= max_val
img_tmp = img_tmp.astype(np.uint8)

# RGB
else:

if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32:
# check within [0, 1] and convert to uint8

normalized = False
if img_tmp.min() < 0:
img_tmp -= img_tmp.min()
normalized = True
if img_tmp.max() > 1:
img_tmp /= img_tmp.max()
normalized = True
if normalized:
print(f"Warning (out of range): {fp} normalizing data to [0, 1]")
img_tmp *= max_val
img_tmp = img_tmp.astype(np.uint8)

# save
if len(img_tmp.shape) == 3 and img_tmp.shape[2] == 3:
# RGB
img_tmp = Image.fromarray(img_tmp)
else:
# grayscale
img_tmp = Image.fromarray(img_tmp.squeeze())
img_tmp.save(fp)

Expand Down
59 changes: 47 additions & 12 deletions scripts/measure/analyze_measured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
import matplotlib.pyplot as plt
import time
import tqdm
import re


def convert(text):
return int(text) if text.isdigit() else text.lower()


def alphanum_key(key):
return [convert(c) for c in re.split("([0-9]+)", key)]


def natural_sort(arr):
return sorted(arr, key=alphanum_key)


@hydra.main(version_base=None, config_path="../../configs", config_name="analyze_dataset")
Expand All @@ -24,13 +37,14 @@ def analyze_dataset(config):
desired_range = config.desired_range
delete_bad = config.delete_bad
start_idx = config.start_idx
saturation_percent = config.saturation_percent

assert (
folder is not None
), "Must specify folder to analyze in config or through command line (folder=PATH)."

# get all PNG files in folder
files = sorted(glob.glob(os.path.join(folder, "*.png")))
files = natural_sort(glob.glob(os.path.join(folder, "*.png")))
print("Found {} files".format(len(files)))
if start_idx is not None:
files = files[start_idx:]
Expand All @@ -48,10 +62,9 @@ def analyze_dataset(config):
im = np.array(Image.open(fn))
max_val = im.max()
max_vals.append(max_val)
saturation_ratio = np.sum(im >= desired_range[1]) / im.size

# if out of desired range, print filename
if max_val < desired_range[0] or max_val > desired_range[1]:
# print("File {} has max value {}".format(fn, max_val))
if max_val < desired_range[0]:
n_bad_files += 1
bad_files.append(fn)

Expand All @@ -61,6 +74,28 @@ def analyze_dataset(config):
else:
print("File {} has max value {}".format(fn, max_val))

elif saturation_ratio > saturation_percent:
n_bad_files += 1
bad_files.append(fn)

if delete_bad:
os.remove(fn)
print("REMOVED file {}".format(fn))
else:
print("File {} has saturation ratio {}".format(fn, saturation_ratio))

# # if out of desired range, print filename
# if max_val < desired_range[0] or saturation_ratio > saturation_percent:
# # print("File {} has max value {}".format(fn, max_val))
# n_bad_files += 1
# bad_files.append(fn)

# if delete_bad:
# os.remove(fn)
# print("REMOVED file {}".format(fn))
# else:
# print("File {} has max value {}".format(fn, max_val))

proc_time = time.time() - start_time
print("Went through {} files in {:.2f} seconds".format(len(files), proc_time))
print(
Expand All @@ -69,6 +104,14 @@ def analyze_dataset(config):
)
)

# plot histogram
output_folder = os.getcwd()
output_fp = os.path.join(output_folder, "max_vals.png")
plt.hist(max_vals, bins=100)
plt.savefig(output_fp)

print("Saved histogram to {}".format(output_fp))

# command line input on whether to delete bad files
if not delete_bad:
response = None
Expand All @@ -80,14 +123,6 @@ def analyze_dataset(config):
else:
print("Not deleting bad files")

# plot histogram
output_folder = os.getcwd()
output_fp = os.path.join(output_folder, "max_vals.png")
plt.hist(max_vals, bins=100)
plt.savefig(output_fp)

print("Saved histogram to {}".format(output_fp))


if __name__ == "__main__":
analyze_dataset()
77 changes: 45 additions & 32 deletions scripts/measure/collect_dataset_on_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import hydra
from hydra.utils import to_absolute_path
import time
import os
import pathlib as plib
Expand Down Expand Up @@ -67,8 +68,8 @@ def collect_dataset(config):
start_idx = config.start_idx
if os.path.exists(output_dir):
files = list(plib.Path(output_dir).glob(f"*.{config.output_file_ext}"))
start_idx = len(files)
print("\nNumber of completed measurements :", start_idx)
n_completed_files = len(files)
print("\nNumber of completed measurements :", n_completed_files)
output_dir = plib.Path(output_dir)
if config.masks is not None:
mask_dir = plib.Path(output_dir) / "masks"
Expand All @@ -91,19 +92,23 @@ def collect_dataset(config):

recon = None
if config.recon is not None:
print("Initializing ADMM recon...")
# initialize ADMM reconstruction
from lensless import ADMM
from lensless.utils.io import load_psf

psf, bg = load_psf(
fp=config.recon.psf,
downsample=config.capture.down, # assume full resolution PSF
return_bg=True
fp=to_absolute_path(config.recon.psf),
downsample=config.capture.down, # assume full resolution PSF
return_bg=True,
)

print("PSF shape: ", psf.shape)
recon = ADMM(psf, n_iter=config.recon.n_iter)

recon_dir = plib.Path(output_dir) / "recon"
recon_dir.mkdir(exist_ok=True)
print("Finished initializing ADMM recon.")

# assert input directory exists
assert os.path.exists(input_dir)
Expand Down Expand Up @@ -250,8 +255,8 @@ def collect_dataset(config):

# -- take picture
max_pixel_val = 0
fact_increase = 2
fact_decrease = 1.5
fact_increase = config.capture.fact_increase
fact_decrease = config.capture.fact_decrease
n_tries = 0

camera.shutter_speed = init_shutter_speed
Expand All @@ -272,6 +277,7 @@ def collect_dataset(config):
# convert to RGB
output = bayer2rgb_cc(
output_bayer,
down=down,
nbits=12,
blue_gain=float(g[1]),
red_gain=float(g[0]),
Expand All @@ -280,10 +286,10 @@ def collect_dataset(config):
nbits_out=8,
)

if down:
output = resize(
output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC
)[0]
# if down:
# output = resize(
# output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC
# )[0]

# save image
save_image(output, output_fp, normalize=False)
Expand All @@ -305,27 +311,34 @@ def collect_dataset(config):

elif max_pixel_val > MAX_LEVEL:

# decrease exposure
current_shutter_speed = int(current_shutter_speed / fact_decrease)
camera.shutter_speed = current_shutter_speed
time.sleep(config.capture.config_pause)
print(f"decreasing shutter speed to {current_shutter_speed}")

# # decrease screen brightness
# current_screen_brightness = current_screen_brightness - 10
# screen_res = np.array(config.display.screen_res)
# hshift = config.display.hshift
# vshift = config.display.vshift
# pad = config.display.pad
# brightness = current_screen_brightness
# display_image_path = config.display.output_fp
# rot90 = config.display.rot90
# os.system(
# f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}"
# )
# print(f"decreasing screen brightness to {current_screen_brightness}")

# time.sleep(config.display.delay)
if current_shutter_speed > 13098: # TODO: minimum for RPi HQ
# decrease exposure
current_shutter_speed = int(current_shutter_speed / fact_decrease)
camera.shutter_speed = current_shutter_speed
time.sleep(config.capture.config_pause)
print(f"decreasing shutter speed to {current_shutter_speed}")

else:

# decrease screen brightness
current_screen_brightness = current_screen_brightness - 10
screen_res = np.array(config.display.screen_res)
hshift = config.display.hshift
vshift = config.display.vshift
pad = config.display.pad
brightness = current_screen_brightness
display_image_path = config.display.output_fp
rot90 = config.display.rot90

display_command = f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}"
if config.display.landscape:
display_command += " --landscape"
if config.display.image_res is not None:
display_command += f" --image_res {config.display.image_res[0]} {config.display.image_res[1]}"
# print(display_command)
os.system(display_command)

time.sleep(config.display.delay)

exposure_vals.append(current_shutter_speed / 1e6)
brightness_vals.append(current_screen_brightness)
Expand Down
Loading