-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from lazarchris/load-and-preprocess-dataset
Load the dataset and preprocess
- Loading branch information
Showing
2 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
import pathlib | ||
import requests | ||
import zipfile | ||
from concurrent.futures import ThreadPoolExecutor | ||
from PIL import Image | ||
|
||
""" | ||
Load the rice image dataset and preprocessing | ||
""" | ||
|
||
|
||
def download_dataset(url, filename): | ||
response = requests.get(url) | ||
if response == 200: | ||
with open(filename, "wb") as f: | ||
f.write(response.content) | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
def unzip_file(file_name, location): | ||
with zipfile.ZipFile(file_name, location) as f: | ||
f.extractall(location) | ||
|
||
|
||
def check_file_exists(directory_path) -> bool: | ||
return os.path.exists(directory_path) | ||
|
||
|
||
def print_number_of_files(data_dir): | ||
Arborio = list(data_dir.glob("Arborio/*")) | ||
Basmati = list(data_dir.glob("Basmati/*")) | ||
Ipsala = list(data_dir.glob("Ipsala/*")) | ||
Jasmine = list(data_dir.glob("Jasmine/*")) | ||
Karacadag = list(data_dir.glob("Karacadag/*")) | ||
|
||
print("The length of Arborio: %d" % len(Arborio)) | ||
print("The length of Jasmine: %d" % len(Jasmine)) | ||
print("The length of Basmati: %d" % len(Basmati)) | ||
print("The length of Ipsala: %d" % len(Ipsala)) | ||
print("The length of Karacadag: %d" % len(Karacadag)) | ||
|
||
|
||
def get_image_files(dir_path): | ||
all_image_dirs = [f.path for f in os.scandir(dir_path) if f.is_dir()] | ||
for dir in all_image_dirs: | ||
for img in os.listdir(dir): | ||
yield os.path.join(dir, img) | ||
|
||
|
||
def get_image_size(image_path) -> tuple: | ||
image = Image.open(image_path) | ||
return image.size | ||
|
||
|
||
def check_image_size(image_path, actual_width, actual_height) -> bool: | ||
width, height = get_image_size(image_path) | ||
if actual_width is width and actual_height is height: | ||
print("checked") | ||
return True | ||
else: | ||
return False | ||
|
||
|
||
def resize_image(image_path, target_size=(250, 250)): | ||
image = Image.open(image_path) | ||
resized_image = image.resize(target_size) | ||
resized_image.save(image_path) | ||
print(f'Resized image "{image_path}" to {target_size}.') | ||
|
||
|
||
def resize_if_required(data_dir, actual_width, actual_height): | ||
with ThreadPoolExecutor(max_workers=8) as executor: | ||
for img in get_image_files(data_dir): | ||
result = executor.submit(check_image_size, img, actual_width, actual_height) | ||
if not result: | ||
resize_image(img) | ||
|
||
|
||
directory_path = "Rice_Image_Dataset" | ||
if not check_file_exists(directory_path): | ||
if not download_dataset(link, directory_path): | ||
print("Download failed") | ||
exit() | ||
|
||
data_dir = pathlib.Path(directory_path).absolute() | ||
print_number_of_files(data_dir) | ||
|
||
first_image = next(get_image_files(data_dir)) | ||
actual_width, actual_height = get_image_size(first_image) | ||
resize_if_required(data_dir, actual_width, actual_height) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
absl-py==2.1.0 | ||
astunparse==1.6.3 | ||
cachetools==5.3.3 | ||
certifi==2024.2.2 | ||
charset-normalizer==3.3.2 | ||
contourpy==1.1.1 | ||
cycler==0.12.1 | ||
flatbuffers==23.5.26 | ||
fonttools==4.49.0 | ||
gast==0.4.0 | ||
google-auth==2.28.1 | ||
google-auth-oauthlib==1.0.0 | ||
google-pasta==0.2.0 | ||
grpcio==1.62.0 | ||
h5py==3.10.0 | ||
idna==3.6 | ||
importlib-metadata==7.0.1 | ||
importlib-resources==6.1.3 | ||
keras==2.13.1 | ||
kiwisolver==1.4.5 | ||
libclang==16.0.6 | ||
Markdown==3.5.2 | ||
MarkupSafe==2.1.5 | ||
matplotlib==3.7.5 | ||
numpy==1.24.3 | ||
oauthlib==3.2.2 | ||
opencv-python==4.9.0.80 | ||
opt-einsum==3.3.0 | ||
packaging==23.2 | ||
pillow==10.2.0 | ||
protobuf==4.25.3 | ||
pyasn1==0.5.1 | ||
pyasn1-modules==0.3.0 | ||
pyparsing==3.1.2 | ||
python-dateutil==2.9.0.post0 | ||
requests==2.31.0 | ||
requests-oauthlib==1.3.1 | ||
rsa==4.9 | ||
six==1.16.0 | ||
tensorboard==2.13.0 | ||
tensorboard-data-server==0.7.2 | ||
tensorflow==2.13.1 | ||
tensorflow-estimator==2.13.0 | ||
tensorflow-io-gcs-filesystem==0.34.0 | ||
termcolor==2.4.0 | ||
typing-extensions==4.5.0 | ||
urllib3==2.2.1 | ||
werkzeug==3.0.1 | ||
wrapt==1.16.0 | ||
zipp==3.17.0 |