Skip to content

Commit

Permalink
chore: Improve download speed of Encord Datasets (#84)
Browse files Browse the repository at this point in the history
Co-authored-by: Frederik Hvilshøj <[email protected]>
  • Loading branch information
eloy-encord and frederik-encord authored Jul 15, 2024
1 parent 4bd2cae commit cd14be8
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 32 deletions.
4 changes: 1 addition & 3 deletions tti_eval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
# If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root.
_TTI_EVAL_ROOT_DIR = Path(__file__).parent.parent
CACHE_PATH = Path(os.environ.get("TTI_EVAL_CACHE_PATH", _TTI_EVAL_ROOT_DIR / ".cache"))
_OUTPUT_PATH = Path(
os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output")
)
_OUTPUT_PATH = Path(os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output"))
_SOURCES_PATH = _TTI_EVAL_ROOT_DIR / "sources"


Expand Down
82 changes: 59 additions & 23 deletions tti_eval/dataset/types/encord_ds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
import multiprocessing
import os
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -30,7 +34,13 @@ def __init__(
ssh_key_path: str | None = None,
**kwargs,
):
super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir)
super().__init__(
title,
split=split,
title_in_source=title_in_source,
transform=transform,
cache_dir=cache_dir,
)
self._setup(project_hash, classification_hash, ssh_key_path, **kwargs)

def __getitem__(self, idx):
Expand Down Expand Up @@ -191,6 +201,37 @@ def _download_label_row_image_data(data_dir: Path, project: Project, label_row:
)


def _download_label_row(
label_row: LabelRowV2,
project: Project,
data_dir: Path,
overwrite_annotations: bool,
label_rows_info: dict[str, Any],
update_pbar: Callable[[], Any],
):
if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}:
return
save_annotations = False
# Trigger the images download if the label hash is not found or is None (never downloaded).
if label_row.label_hash not in label_rows_info.keys():
_download_label_row_image_data(data_dir, project, label_row)
save_annotations = True
# Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations.
elif (
overwrite_annotations
and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT)
!= label_rows_info[label_row.label_hash]["last_edited_at"]
):
label_row.initialise_labels()
save_annotations = True

if save_annotations:
annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}
update_pbar()


def _download_label_rows(
project: Project,
data_dir: Path,
Expand All @@ -202,27 +243,18 @@ def _download_label_rows(
if tqdm_desc is None:
tqdm_desc = f"Downloading data from Encord project `{project.title}`"

for label_row in tqdm(label_rows, desc=tqdm_desc):
if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}:
continue
save_annotations = False
# Trigger the images download if the label hash is not found or is None (never downloaded).
if label_row.label_hash not in label_rows_info.keys():
_download_label_row_image_data(data_dir, project, label_row)
save_annotations = True
# Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations.
elif (
overwrite_annotations
and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT)
!= label_rows_info[label_row.label_hash]["last_edited_at"]
):
label_row.initialise_labels()
save_annotations = True

if save_annotations:
annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}
pbar = tqdm(total=len(label_rows), desc=tqdm_desc)
_do_download = partial(
_download_label_row,
project=project,
data_dir=data_dir,
overwrite_annotations=overwrite_annotations,
label_rows_info=label_rows_info,
update_pbar=lambda: pbar.update(1),
)

with ThreadPoolExecutor(min(multiprocessing.cpu_count(), 24)) as exe:
exe.map(_do_download, label_rows)


def download_data_from_project(
Expand Down Expand Up @@ -293,7 +325,11 @@ def get_frame_file(data_dir: Path, project_hash: str, label_row: LabelRowV2, fra


def get_frame_file_raw(
data_dir: Path, project_hash: str, label_row_hash: str, frame_hash: str, frame_title: str
data_dir: Path,
project_hash: str,
label_row_hash: str,
frame_hash: str,
frame_title: str,
) -> Path:
return get_label_row_dir(data_dir, project_hash, label_row_hash) / get_frame_name(frame_hash, frame_title)

Expand Down
4 changes: 1 addition & 3 deletions tti_eval/evaluation/image_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def evaluate(self) -> float:

# To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved,
# where Q represents the size of the respective class in the validation embeddings
top_nearest_per_class = np.where(
self._class_counts < self.k, self._class_counts, self.k
)
top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k)
top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels]

# Add a placeholder value for indices outside the retrieval scope
Expand Down
4 changes: 1 addition & 3 deletions tti_eval/evaluation/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def predict(self) -> tuple[ProbabilityArray, ClassArray]:
# Calculate class votes from the distances (avoiding division by zero)
# Note: Values stored in `dists` are the squared 2-norm values of the respective distance vectors
max_value = np.finfo(np.float32).max
scores = np.divide(
1, dists, out=np.full_like(dists, max_value), where=dists != 0
)
scores = np.divide(1, dists, out=np.full_like(dists, max_value), where=dists != 0)
# NOTE: if self.k and self.num_classes are both large, this might become a big one.
# We can shape of a factor self.k if we count differently here.
n = len(self._val_embeddings.images)
Expand Down

0 comments on commit cd14be8

Please sign in to comment.