Skip to content

Commit

Permalink
feat: avoid refetch of downloaded encord data (#52)
Browse files Browse the repository at this point in the history
Use an internal project json to record the label rows that have been
completely downloaded, thus avoiding any further check on their images.
Also, after the first time the image files and labels are calculated
they will be saved for easy access in future runs.

Also, label rows' contents that are already up-to-date won't be
downloaded again, even when `overwrite_annotations` is set to `True`
(via last edited datetime comparison).
  • Loading branch information
eloy-encord authored Apr 12, 2024
1 parent 5b53023 commit 190a90f
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 61 deletions.
86 changes: 58 additions & 28 deletions clip_eval/dataset/encord.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from encord import EncordUserClient
from encord.objects import Classification, LabelRowV2
Expand All @@ -12,6 +14,7 @@
download_data_from_project,
get_frame_file,
get_label_row_annotations_file,
get_label_rows_info_file,
simple_project_split,
)

Expand All @@ -34,9 +37,9 @@ def __init__(
self._setup(project_hash, classification_hash, ssh_key_path, **kwargs)

def __getitem__(self, idx):
frame_path = self._frame_paths[idx]
frame_path = self._dataset_indices_info[idx].image_file
img = Image.open(frame_path)
label = self._labels[idx]
label = self._dataset_indices_info[idx].label

if self.transform is not None:
_d = self.transform(dict(image=[img], label=[label]))
Expand All @@ -46,23 +49,58 @@ def __getitem__(self, idx):
return res_item

def __len__(self):
return len(self._frame_paths)
return len(self._dataset_indices_info)

def _get_frame_path(self, label_row: LabelRowV2, frame: int) -> Path:
def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path:
return get_frame_file(
data_dir=self._cache_dir,
project_hash=self._project.project_hash,
label_row=label_row,
frame=frame,
)

def _get_label_row_annotations(self, label_row: LabelRowV2) -> Path:
def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path:
return get_label_row_annotations_file(
data_dir=self._cache_dir,
project_hash=self._project.project_hash,
label_row_hash=label_row.label_hash,
)

def _ensure_answers_availability(self) -> dict:
lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash)
label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8"))
should_update_info = False
class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices
for label_row in self._label_rows:
if "answers" not in label_rows_info[label_row.label_hash]:
if not label_row.is_labelling_initialised:
# Retrieve label row content from local storage
anns_path = self._get_label_row_annotations_file(label_row)
label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8")))

answers = dict()
for frame_view in label_row.get_frame_views():
clf_instances = frame_view.get_classification_instances(self._classification)
# Skip frames where the input classification is missing
if len(clf_instances) == 0:
continue

clf_instance = clf_instances[0]
clf_answer = clf_instance.get_answer(self._attribute)
# Skip frames where the input classification has no answer (probable annotation error)
if clf_answer is None:
continue

answers[frame_view.frame] = {
"image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(),
"label": class_name_to_idx[clf_answer.title],
}
label_rows_info[label_row.label_hash]["answers"] = answers
should_update_info = True
if should_update_info:
lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8")
return label_rows_info

def _setup(
self,
project_hash: str,
Expand Down Expand Up @@ -96,34 +134,26 @@ def _setup(
else:
split_to_lr_hashes = simple_project_split(self._project)
splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8")
lr_hashes = split_to_lr_hashes[self.split]
self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split])

# Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything
download_data_from_project(
self._project,
self._cache_dir,
lr_hashes,
tqdm_desc=f"Fetching {self.split} data from Encord project `{self._project.title}`",
self._label_rows,
tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`",
**kwargs,
)

self._frame_paths = []
self._labels = []
class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices
for label_row in self._project.list_label_rows_v2(label_hashes=lr_hashes):
anns_path = self._get_label_row_annotations(label_row)
label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8")))
for frame_view in label_row.get_frame_views():
clf_instances = frame_view.get_classification_instances(self._classification)
# Skip frames where the input classification is missing
if len(clf_instances) == 0:
continue

clf_instance = clf_instances[0]
clf_answer = clf_instance.get_answer(self._attribute)
# Skip frames where the input classification has no answer (probable annotation error)
if clf_answer is None:
continue

self._frame_paths.append(self._get_frame_path(label_row, frame_view.frame))
self._labels.append(class_name_to_idx[clf_answer.title])
# Prepare data for the __getitem__ method
self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = []
label_rows_info = self._ensure_answers_availability()
for label_row in self._label_rows:
answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"]
for frame_num in sorted(answers.keys()):
self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num]))

@dataclass
class DatasetIndexInfo:
image_file: Path | str
label: int
116 changes: 83 additions & 33 deletions clip_eval/dataset/encord_utils.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,30 @@
import json
from pathlib import Path
from typing import Any

from encord import Project
from encord.common.constants import DATETIME_STRING_FORMAT
from encord.orm.dataset import DataType, Image, Video
from encord.project import LabelRowV2
from tqdm.auto import tqdm

from .utils import Split, collect_async, download_file, simple_random_split


def download_image(image_data: Image | Video, destination_dir: Path) -> Path:
# TODO The type of `image_data` is also Video because of a SDK bug explained in `download_label_row_data`.
def _download_image(image_data: Image | Video, destination_dir: Path) -> Path:
# TODO The type of `image_data` is also Video because of a SDK bug explained in `_download_label_row_image_data`.
file_name = get_frame_name(image_data.image_hash, image_data.title)
destination_path = destination_dir / file_name
if not destination_path.exists():
download_file(image_data.file_link, destination_path)
return destination_path


def download_label_row_data(
data_dir: Path, project: Project, label_row: LabelRowV2, overwrite_annotations: bool = False
) -> list[Path]:
label_row_annotations = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
label_row_dir = label_row_annotations.parent
def _download_label_row_image_data(data_dir: Path, project: Project, label_row: LabelRowV2) -> list[Path]:
label_row.initialise_labels()
label_row_dir = get_label_row_dir(data_dir, project.project_hash, label_row.label_hash)
label_row_dir.mkdir(parents=True, exist_ok=True)

# Download the annotations
if not label_row_annotations.exists() or overwrite_annotations:
label_row.initialise_labels()
label_row_annotations.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
else:
# Needed to iterate the label row's frames
label_row.from_labels_dict(json.loads(label_row_annotations.read_text(encoding="utf-8")))

# Download the images
is_frame_missing = any(
not get_frame_file(data_dir, project.project_hash, label_row, idx).exists()
for idx in range(label_row.number_of_frames)
)
if not is_frame_missing:
return [label_row_dir / fv.image_title for fv in label_row.get_frame_views()]

if label_row.data_type == DataType.IMAGE:
# TODO This `if` is here because of a SDK bug, remove it when IMAGE data is stored in the proper image field [1]
images_data = [project.get_data(label_row.data_hash, get_signed_url=True)[0]]
Expand All @@ -49,17 +33,51 @@ def download_label_row_data(
else:
images_data = project.get_data(label_row.data_hash, get_signed_url=True)[1]
return collect_async(
lambda image_data: download_image(image_data, label_row_dir),
lambda image_data: _download_image(image_data, label_row_dir),
images_data,
max_workers=4,
disable=True,
)


def _download_label_rows(
project: Project,
data_dir: Path,
label_rows: list[LabelRowV2],
overwrite_annotations: bool,
label_rows_info: dict[str, Any],
tqdm_desc: str | None,
):
if tqdm_desc is None:
tqdm_desc = f"Downloading data from Encord project `{project.title}`"

for label_row in tqdm(label_rows, desc=tqdm_desc):
if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}:
continue
save_annotations = False
# Trigger the images download if the label hash is not found or is None (never downloaded).
if label_row.label_hash not in label_rows_info.keys():
_download_label_row_image_data(data_dir, project, label_row)
save_annotations = True
# Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations.
elif (
overwrite_annotations
and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT)
!= label_rows_info[label_row.label_hash]["last_edited_at"]
):
label_row.initialise_labels()
save_annotations = True

if save_annotations:
annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}


def download_data_from_project(
project: Project,
data_dir: Path,
label_hashes: list[str] | None = None,
label_rows: list[LabelRowV2] | None = None,
overwrite_annotations: bool = False,
tqdm_desc: str | None = None,
) -> None:
Expand All @@ -79,16 +97,37 @@ def download_data_from_project(
└── ...
:param project: The project containing the images with their annotations.
:param data_dir: The directory where the project data will be downloaded.
:param label_hashes: The hashes of the label rows that will be downloaded. If None, all label rows
will be downloaded.
:param label_rows: The label rows that will be downloaded. If None, all label rows will be downloaded.
:param overwrite_annotations: Flag that indicates whether to overwrite existing annotations if they exist.
:param tqdm_desc: Optional description for tqdm progress bar.
Defaults to 'Downloading data from Encord project `{project.title}`'
"""
if tqdm_desc is None:
tqdm_desc = f"Fetching data from Encord project `{project.title}`"
data_dir.mkdir(parents=True, exist_ok=True)
for label_row in tqdm(project.list_label_rows_v2(label_hashes=label_hashes), desc=tqdm_desc):
if label_row.data_type in {DataType.IMAGE, DataType.IMG_GROUP}:
download_label_row_data(data_dir, project, label_row, overwrite_annotations)
# Read file that tracks the downloaded data progress
lrs_info_file = get_label_rows_info_file(data_dir, project.project_hash)
lrs_info_file.parent.mkdir(parents=True, exist_ok=True)
label_rows_info = json.loads(lrs_info_file.read_text(encoding="utf-8")) if lrs_info_file.is_file() else dict()

# Retrieve only the unseen data if there is no explicit annotation update
filtered_label_rows = (
label_rows
if overwrite_annotations
else [lr for lr in label_rows if lr.label_hash not in label_rows_info.keys()]
)
if len(filtered_label_rows) == 0:
return

try:
_download_label_rows(
project,
data_dir,
filtered_label_rows,
overwrite_annotations,
label_rows_info,
tqdm_desc=tqdm_desc,
)
finally:
# Save the current download progress in case of failure
lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8")


def get_frame_name(frame_hash: str, frame_title: str) -> str:
Expand Down Expand Up @@ -116,6 +155,10 @@ def get_label_row_dir(data_dir: Path, project_hash: str, label_row_hash: str) ->
return data_dir / project_hash / label_row_hash


def get_label_rows_info_file(data_dir: Path, project_hash: str) -> Path:
return data_dir / project_hash / "label_rows_info.json"


def simple_project_split(
project: Project,
seed: int = 42,
Expand All @@ -137,4 +180,11 @@ def simple_project_split(
"""
label_rows = project.list_label_rows_v2()
split_to_indices = simple_random_split(len(label_rows), seed, train_split, validation_split)
enforce_label_rows_initialization(label_rows) # Ensure that all label rows have a label hash
return {split: [label_rows[i].label_hash for i in indices] for split, indices in split_to_indices.items()}


def enforce_label_rows_initialization(label_rows: list[LabelRowV2]):
for lr in label_rows:
if lr.label_hash is None:
lr.initialise_labels()

0 comments on commit 190a90f

Please sign in to comment.