Skip to content

Commit

Permalink
Fix normalization, and downsample before color correction.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jun 6, 2024
1 parent 17462ff commit 3062e7d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 48 deletions.
1 change: 1 addition & 0 deletions configs/analyze_dataset.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# python scripts/measure/analyze_measured_dataset.py
hydra:
job:
chdir: True # change to output folder
Expand Down
6 changes: 4 additions & 2 deletions configs/collect_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ display:

capture:
skip: False # to test looping over displaying images
config_pause: 2
config_pause: 3
iso: 100
res: null
down: 4
exposure: 0.02 # min exposure
awb_gains: [1.9, 1.2] # red, blue
# awb_gains: null
# awb_gains: null
fact_increase: 2 # multiplicative factor to increase exposure
fact_decrease: 1.5
6 changes: 6 additions & 0 deletions lensless/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def get_max_val(img, nbits=None):
def bayer2rgb_cc(
img,
nbits,
down=None,
blue_gain=None,
red_gain=None,
black_level=RPI_HQ_CAMERA_BLACK_LEVEL,
Expand Down Expand Up @@ -269,6 +270,10 @@ def bayer2rgb_cc(
# demosaic Bayer data
img = cv2.cvtColor(img, cv2.COLOR_BayerRG2RGB)

# downsample
if down is not None:
img = resize(img[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC)[0]

# correction
img = img - black_level
if red_gain:
Expand All @@ -277,6 +282,7 @@ def bayer2rgb_cc(
img[:, :, 2] *= blue_gain
img = img / (2**nbits - 1 - black_level)
img[img > 1] = 1

img = (img.reshape(-1, 3, order="F") @ ccm.T).reshape(img.shape, order="F")
img[img < 0] = 0
img[img > 1] = 1
Expand Down
40 changes: 24 additions & 16 deletions lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,31 +576,39 @@ def save_image(img, fp, max_val=255, normalize=True):

img_tmp = img.copy()

if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8:
img_tmp = img_tmp.astype(np.float32)

if normalize:

if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8:
img_tmp = img_tmp.astype(np.float32)

img_tmp -= img_tmp.min()
img_tmp /= img_tmp.max()
else:
normalized = False
if img_tmp.min() < 0:
img_tmp -= img_tmp.min()
normalize = True
if img_tmp.max() > 1:
img_tmp /= img_tmp.max()
normalize = True
if normalized:
print(f"Warning (out of range): {fp} normalizing data to [0, 1]")

if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32:
img_tmp *= max_val
img_tmp = img_tmp.astype(np.uint8)

# RGB
else:

if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32:
# check within [0, 1] and convert to uint8

normalized = False
if img_tmp.min() < 0:
img_tmp -= img_tmp.min()
normalized = True
if img_tmp.max() > 1:
img_tmp /= img_tmp.max()
normalized = True
if normalized:
print(f"Warning (out of range): {fp} normalizing data to [0, 1]")
img_tmp *= max_val
img_tmp = img_tmp.astype(np.uint8)

# save
if len(img_tmp.shape) == 3 and img_tmp.shape[2] == 3:
# RGB
img_tmp = Image.fromarray(img_tmp)
else:
# grayscale
img_tmp = Image.fromarray(img_tmp.squeeze())
img_tmp.save(fp)

Expand Down
73 changes: 43 additions & 30 deletions scripts/measure/collect_dataset_on_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import hydra
from hydra.utils import to_absolute_path
import time
import os
import pathlib as plib
Expand Down Expand Up @@ -91,19 +92,23 @@ def collect_dataset(config):

recon = None
if config.recon is not None:
print("Initializing ADMM recon...")
# initialize ADMM reconstruction
from lensless import ADMM
from lensless.utils.io import load_psf

psf, bg = load_psf(
fp=config.recon.psf,
downsample=config.capture.down, # assume full resolution PSF
return_bg=True
fp=to_absolute_path(config.recon.psf),
downsample=config.capture.down, # assume full resolution PSF
return_bg=True,
)

print("PSF shape: ", psf.shape)
recon = ADMM(psf, n_iter=config.recon.n_iter)

recon_dir = plib.Path(output_dir) / "recon"
recon_dir.mkdir(exist_ok=True)
print("Finished initializing ADMM recon.")

# assert input directory exists
assert os.path.exists(input_dir)
Expand Down Expand Up @@ -250,8 +255,8 @@ def collect_dataset(config):

# -- take picture
max_pixel_val = 0
fact_increase = 2
fact_decrease = 1.5
fact_increase = config.capture.fact_increase
fact_decrease = config.capture.fact_decrease
n_tries = 0

camera.shutter_speed = init_shutter_speed
Expand All @@ -272,6 +277,7 @@ def collect_dataset(config):
# convert to RGB
output = bayer2rgb_cc(
output_bayer,
down=down,
nbits=12,
blue_gain=float(g[1]),
red_gain=float(g[0]),
Expand All @@ -280,10 +286,10 @@ def collect_dataset(config):
nbits_out=8,
)

if down:
output = resize(
output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC
)[0]
# if down:
# output = resize(
# output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC
# )[0]

# save image
save_image(output, output_fp, normalize=False)
Expand All @@ -305,27 +311,34 @@ def collect_dataset(config):

elif max_pixel_val > MAX_LEVEL:

# decrease exposure
current_shutter_speed = int(current_shutter_speed / fact_decrease)
camera.shutter_speed = current_shutter_speed
time.sleep(config.capture.config_pause)
print(f"decreasing shutter speed to {current_shutter_speed}")

# # decrease screen brightness
# current_screen_brightness = current_screen_brightness - 10
# screen_res = np.array(config.display.screen_res)
# hshift = config.display.hshift
# vshift = config.display.vshift
# pad = config.display.pad
# brightness = current_screen_brightness
# display_image_path = config.display.output_fp
# rot90 = config.display.rot90
# os.system(
# f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}"
# )
# print(f"decreasing screen brightness to {current_screen_brightness}")

# time.sleep(config.display.delay)
if current_shutter_speed > 13098: # TODO: minimum for RPi HQ
# decrease exposure
current_shutter_speed = int(current_shutter_speed / fact_decrease)
camera.shutter_speed = current_shutter_speed
time.sleep(config.capture.config_pause)
print(f"decreasing shutter speed to {current_shutter_speed}")

else:

# decrease screen brightness
current_screen_brightness = current_screen_brightness - 10
screen_res = np.array(config.display.screen_res)
hshift = config.display.hshift
vshift = config.display.vshift
pad = config.display.pad
brightness = current_screen_brightness
display_image_path = config.display.output_fp
rot90 = config.display.rot90

display_command = f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}"
if config.display.landscape:
display_command += " --landscape"
if config.display.image_res is not None:
display_command += f" --image_res {config.display.image_res[0]} {config.display.image_res[1]}"
# print(display_command)
os.system(display_command)

time.sleep(config.display.delay)

exposure_vals.append(current_shutter_speed / 1e6)
brightness_vals.append(current_screen_brightness)
Expand Down

0 comments on commit 3062e7d

Please sign in to comment.