diff --git a/.gitignore b/.gitignore index 7761abd..4f6f7e8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ src/ibc_api.egg-info src/ibc_api/__pycache__ src/ibc-api ibc_data -.ipynb_checkpoints \ No newline at end of file +.ipynb_checkpoints +src/ibc_api/data/token \ No newline at end of file diff --git a/examples/example.py b/examples/example.py index 4e4b887..cc064f4 100644 --- a/examples/example.py +++ b/examples/example.py @@ -1,13 +1,13 @@ -import ibc_api.utils as ibc +from ibc_api import utils as ibc # Fetch info on all available files # Load as a pandas dataframe and save as ibc_data/available_{data_type}.csv db = ibc.get_info(data_type="volume_maps") # Keep statistic maps for sub-08, for task-Discount -filtered_db = ibc.filter_data(db, subject_list=["08"], task_list=["Discount"]) +filtered_db = ibc.filter_data(db, subject_list=["08"], task_list=["Lec1"]) -# Download all statistic maps for sub-08, task-Discount +# Download all statistic maps for sub-08, task-Lec1 # Also creates ibc_data/downloaded_volume_maps.csv # which contains local file paths and time of download -downloaded_db = ibc.download_data(filtered_db) +downloaded_db = ibc.download_data(filtered_db, n_jobs=2) diff --git a/pyproject.toml b/pyproject.toml index 9e13a83..e7a0331 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,4 +11,5 @@ dependencies = [ "nibabel", "pandas", "tqdm", + "joblib", ] \ No newline at end of file diff --git a/src/ibc_api/utils.py b/src/ibc_api/utils.py index 8255764..915e7b2 100644 --- a/src/ibc_api/utils.py +++ b/src/ibc_api/utils.py @@ -1,18 +1,20 @@ """API to fetch IBC data from EBRAINS via Human Data Gateway using siibra. """ -import siibra -from siibra.retrieval.repositories import EbrainsHdgConnector -from siibra.retrieval.requests import EbrainsRequest, SiibraHttpRequestError +# %$ +import json import os -from tqdm import tqdm +from datetime import datetime + import nibabel -from siibra.retrieval.cache import CACHE import pandas as pd -from datetime import datetime +import siibra +from joblib import Memory, Parallel, delayed +from siibra.retrieval.cache import CACHE +from siibra.retrieval.repositories import EbrainsHdgConnector +from siibra.retrieval.requests import EbrainsRequest, SiibraHttpRequestError + from . import metadata as md -import json -import numpy as np # clear cache CACHE.clear() @@ -27,6 +29,11 @@ TOKEN_ROOT = os.path.join(os.path.dirname(__file__), "data") os.makedirs(TOKEN_ROOT, exist_ok=True) +# memory cache +joblib_cache_dir = os.path.join(os.path.dirname(__file__), "cache") +os.makedirs(joblib_cache_dir, exist_ok=True) +memory = Memory(joblib_cache_dir, verbose=0) + def _authenticate(token_dir=TOKEN_ROOT): """This function authenticates you to EBRAINS. It would return a link that @@ -226,7 +233,7 @@ def filter_data(db, subject_list=SUBJECTS, task_list=False): return filtered_db -def get_file_paths(db, metadata=METADATA): +def get_file_paths(db, metadata=METADATA, save_to_dir=None): """Get the remote and local file paths for each file in a (filtered) dataframe. Parameters @@ -238,7 +245,8 @@ def get_file_paths(db, metadata=METADATA): Returns ------- filenames, list - lists of file paths for each file in the input dataframe. First list is the remote file paths and second list is the local file paths + lists of file paths for each file in the input dataframe. First list + is the remote file paths and second list is the local file paths """ # get the data type from the db data_type = db["dataset"].unique() @@ -251,7 +259,10 @@ def get_file_paths(db, metadata=METADATA): remote_file_names = [] local_file_names = [] remote_root_dir = md.select_dataset(data_type, metadata)["root"] - local_root_dir = data_type + if save_to_dir == None: + local_root_dir = data_type + else: + local_root_dir = os.path.join(save_to_dir, data_type) for file in file_names: # put the file path together # always use "/" as the separator for remote file paths @@ -267,44 +278,44 @@ def get_file_paths(db, metadata=METADATA): return remote_file_names, local_file_names -def _update_local_db(db_file, file_names, file_times): +def _update_local_db(db_file, files_data): """Update the local database of downloaded files. Parameters ---------- db_file : str path to the local database file - file_names : str or list - path to the downloaded file(s) - file_times : str or list - time at which the file(s) were downloaded + files_data : list of tuples + list of tuples where each tuple contains (file_name, file_time) Returns ------- pandas.DataFrame updated local database """ - - if type(file_names) is str: - file_names = [file_names] - file_times = [file_times] - if not os.path.exists(db_file): # create a new database - db = pd.DataFrame( - {"local_path": file_names, "downloaded_on": file_times} - ) + db = pd.DataFrame(columns=["local_path", "downloaded_on"]) else: - # load the database - db = pd.read_csv(db_file, index_col=False) - new_db = pd.DataFrame( - {"local_path": file_names, "downloaded_on": file_times} - ) - # update the database - db = pd.concat([db, new_db]) - db.reset_index(drop=True, inplace=True) + try: + # load the database + db = pd.read_csv(db_file, index_col=False) + except ( + pd.errors.EmptyDataError, + pd.errors.ParserError, + FileNotFoundError, + ): + print("Empty database file. Creating a new one.") + db = pd.DataFrame(columns=["local_path", "downloaded_on"]) + + downloaded_db = pd.DataFrame( + files_data, columns=["local_path", "downloaded_on"] + ) + + new_db = pd.concat([db, downloaded_db], ignore_index=True) + new_db.reset_index(drop=True, inplace=True) # save the database - db.to_csv(db_file, index=False) + new_db.to_csv(db_file, index=False) return db @@ -318,6 +329,11 @@ def _write_file(file, data): path to the file to write to data : data fetched from ebrains data to write to the file + + Returns + ------- + file: str + path to the written """ # check file type and write accordingly if type(data) == nibabel.nifti1.Nifti1Image: @@ -353,6 +369,7 @@ def _write_file(file, data): return file +@memory.cache def _download_file(src_file, dst_file, connector): """Download a file from ebrains. @@ -370,6 +387,7 @@ def _download_file(src_file, dst_file, connector): str, datetime path to the downloaded file and time at which it was downloaded """ + # CACHE.run_maintenance() if not os.path.exists(dst_file): # load the file from ebrains src_data = connector.get(src_file) @@ -381,11 +399,10 @@ def _download_file(src_file, dst_file, connector): return dst_file else: print(f"File {dst_file} already exists, skipping download.") - - return [] + return dst_file -def download_data(db, save_to=None): +def download_data(db, n_jobs=2, save_to=None): """Download the files in a (filtered) dataframe. Parameters @@ -393,6 +410,9 @@ def download_data(db, save_to=None): db : pandas.DataFrame dataframe with information about files in the dataset, ideally a subset of the full dataset + n_jobs : int, optional + number of parallel jobs to run, by default 2. -1 would use all the CPUs. + See: https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html save_to : str, optional where to save the data, by default None, in which case the data is saved in a directory called "ibc_data" in the current working directory @@ -406,7 +426,7 @@ def download_data(db, save_to=None): db_length = len(db) if db_length == 0: raise ValueError( - f"The input dataframe is empty. Please make sure that it atleast has columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns." + f"The input dataframe is empty. Please make sure that it at least has columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns." ) else: print(f"Found {db_length} files to download.") @@ -420,32 +440,44 @@ def download_data(db, save_to=None): # get data type from db data_type = db["dataset"].unique()[0] # connect to ebrains dataset + print("... Fetching token and connecting to EBRAINS ...") connector = _connect_ebrains(data_type) - # get the file names as they are on ebrains - src_file_names, dst_file_names = get_file_paths(db) # set the save directory save_to = _create_root_dir(save_to) - # track downloaded file names and times + # file to track downloaded file names and times local_db_file = os.path.join(save_to, f"downloaded_{data_type}.csv") - # download the files - for src_file, dst_file in tqdm( - zip(src_file_names, dst_file_names), - position=1, - leave=True, - total=db_length, - desc="Overall Progress: ", - colour="green", - ): - # final file path to save the data - dst_file = os.path.join(save_to, dst_file) - file_name = _download_file(src_file, dst_file, connector) - file_time = datetime.now() - local_db = _update_local_db(local_db_file, file_name, file_time) - # keep cache < 2GB, delete oldest files first - CACHE.run_maintenance() + # get the file names as they are on ebrains + src_file_names, dst_file_names = get_file_paths(db, save_to_dir=save_to) + + # helper to process the parallel download + def _download_and_update_progress(src_file, dst_file, connector): + try: + file_name = _download_file(src_file, dst_file, connector) + file_time = datetime.now() + CACHE.run_maintenance() # keep cache < 2GB + return file_name, file_time + except Exception as e: + raise(f"Error downloading {src_file}. Error: {e}") + + # download finally + print(f"\n...Starting download of {len(src_file_names)} files...") + results = Parallel(n_jobs=n_jobs, backend="threading", verbose=10)( + delayed(_download_and_update_progress)(src_file, dst_file, connector) + for src_file, dst_file in zip(src_file_names, dst_file_names) + ) + + # update the local database + results = [res for res in results if res[0] is not None] + if len(results) == 0: + raise RuntimeError(f"No files downloaded ! Please try again.") + download_details = _update_local_db(local_db_file, results) print( f"Downloaded requested files from IBC {data_type} dataset. See " - f"{local_db_file} for details." + f"{local_db_file} for details.\n" ) - return local_db + # clean up the cache + CACHE.clear() + memory.clear() + + return download_details