From 7e86787e2c0af7f710b2b8a46e52a9c70b134fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= Date: Fri, 23 Feb 2024 16:21:45 +0100 Subject: [PATCH] fix: one source of truth for paths in Encord cache --- clip_eval/dataset/encord_dataset.py | 27 +++++++++++++++++-------- clip_eval/dataset/encord_utils.py | 31 ++++++++++++++++++++++++----- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/clip_eval/dataset/encord_dataset.py b/clip_eval/dataset/encord_dataset.py index 10bb812..e36fbe5 100644 --- a/clip_eval/dataset/encord_dataset.py +++ b/clip_eval/dataset/encord_dataset.py @@ -3,12 +3,12 @@ from pathlib import Path from encord import EncordUserClient -from encord.objects import Classification +from encord.objects import Classification, LabelRowV2 from encord.objects.common import PropertyType from PIL import Image from .dataset import Dataset -from .encord_utils import download_data_from_project +from .encord_utils import download_data_from_project, get_frame_file, get_label_row_annotations_file class EncordDataset(Dataset): @@ -39,11 +39,22 @@ def __getitem__(self, idx): def __len__(self): return len(self._frame_paths) - def _get_frame_path(self, label_row_hash: str, frame_title: str) -> Path: - return self._cache_dir / self._project.project_hash / label_row_hash / frame_title + def _get_frame_path(self, label_row: LabelRowV2, frame: int) -> Path: + frame_view = label_row.get_frame_view(frame) + return get_frame_file( + data_dir=self._cache_dir, + project_hash=self._project.project_hash, + label_row_hash=label_row.label_hash, + frame_hash=frame_view.image_hash, + frame_title=frame_view.image_title, + ) - def _get_label_row_annotations(self, label_row_hash: str) -> Path: - return self._cache_dir / self._project.project_hash / label_row_hash / "annotations.json" + def _get_label_row_annotations(self, label_row: LabelRowV2) -> Path: + return get_label_row_annotations_file( + data_dir=self._cache_dir, + project_hash=self._project.project_hash, + label_row_hash=label_row.label_hash, + ) def _setup( self, @@ -85,7 +96,7 @@ def _setup( self._frame_paths = [] self._labels = [] for label_row in self._project.list_label_rows_v2(): - anns_path = self._get_label_row_annotations(label_row.label_hash) + anns_path = self._get_label_row_annotations(label_row) label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) for frame_view in label_row.get_frame_views(): clf_instances = frame_view.get_classification_instances(self._classification) @@ -99,5 +110,5 @@ def _setup( if clf_answer is None: continue - self._frame_paths.append(self._get_frame_path(label_row.label_hash, frame_view.image_title)) + self._frame_paths.append(self._get_frame_path(label_row, frame_view.frame)) self._labels.append(clf_answer.title) diff --git a/clip_eval/dataset/encord_utils.py b/clip_eval/dataset/encord_utils.py index 31a6769..7e81dd6 100644 --- a/clip_eval/dataset/encord_utils.py +++ b/clip_eval/dataset/encord_utils.py @@ -11,20 +11,21 @@ def download_image(image_data: Image | Video, destination_dir: Path) -> Path: # TODO The type of `image_data` is also Video because of a SDK bug explained in `download_label_row_data`. - destination_path = destination_dir / image_data.title + file_name = get_frame_name(image_data.image_hash, image_data.title) + destination_path = destination_dir / file_name if not destination_path.exists(): download_file(image_data.file_link, destination_path) return destination_path def download_label_row_data( - project: Project, label_row: LabelRowV2, data_dir: Path, overwrite_annotations: bool = False + data_dir: Path, project: Project, label_row: LabelRowV2, overwrite_annotations: bool = False ) -> list[Path]: - label_row_dir = data_dir / project.project_hash / label_row.data_hash + label_row_annotations = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash) + label_row_dir = label_row_annotations.parent label_row_dir.mkdir(parents=True, exist_ok=True) # Download the annotations - label_row_annotations = label_row_dir / "annotations.json" if not label_row_annotations.exists() or overwrite_annotations: label_row.initialise_labels() label_row_annotations.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8") @@ -40,7 +41,10 @@ def download_label_row_data( if label_row.data_type == DataType.IMAGE: # TODO This `if` is here because of a SDK bug, remove it when IMAGE data is stored in the proper image field [1] images_data = [project.get_data(label_row.data_hash, get_signed_url=True)[0]] + # Missing field caused by the SDK bug + images_data[0]["image_hash"] = label_row.data_hash else: + # TODO test this images_data = project.get_data(label_row.data_hash, get_signed_url=True)[1] return collect_async( lambda image_data: download_image(image_data, label_row_dir), @@ -72,4 +76,21 @@ def download_data_from_project(project: Project, data_dir: Path, overwrite_annot data_dir.mkdir(parents=True, exist_ok=True) for label_row in tqdm(project.list_label_rows_v2(), desc=f"Downloading [{project.title}]"): if label_row.data_type in {DataType.IMAGE, DataType.IMG_GROUP}: - download_label_row_data(project, label_row, data_dir, overwrite_annotations) + download_label_row_data(data_dir, project, label_row, overwrite_annotations) + + +def get_frame_name(frame_hash: str, frame_title: str) -> str: + file_extension = frame_title.rsplit(sep=".", maxsplit=1)[-1] + return f"{frame_hash}.{file_extension}" + + +def get_frame_file(data_dir: Path, project_hash: str, label_row_hash: str, frame_hash: str, frame_title: str) -> Path: + return get_label_row_dir(data_dir, project_hash, label_row_hash) / get_frame_name(frame_hash, frame_title) + + +def get_label_row_annotations_file(data_dir: Path, project_hash: str, label_row_hash: str) -> Path: + return get_label_row_dir(data_dir, project_hash, label_row_hash) / "annotations.json" + + +def get_label_row_dir(data_dir: Path, project_hash: str, label_row_hash: str) -> Path: + return data_dir / project_hash / label_row_hash