From 190a90f67fbe65b38f7d9bd0ae4fdbbfb4d1758c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= <99720527+eloy-encord@users.noreply.github.com> Date: Fri, 12 Apr 2024 13:09:33 +0100 Subject: [PATCH] feat: avoid refetch of downloaded encord data (#52) Use an internal project json to record the label rows that have been completely downloaded, thus avoiding any further check on their images. Also, after the first time the image files and labels are calculated they will be saved for easy access in future runs. Also, label rows' contents that are already up-to-date won't be downloaded again, even when `overwrite_annotations` is set to `True` (via last edited datetime comparison). --- clip_eval/dataset/encord.py | 86 ++++++++++++++-------- clip_eval/dataset/encord_utils.py | 116 +++++++++++++++++++++--------- 2 files changed, 141 insertions(+), 61 deletions(-) diff --git a/clip_eval/dataset/encord.py b/clip_eval/dataset/encord.py index 9513abe..b36aa97 100644 --- a/clip_eval/dataset/encord.py +++ b/clip_eval/dataset/encord.py @@ -1,6 +1,8 @@ import json import os +from dataclasses import dataclass from pathlib import Path +from typing import Any from encord import EncordUserClient from encord.objects import Classification, LabelRowV2 @@ -12,6 +14,7 @@ download_data_from_project, get_frame_file, get_label_row_annotations_file, + get_label_rows_info_file, simple_project_split, ) @@ -34,9 +37,9 @@ def __init__( self._setup(project_hash, classification_hash, ssh_key_path, **kwargs) def __getitem__(self, idx): - frame_path = self._frame_paths[idx] + frame_path = self._dataset_indices_info[idx].image_file img = Image.open(frame_path) - label = self._labels[idx] + label = self._dataset_indices_info[idx].label if self.transform is not None: _d = self.transform(dict(image=[img], label=[label])) @@ -46,9 +49,9 @@ def __getitem__(self, idx): return res_item def __len__(self): - return len(self._frame_paths) + return len(self._dataset_indices_info) - def _get_frame_path(self, label_row: LabelRowV2, frame: int) -> Path: + def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path: return get_frame_file( data_dir=self._cache_dir, project_hash=self._project.project_hash, @@ -56,13 +59,48 @@ def _get_frame_path(self, label_row: LabelRowV2, frame: int) -> Path: frame=frame, ) - def _get_label_row_annotations(self, label_row: LabelRowV2) -> Path: + def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path: return get_label_row_annotations_file( data_dir=self._cache_dir, project_hash=self._project.project_hash, label_row_hash=label_row.label_hash, ) + def _ensure_answers_availability(self) -> dict: + lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash) + label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8")) + should_update_info = False + class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices + for label_row in self._label_rows: + if "answers" not in label_rows_info[label_row.label_hash]: + if not label_row.is_labelling_initialised: + # Retrieve label row content from local storage + anns_path = self._get_label_row_annotations_file(label_row) + label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) + + answers = dict() + for frame_view in label_row.get_frame_views(): + clf_instances = frame_view.get_classification_instances(self._classification) + # Skip frames where the input classification is missing + if len(clf_instances) == 0: + continue + + clf_instance = clf_instances[0] + clf_answer = clf_instance.get_answer(self._attribute) + # Skip frames where the input classification has no answer (probable annotation error) + if clf_answer is None: + continue + + answers[frame_view.frame] = { + "image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(), + "label": class_name_to_idx[clf_answer.title], + } + label_rows_info[label_row.label_hash]["answers"] = answers + should_update_info = True + if should_update_info: + lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") + return label_rows_info + def _setup( self, project_hash: str, @@ -96,34 +134,26 @@ def _setup( else: split_to_lr_hashes = simple_project_split(self._project) splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8") - lr_hashes = split_to_lr_hashes[self.split] + self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split]) # Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything download_data_from_project( self._project, self._cache_dir, - lr_hashes, - tqdm_desc=f"Fetching {self.split} data from Encord project `{self._project.title}`", + self._label_rows, + tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`", **kwargs, ) - self._frame_paths = [] - self._labels = [] - class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices - for label_row in self._project.list_label_rows_v2(label_hashes=lr_hashes): - anns_path = self._get_label_row_annotations(label_row) - label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) - for frame_view in label_row.get_frame_views(): - clf_instances = frame_view.get_classification_instances(self._classification) - # Skip frames where the input classification is missing - if len(clf_instances) == 0: - continue - - clf_instance = clf_instances[0] - clf_answer = clf_instance.get_answer(self._attribute) - # Skip frames where the input classification has no answer (probable annotation error) - if clf_answer is None: - continue - - self._frame_paths.append(self._get_frame_path(label_row, frame_view.frame)) - self._labels.append(class_name_to_idx[clf_answer.title]) + # Prepare data for the __getitem__ method + self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = [] + label_rows_info = self._ensure_answers_availability() + for label_row in self._label_rows: + answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"] + for frame_num in sorted(answers.keys()): + self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num])) + + @dataclass + class DatasetIndexInfo: + image_file: Path | str + label: int diff --git a/clip_eval/dataset/encord_utils.py b/clip_eval/dataset/encord_utils.py index 9588252..997e514 100644 --- a/clip_eval/dataset/encord_utils.py +++ b/clip_eval/dataset/encord_utils.py @@ -1,7 +1,9 @@ import json from pathlib import Path +from typing import Any from encord import Project +from encord.common.constants import DATETIME_STRING_FORMAT from encord.orm.dataset import DataType, Image, Video from encord.project import LabelRowV2 from tqdm.auto import tqdm @@ -9,8 +11,8 @@ from .utils import Split, collect_async, download_file, simple_random_split -def download_image(image_data: Image | Video, destination_dir: Path) -> Path: - # TODO The type of `image_data` is also Video because of a SDK bug explained in `download_label_row_data`. +def _download_image(image_data: Image | Video, destination_dir: Path) -> Path: + # TODO The type of `image_data` is also Video because of a SDK bug explained in `_download_label_row_image_data`. file_name = get_frame_name(image_data.image_hash, image_data.title) destination_path = destination_dir / file_name if not destination_path.exists(): @@ -18,29 +20,11 @@ def download_image(image_data: Image | Video, destination_dir: Path) -> Path: return destination_path -def download_label_row_data( - data_dir: Path, project: Project, label_row: LabelRowV2, overwrite_annotations: bool = False -) -> list[Path]: - label_row_annotations = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash) - label_row_dir = label_row_annotations.parent +def _download_label_row_image_data(data_dir: Path, project: Project, label_row: LabelRowV2) -> list[Path]: + label_row.initialise_labels() + label_row_dir = get_label_row_dir(data_dir, project.project_hash, label_row.label_hash) label_row_dir.mkdir(parents=True, exist_ok=True) - # Download the annotations - if not label_row_annotations.exists() or overwrite_annotations: - label_row.initialise_labels() - label_row_annotations.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8") - else: - # Needed to iterate the label row's frames - label_row.from_labels_dict(json.loads(label_row_annotations.read_text(encoding="utf-8"))) - - # Download the images - is_frame_missing = any( - not get_frame_file(data_dir, project.project_hash, label_row, idx).exists() - for idx in range(label_row.number_of_frames) - ) - if not is_frame_missing: - return [label_row_dir / fv.image_title for fv in label_row.get_frame_views()] - if label_row.data_type == DataType.IMAGE: # TODO This `if` is here because of a SDK bug, remove it when IMAGE data is stored in the proper image field [1] images_data = [project.get_data(label_row.data_hash, get_signed_url=True)[0]] @@ -49,17 +33,51 @@ def download_label_row_data( else: images_data = project.get_data(label_row.data_hash, get_signed_url=True)[1] return collect_async( - lambda image_data: download_image(image_data, label_row_dir), + lambda image_data: _download_image(image_data, label_row_dir), images_data, max_workers=4, disable=True, ) +def _download_label_rows( + project: Project, + data_dir: Path, + label_rows: list[LabelRowV2], + overwrite_annotations: bool, + label_rows_info: dict[str, Any], + tqdm_desc: str | None, +): + if tqdm_desc is None: + tqdm_desc = f"Downloading data from Encord project `{project.title}`" + + for label_row in tqdm(label_rows, desc=tqdm_desc): + if label_row.data_type not in {DataType.IMAGE, DataType.IMG_GROUP}: + continue + save_annotations = False + # Trigger the images download if the label hash is not found or is None (never downloaded). + if label_row.label_hash not in label_rows_info.keys(): + _download_label_row_image_data(data_dir, project, label_row) + save_annotations = True + # Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations. + elif ( + overwrite_annotations + and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT) + != label_rows_info[label_row.label_hash]["last_edited_at"] + ): + label_row.initialise_labels() + save_annotations = True + + if save_annotations: + annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash) + annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8") + label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at} + + def download_data_from_project( project: Project, data_dir: Path, - label_hashes: list[str] | None = None, + label_rows: list[LabelRowV2] | None = None, overwrite_annotations: bool = False, tqdm_desc: str | None = None, ) -> None: @@ -79,16 +97,37 @@ def download_data_from_project( └── ... :param project: The project containing the images with their annotations. :param data_dir: The directory where the project data will be downloaded. - :param label_hashes: The hashes of the label rows that will be downloaded. If None, all label rows - will be downloaded. + :param label_rows: The label rows that will be downloaded. If None, all label rows will be downloaded. :param overwrite_annotations: Flag that indicates whether to overwrite existing annotations if they exist. + :param tqdm_desc: Optional description for tqdm progress bar. + Defaults to 'Downloading data from Encord project `{project.title}`' """ - if tqdm_desc is None: - tqdm_desc = f"Fetching data from Encord project `{project.title}`" - data_dir.mkdir(parents=True, exist_ok=True) - for label_row in tqdm(project.list_label_rows_v2(label_hashes=label_hashes), desc=tqdm_desc): - if label_row.data_type in {DataType.IMAGE, DataType.IMG_GROUP}: - download_label_row_data(data_dir, project, label_row, overwrite_annotations) + # Read file that tracks the downloaded data progress + lrs_info_file = get_label_rows_info_file(data_dir, project.project_hash) + lrs_info_file.parent.mkdir(parents=True, exist_ok=True) + label_rows_info = json.loads(lrs_info_file.read_text(encoding="utf-8")) if lrs_info_file.is_file() else dict() + + # Retrieve only the unseen data if there is no explicit annotation update + filtered_label_rows = ( + label_rows + if overwrite_annotations + else [lr for lr in label_rows if lr.label_hash not in label_rows_info.keys()] + ) + if len(filtered_label_rows) == 0: + return + + try: + _download_label_rows( + project, + data_dir, + filtered_label_rows, + overwrite_annotations, + label_rows_info, + tqdm_desc=tqdm_desc, + ) + finally: + # Save the current download progress in case of failure + lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") def get_frame_name(frame_hash: str, frame_title: str) -> str: @@ -116,6 +155,10 @@ def get_label_row_dir(data_dir: Path, project_hash: str, label_row_hash: str) -> return data_dir / project_hash / label_row_hash +def get_label_rows_info_file(data_dir: Path, project_hash: str) -> Path: + return data_dir / project_hash / "label_rows_info.json" + + def simple_project_split( project: Project, seed: int = 42, @@ -137,4 +180,11 @@ def simple_project_split( """ label_rows = project.list_label_rows_v2() split_to_indices = simple_random_split(len(label_rows), seed, train_split, validation_split) + enforce_label_rows_initialization(label_rows) # Ensure that all label rows have a label hash return {split: [label_rows[i].label_hash for i in indices] for split, indices in split_to_indices.items()} + + +def enforce_label_rows_initialization(label_rows: list[LabelRowV2]): + for lr in label_rows: + if lr.label_hash is None: + lr.initialise_labels()