Skip to content

Commit

Permalink
Add downsample option before Hugging Face upload.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed May 2, 2024
1 parent 49fa828 commit 9164e95
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
1 change: 1 addition & 0 deletions configs/upload_dataset_huggingface.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ lensless:
dir: null
ext: null # for example: .png, .jpg
eight_norm: False # save as 8-bit normalized image
downsample: null

lensed:
dir: null
Expand Down
35 changes: 30 additions & 5 deletions scripts/data/upload_dataset_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from lensless.utils.dataset import natural_sort
from tqdm import tqdm
from lensless.utils.io import save_image
import cv2
from joblib import Parallel, delayed


@hydra.main(
Expand Down Expand Up @@ -82,24 +84,46 @@ def upload_dataset(config):
]
lensed_files = [os.path.join(config.lensed.dir, f + config.lensed.ext) for f in common_files]

if config.lensless.downsample is not None:

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

def downsample(f, output_dir):
img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
print("Original shape:", img.shape)
img = cv2.resize(
img,
(0, 0),
fx=1 / config.lensless.downsample,
fy=1 / config.lensless.downsample,
interpolation=cv2.INTER_LINEAR,
)
print("Downsampled shape:", img.shape)
new_fp = os.path.join(output_dir, os.path.basename(f))
new_fp = new_fp.split(".")[0] + config.lensless.ext
save_image(img, new_fp, normalize=False)

Parallel(n_jobs=n_jobs)(delayed(downsample)(f, tmp_dir) for f in tqdm(lensless_files))
lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}"))

# convert to normalized 8 bit
if config.lensless.eight_norm:

import cv2
from joblib import Parallel, delayed

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

# -- parallelize with joblib
def save_8bit(f, output_dir, normalize=True):
img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
new_fp = os.path.join(output_dir, os.path.basename(f))
new_fp = new_fp.split(".")[0] + ".png"
new_fp = new_fp.split(".")[0] + config.lensless.ext
save_image(img, new_fp, normalize=normalize)

Parallel(n_jobs=n_jobs)(delayed(save_8bit)(f, tmp_dir) for f in tqdm(lensless_files))
lensless_files = glob.glob(os.path.join(tmp_dir, "*png"))
lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}"))

# check for attribute
df_attr = None
Expand Down Expand Up @@ -222,6 +246,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):

upload_file(
path_or_fileobj=lensless_files[0],
# path_in_repo=f"lensless_example{config.lensless.ext}" if not config.lensless.eight_norm else f"lensless_example.png",
path_in_repo=f"lensless_example{config.lensless.ext}",
repo_id=repo_id,
repo_type="dataset",
Expand Down

0 comments on commit 9164e95

Please sign in to comment.