From d10099b0619b0efad7be08123180224e5606af47 Mon Sep 17 00:00:00 2001 From: Himanshu Aggarwal Date: Wed, 4 Oct 2023 18:03:34 +0200 Subject: [PATCH] error handling, CACHE clearing, better progress bars --- src/ibc_api/utils.py | 50 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/ibc_api/utils.py b/src/ibc_api/utils.py index ce2fe46..23228f4 100644 --- a/src/ibc_api/utils.py +++ b/src/ibc_api/utils.py @@ -11,9 +11,10 @@ from datetime import datetime from . import metadata as md import json +import numpy as np # clear cache -CACHE.run_maintenance() +CACHE.clear() # dataset ids on ebrains METADATA = md.fetch_metadata() @@ -178,7 +179,15 @@ def filter_data(db, subject_list=SUBJECTS, task_list=False): # filter the database on task if specified if task_list: filtered_db = filtered_db[filtered_db["task"].isin(task_list)] - + length = len(filtered_db) + if length == 0: + raise ValueError( + f"No files found for subjects {subject_list} and tasks {task_list}." + ) + else: + print( + f"Found {length} files for subjects {subject_list} and tasks {task_list}." + ) return filtered_db @@ -292,9 +301,18 @@ def _write_file(file, data): elif type(data) == dict: with open(file, "w") as f: json.dump(data, f) + elif type(data) == bytes: + if file.endswith(".bvec") or file.endswith(".bval"): + with open(file, "wb") as f: + f.write(data) + f.close() + else: + raise ValueError( + f"Don't know how to save file {file} of type {type(data)}" + ) else: raise ValueError( - f"Don't know how to save file {file}" f" of type {type(data)}" + f"Don't know how to save file {file} of type {type(data)}" ) return file @@ -349,6 +367,21 @@ def download_data(db, save_to=None): pandas.DataFrame dataframe with downloaded file names and times from the dataset """ + # make sure the dataframe is not empty + db_length = len(db) + if db_length == 0: + raise ValueError( + f"The input dataframe is empty. Please make sure that it atleast has columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns." + ) + else: + print(f"Found {db_length} files to download.") + # make sure the dataframe has dataset and path columns + db_columns = db.columns.tolist() + if "dataset" not in db_columns or "path" not in db_columns: + raise ValueError( + f"The input dataframe should have columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns." + ) + # get data type from db data_type = db["dataset"].unique()[0] # connect to ebrains dataset @@ -360,13 +393,20 @@ def download_data(db, save_to=None): # track downloaded file names and times local_db_file = os.path.join(save_to, f"downloaded_{data_type}.csv") # download the files - for src_file, dst_file in tqdm(zip(src_file_names, dst_file_names)): + for src_file, dst_file in tqdm( + zip(src_file_names, dst_file_names), + position=1, + leave=True, + total=db_length, + desc="Overall Progress: ", + colour="green", + ): # final file path to save the data dst_file = os.path.join(save_to, dst_file) file_name = _download_file(src_file, dst_file, connector) file_time = datetime.now() local_db = _update_local_db(local_db_file, file_name, file_time) - # keep cache < 2GiB, delete oldest files first + # keep cache < 2GB, delete oldest files first CACHE.run_maintenance() print( f"Downloaded requested files from IBC {data_type} dataset. See "