Skip to content

Commit

Permalink
error handling, CACHE clearing, better progress bars
Browse files Browse the repository at this point in the history
  • Loading branch information
man-shu committed Oct 4, 2023
1 parent 2f6af28 commit d10099b
Showing 1 changed file with 45 additions and 5 deletions.
50 changes: 45 additions & 5 deletions src/ibc_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from datetime import datetime
from . import metadata as md
import json
import numpy as np

# clear cache
CACHE.run_maintenance()
CACHE.clear()

# dataset ids on ebrains
METADATA = md.fetch_metadata()
Expand Down Expand Up @@ -178,7 +179,15 @@ def filter_data(db, subject_list=SUBJECTS, task_list=False):
# filter the database on task if specified
if task_list:
filtered_db = filtered_db[filtered_db["task"].isin(task_list)]

length = len(filtered_db)
if length == 0:
raise ValueError(
f"No files found for subjects {subject_list} and tasks {task_list}."
)
else:
print(
f"Found {length} files for subjects {subject_list} and tasks {task_list}."
)
return filtered_db


Expand Down Expand Up @@ -292,9 +301,18 @@ def _write_file(file, data):
elif type(data) == dict:
with open(file, "w") as f:
json.dump(data, f)
elif type(data) == bytes:
if file.endswith(".bvec") or file.endswith(".bval"):
with open(file, "wb") as f:
f.write(data)
f.close()
else:
raise ValueError(
f"Don't know how to save file {file} of type {type(data)}"
)
else:
raise ValueError(
f"Don't know how to save file {file}" f" of type {type(data)}"
f"Don't know how to save file {file} of type {type(data)}"
)

return file
Expand Down Expand Up @@ -349,6 +367,21 @@ def download_data(db, save_to=None):
pandas.DataFrame
dataframe with downloaded file names and times from the dataset
"""
# make sure the dataframe is not empty
db_length = len(db)
if db_length == 0:
raise ValueError(
f"The input dataframe is empty. Please make sure that it atleast has columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns."
)
else:
print(f"Found {db_length} files to download.")
# make sure the dataframe has dataset and path columns
db_columns = db.columns.tolist()
if "dataset" not in db_columns or "path" not in db_columns:
raise ValueError(
f"The input dataframe should have columns 'dataset' and 'path' and a row containing appropriate values corresponding to those columns."
)

# get data type from db
data_type = db["dataset"].unique()[0]
# connect to ebrains dataset
Expand All @@ -360,13 +393,20 @@ def download_data(db, save_to=None):
# track downloaded file names and times
local_db_file = os.path.join(save_to, f"downloaded_{data_type}.csv")
# download the files
for src_file, dst_file in tqdm(zip(src_file_names, dst_file_names)):
for src_file, dst_file in tqdm(
zip(src_file_names, dst_file_names),
position=1,
leave=True,
total=db_length,
desc="Overall Progress: ",
colour="green",
):
# final file path to save the data
dst_file = os.path.join(save_to, dst_file)
file_name = _download_file(src_file, dst_file, connector)
file_time = datetime.now()
local_db = _update_local_db(local_db_file, file_name, file_time)
# keep cache < 2GiB, delete oldest files first
# keep cache < 2GB, delete oldest files first
CACHE.run_maintenance()
print(
f"Downloaded requested files from IBC {data_type} dataset. See "
Expand Down

0 comments on commit d10099b

Please sign in to comment.