Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Improve download speed of Encord Datasets #84

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tti_eval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
# If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root.
_TTI_EVAL_ROOT_DIR = Path(__file__).parent.parent
CACHE_PATH = Path(os.environ.get("TTI_EVAL_CACHE_PATH", _TTI_EVAL_ROOT_DIR / ".cache"))
_OUTPUT_PATH = Path(
os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output")
)
_OUTPUT_PATH = Path(os.environ.get("TTI_EVAL_OUTPUT_PATH", _TTI_EVAL_ROOT_DIR / "output"))
_SOURCES_PATH = _TTI_EVAL_ROOT_DIR / "sources"


Expand Down
82 changes: 59 additions & 23 deletions tti_eval/dataset/types/encord_ds.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
import multiprocessing
import os
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -30,7 +34,13 @@ def __init__(
ssh_key_path: str | None = None,
**kwargs,
):
super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir)
super().__init__(
title,
split=split,
title_in_source=title_in_source,
transform=transform,
cache_dir=cache_dir,
)
self._setup(project_hash, classification_hash, ssh_key_path, **kwargs)

def __getitem__(self, idx):
Expand Down Expand Up @@ -191,6 +201,37 @@ def _download_label_row_image_data(data_dir: Path, project: Project, label_row:
)


def _download_label_row(
label_row: LabelRowV2,
project: Project,
data_dir: Path,
overwrite_annotations: bool,
label_rows_info: dict[str, Any],
update_pbar: Callable[[], Any],
):
if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}:
return
save_annotations = False
# Trigger the images download if the label hash is not found or is None (never downloaded).
if label_row.label_hash not in label_rows_info.keys():
_download_label_row_image_data(data_dir, project, label_row)
save_annotations = True
# Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations.
elif (
overwrite_annotations
and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT)
!= label_rows_info[label_row.label_hash]["last_edited_at"]
):
label_row.initialise_labels()
save_annotations = True

if save_annotations:
annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}
update_pbar()


def _download_label_rows(
project: Project,
data_dir: Path,
Expand All @@ -202,27 +243,18 @@ def _download_label_rows(
if tqdm_desc is None:
tqdm_desc = f"Downloading data from Encord project `{project.title}`"

for label_row in tqdm(label_rows, desc=tqdm_desc):
if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}:
continue
save_annotations = False
# Trigger the images download if the label hash is not found or is None (never downloaded).
if label_row.label_hash not in label_rows_info.keys():
_download_label_row_image_data(data_dir, project, label_row)
save_annotations = True
# Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations.
elif (
overwrite_annotations
and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT)
!= label_rows_info[label_row.label_hash]["last_edited_at"]
):
label_row.initialise_labels()
save_annotations = True

if save_annotations:
annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}
pbar = tqdm(total=len(label_rows), desc=tqdm_desc)
_do_download = partial(
_download_label_row,
project=project,
data_dir=data_dir,
overwrite_annotations=overwrite_annotations,
label_rows_info=label_rows_info,
update_pbar=lambda: pbar.update(1),
)

with ThreadPoolExecutor(min(multiprocessing.cpu_count(), 24)) as exe:
exe.map(_do_download, label_rows)


def download_data_from_project(
Expand Down Expand Up @@ -293,7 +325,11 @@ def get_frame_file(data_dir: Path, project_hash: str, label_row: LabelRowV2, fra


def get_frame_file_raw(
data_dir: Path, project_hash: str, label_row_hash: str, frame_hash: str, frame_title: str
data_dir: Path,
project_hash: str,
label_row_hash: str,
frame_hash: str,
frame_title: str,
) -> Path:
return get_label_row_dir(data_dir, project_hash, label_row_hash) / get_frame_name(frame_hash, frame_title)

Expand Down
4 changes: 1 addition & 3 deletions tti_eval/evaluation/image_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def evaluate(self) -> float:

# To compute retrieval accuracy, we ensure that a maximum of Q elements per sample are retrieved,
# where Q represents the size of the respective class in the validation embeddings
top_nearest_per_class = np.where(
self._class_counts < self.k, self._class_counts, self.k
)
top_nearest_per_class = np.where(self._class_counts < self.k, self._class_counts, self.k)
top_nearest_per_sample = top_nearest_per_class[self._train_embeddings.labels]

# Add a placeholder value for indices outside the retrieval scope
Expand Down
4 changes: 1 addition & 3 deletions tti_eval/evaluation/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def predict(self) -> tuple[ProbabilityArray, ClassArray]:
# Calculate class votes from the distances (avoiding division by zero)
# Note: Values stored in `dists` are the squared 2-norm values of the respective distance vectors
max_value = np.finfo(np.float32).max
scores = np.divide(
1, dists, out=np.full_like(dists, max_value), where=dists != 0
)
scores = np.divide(1, dists, out=np.full_like(dists, max_value), where=dists != 0)
# NOTE: if self.k and self.num_classes are both large, this might become a big one.
# We can shape of a factor self.k if we count differently here.
n = len(self._val_embeddings.images)
Expand Down
Loading