Skip to content

Commit

Permalink
Merge pull request #11 from individual-brain-charting/parallel_download
Browse files Browse the repository at this point in the history
Parallel download
  • Loading branch information
ferponcem authored Sep 3, 2024
2 parents 124ff49 + 2303012 commit 6e5cc23
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 63 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ src/ibc_api.egg-info
src/ibc_api/__pycache__
src/ibc-api
ibc_data
.ipynb_checkpoints
.ipynb_checkpoints
src/ibc_api/data/token
8 changes: 4 additions & 4 deletions examples/example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import ibc_api.utils as ibc
from ibc_api import utils as ibc

# Fetch info on all available files
# Load as a pandas dataframe and save as ibc_data/available_{data_type}.csv
db = ibc.get_info(data_type="volume_maps")

# Keep statistic maps for sub-08, for task-Discount
filtered_db = ibc.filter_data(db, subject_list=["08"], task_list=["Discount"])
filtered_db = ibc.filter_data(db, subject_list=["08"], task_list=["Lec1"])

# Download all statistic maps for sub-08, task-Discount
# Download all statistic maps for sub-08, task-Lec1
# Also creates ibc_data/downloaded_volume_maps.csv
# which contains local file paths and time of download
downloaded_db = ibc.download_data(filtered_db)
downloaded_db = ibc.download_data(filtered_db, n_jobs=2)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ dependencies = [
"nibabel",
"pandas",
"tqdm",
"joblib",
]
148 changes: 90 additions & 58 deletions src/ibc_api/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""API to fetch IBC data from EBRAINS via Human Data Gateway using siibra.
"""

import siibra
from siibra.retrieval.repositories import EbrainsHdgConnector
from siibra.retrieval.requests import EbrainsRequest, SiibraHttpRequestError
# %$
import json
import os
from tqdm import tqdm
from datetime import datetime

import nibabel
from siibra.retrieval.cache import CACHE
import pandas as pd
from datetime import datetime
import siibra
from joblib import Memory, Parallel, delayed
from siibra.retrieval.cache import CACHE
from siibra.retrieval.repositories import EbrainsHdgConnector
from siibra.retrieval.requests import EbrainsRequest, SiibraHttpRequestError

from . import metadata as md
import json
import numpy as np

# clear cache
CACHE.clear()
Expand All @@ -27,6 +29,11 @@
TOKEN_ROOT = os.path.join(os.path.dirname(__file__), "data")
os.makedirs(TOKEN_ROOT, exist_ok=True)

# memory cache
joblib_cache_dir = os.path.join(os.path.dirname(__file__), "cache")
os.makedirs(joblib_cache_dir, exist_ok=True)
memory = Memory(joblib_cache_dir, verbose=0)


def _authenticate(token_dir=TOKEN_ROOT):
"""This function authenticates you to EBRAINS. It would return a link that
Expand Down Expand Up @@ -226,7 +233,7 @@ def filter_data(db, subject_list=SUBJECTS, task_list=False):
return filtered_db


def get_file_paths(db, metadata=METADATA):
def get_file_paths(db, metadata=METADATA, save_to_dir=None):
"""Get the remote and local file paths for each file in a (filtered) dataframe.
Parameters
Expand All @@ -238,7 +245,8 @@ def get_file_paths(db, metadata=METADATA):
Returns
-------
filenames, list
lists of file paths for each file in the input dataframe. First list is the remote file paths and second list is the local file paths
lists of file paths for each file in the input dataframe. First list
is the remote file paths and second list is the local file paths
"""
# get the data type from the db
data_type = db["dataset"].unique()
Expand All @@ -251,7 +259,10 @@ def get_file_paths(db, metadata=METADATA):
remote_file_names = []
local_file_names = []
remote_root_dir = md.select_dataset(data_type, metadata)["root"]
local_root_dir = data_type
if save_to_dir == None:
local_root_dir = data_type
else:
local_root_dir = os.path.join(save_to_dir, data_type)
for file in file_names:
# put the file path together
# always use "/" as the separator for remote file paths
Expand All @@ -267,44 +278,44 @@ def get_file_paths(db, metadata=METADATA):
return remote_file_names, local_file_names


def _update_local_db(db_file, file_names, file_times):
def _update_local_db(db_file, files_data):
"""Update the local database of downloaded files.
Parameters
----------
db_file : str
path to the local database file
file_names : str or list
path to the downloaded file(s)
file_times : str or list
time at which the file(s) were downloaded
files_data : list of tuples
list of tuples where each tuple contains (file_name, file_time)
Returns
-------
pandas.DataFrame
updated local database
"""

if type(file_names) is str:
file_names = [file_names]
file_times = [file_times]

