From 1de818d79c89ead4d787af28d9ea4bbb5608e966 Mon Sep 17 00:00:00 2001 From: Christy Lazar Date: Mon, 18 Mar 2024 22:27:44 +0100 Subject: [PATCH] Load the dataset and preprocess --- image_classification.py | 93 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 50 ++++++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 image_classification.py create mode 100644 requirements.txt diff --git a/image_classification.py b/image_classification.py new file mode 100644 index 0000000..765ab45 --- /dev/null +++ b/image_classification.py @@ -0,0 +1,93 @@ +import os +import pathlib +import requests +import zipfile +from concurrent.futures import ThreadPoolExecutor +from PIL import Image + +""" +Load the rice image dataset and preprocessing +""" + + +def download_dataset(url, filename): + response = requests.get(url) + if response == 200: + with open(filename, "wb") as f: + f.write(response.content) + return True + else: + return False + + +def unzip_file(file_name, location): + with zipfile.ZipFile(file_name, location) as f: + f.extractall(location) + + +def check_file_exists(directory_path) -> bool: + return os.path.exists(directory_path) + + +def print_number_of_files(data_dir): + Arborio = list(data_dir.glob("Arborio/*")) + Basmati = list(data_dir.glob("Basmati/*")) + Ipsala = list(data_dir.glob("Ipsala/*")) + Jasmine = list(data_dir.glob("Jasmine/*")) + Karacadag = list(data_dir.glob("Karacadag/*")) + + print("The length of Arborio: %d" % len(Arborio)) + print("The length of Jasmine: %d" % len(Jasmine)) + print("The length of Basmati: %d" % len(Basmati)) + print("The length of Ipsala: %d" % len(Ipsala)) + print("The length of Karacadag: %d" % len(Karacadag)) + + +def get_image_files(dir_path): + all_image_dirs = [f.path for f in os.scandir(dir_path) if f.is_dir()] + for dir in all_image_dirs: + for img in os.listdir(dir): + yield os.path.join(dir, img) + + +def get_image_size(image_path) -> tuple: + image = Image.open(image_path) + return image.size + + +def check_image_size(image_path, actual_width, actual_height) -> bool: + width, height = get_image_size(image_path) + if actual_width is width and actual_height is height: + print("checked") + return True + else: + return False + + +def resize_image(image_path, target_size=(250, 250)): + image = Image.open(image_path) + resized_image = image.resize(target_size) + resized_image.save(image_path) + print(f'Resized image "{image_path}" to {target_size}.') + + +def resize_if_required(data_dir, actual_width, actual_height): + with ThreadPoolExecutor(max_workers=8) as executor: + for img in get_image_files(data_dir): + result = executor.submit(check_image_size, img, actual_width, actual_height) + if not result: + resize_image(img) + + +directory_path = "Rice_Image_Dataset" +if not check_file_exists(directory_path): + if not download_dataset(link, directory_path): + print("Download failed") + exit() + +data_dir = pathlib.Path(directory_path).absolute() +print_number_of_files(data_dir) + +first_image = next(get_image_files(data_dir)) +actual_width, actual_height = get_image_size(first_image) +resize_if_required(data_dir, actual_width, actual_height) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..abc32c9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,50 @@ +absl-py==2.1.0 +astunparse==1.6.3 +cachetools==5.3.3 +certifi==2024.2.2 +charset-normalizer==3.3.2 +contourpy==1.1.1 +cycler==0.12.1 +flatbuffers==23.5.26 +fonttools==4.49.0 +gast==0.4.0 +google-auth==2.28.1 +google-auth-oauthlib==1.0.0 +google-pasta==0.2.0 +grpcio==1.62.0 +h5py==3.10.0 +idna==3.6 +importlib-metadata==7.0.1 +importlib-resources==6.1.3 +keras==2.13.1 +kiwisolver==1.4.5 +libclang==16.0.6 +Markdown==3.5.2 +MarkupSafe==2.1.5 +matplotlib==3.7.5 +numpy==1.24.3 +oauthlib==3.2.2 +opencv-python==4.9.0.80 +opt-einsum==3.3.0 +packaging==23.2 +pillow==10.2.0 +protobuf==4.25.3 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +requests==2.31.0 +requests-oauthlib==1.3.1 +rsa==4.9 +six==1.16.0 +tensorboard==2.13.0 +tensorboard-data-server==0.7.2 +tensorflow==2.13.1 +tensorflow-estimator==2.13.0 +tensorflow-io-gcs-filesystem==0.34.0 +termcolor==2.4.0 +typing-extensions==4.5.0 +urllib3==2.2.1 +werkzeug==3.0.1 +wrapt==1.16.0 +zipp==3.17.0