if not os.path.exists(db_file):
# create a new database
db = pd.DataFrame(
{"local_path": file_names, "downloaded_on": file_times}
)
db = pd.DataFrame(columns=["local_path", "downloaded_on"])
else:
# load the database
db = pd.read_csv(db_file, index_col=False)
new_db = pd.DataFrame(
{"local_path": file_names, "downloaded_on": file_times}
)
# update the database
db = pd.concat([db, new_db])
db.reset_index(drop=True, inplace=True)
try:
# load the database
db = pd.read_csv(db_file, index_col=False)
except (
pd.errors.EmptyDataError,
pd.errors.ParserError,
FileNotFoundError,
):
print("Empty database file. Creating a new one.")
db = pd.DataFrame(columns=["local_path", "downloaded_on"])

downloaded_db = pd.DataFrame(
files_data, columns=["local_path", "downloaded_on"]
)

new_db = pd.concat([db, downloaded_db], ignore_index=True)
new_db.reset_index(drop=True, inplace=True)
# save the database
db.to_csv(db_file, index=False)
new_db.to_csv(db_file, index=False)

return db

Expand All @@ -318,6 +329,11 @@ def _write_file(file, data):
path to the file to write to
data : data fetched from ebrains
data to write to the file
Returns
-------
file: str
path to the written
"""
# check file type and write accordingly
if type(data) == nibabel.nifti1.Nifti1Image:
Expand Down Expand Up @@ -353,6 +369,7 @@ def _write_file(file, data):
return file


@memory.cache
def _download_file(src_file, dst_file, connector):
"""Download a file from ebrains.
Expand All @@ -370,6 +387,7 @@ def _download_file(src_file, dst_file, connector):
str, datetime
path to the downloaded file and time at which it was downloaded
"""
# CACHE.run_maintenance()
if not os.path.exists(dst_file):
# load the file from ebrains
src_data = connector.get(src_file)
Expand All @@ -381,18 +399,20 @@ def _download_file(src_file, dst_file, connector):
return dst_file
else:
print(f"File {dst_file} already exists, skipping download.")

return []
return dst_file


def download_data(db, save_to=None):
def download_data(db, n_jobs=2, save_to=None):
"""Download the files in a (filtered) dataframe.
Parameters
----------
db : pandas.DataFrame
dataframe with information about files in the dataset, ideally a subset
of the full dataset
n_jobs : int, optional
number of parallel jobs to run, by default 2. -1 would use all the CPUs.
See: https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html
save_to : str, optional
where to save the data, by default None, in which case the data is
saved in a directory called "ibc_data" in the current working directory
Expand All @@ -406,7 +426,7 @@ def download_data(db, save_to=None):
db_length = len(db)
if db_length == 0:
raise ValueError(
f"The input dataframe is empty. Please make sure that it atleast has columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns."
f"The input dataframe is empty. Please make sure that it at least has columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns."
)
else:
print(f"Found {db_length} files to download.")
Expand All @@ -420,32 +440,44 @@ def download_data(db, save_to=None):
# get data type from db
data_type = db["dataset"].unique()[0]
# connect to ebrains dataset
print("... Fetching token and connecting to EBRAINS ...")
connector = _connect_ebrains(data_type)
# get the file names as they are on ebrains
src_file_names, dst_file_names = get_file_paths(db)
# set the save directory
save_to = _create_root_dir(save_to)
# track downloaded file names and times
# file to track downloaded file names and times
local_db_file = os.path.join(save_to, f"downloaded_{data_type}.csv")
# download the files
for src_file, dst_file in tqdm(
zip(src_file_names, dst_file_names),
position=1,
leave=True,
total=db_length,
desc="Overall Progress: ",
colour="green",
):
# final file path to save the data
dst_file = os.path.join(save_to, dst_file)
file_name = _download_file(src_file, dst_file, connector)
file_time = datetime.now()
local_db = _update_local_db(local_db_file, file_name, file_time)
# keep cache < 2GB, delete oldest files first
CACHE.run_maintenance()
# get the file names as they are on ebrains
src_file_names, dst_file_names = get_file_paths(db, save_to_dir=save_to)

# helper to process the parallel download
def _download_and_update_progress(src_file, dst_file, connector):
try:
file_name = _download_file(src_file, dst_file, connector)
file_time = datetime.now()
CACHE.run_maintenance() # keep cache < 2GB
return file_name, file_time
except Exception as e:
raise(f"Error downloading {src_file}. Error: {e}")

# download finally
print(f"\n...Starting download of {len(src_file_names)} files...")
results = Parallel(n_jobs=n_jobs, backend="threading", verbose=10)(
delayed(_download_and_update_progress)(src_file, dst_file, connector)
for src_file, dst_file in zip(src_file_names, dst_file_names)
)

# update the local database
results = [res for res in results if res[0] is not None]
if len(results) == 0:
raise RuntimeError(f"No files downloaded ! Please try again.")
download_details = _update_local_db(local_db_file, results)
print(
f"Downloaded requested files from IBC {data_type} dataset. See "
f"{local_db_file} for details."
f"{local_db_file} for details.\n"
)

return local_db
# clean up the cache
CACHE.clear()
memory.clear()

return download_details

0 comments on commit 6e5cc23

Please sign in to comment.