Skip to content

Commit

Permalink
fix: one source of truth for paths in Encord cache
Browse files Browse the repository at this point in the history
  • Loading branch information
eloy-encord committed Feb 23, 2024
1 parent 84d876e commit 7e86787
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
27 changes: 19 additions & 8 deletions clip_eval/dataset/encord_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from pathlib import Path

from encord import EncordUserClient
from encord.objects import Classification
from encord.objects import Classification, LabelRowV2
from encord.objects.common import PropertyType
from PIL import Image

from .dataset import Dataset
from .encord_utils import download_data_from_project
from .encord_utils import download_data_from_project, get_frame_file, get_label_row_annotations_file


class EncordDataset(Dataset):
Expand Down Expand Up @@ -39,11 +39,22 @@ def __getitem__(self, idx):
def __len__(self):
return len(self._frame_paths)

def _get_frame_path(self, label_row_hash: str, frame_title: str) -> Path:
return self._cache_dir / self._project.project_hash / label_row_hash / frame_title
def _get_frame_path(self, label_row: LabelRowV2, frame: int) -> Path:
frame_view = label_row.get_frame_view(frame)
return get_frame_file(
data_dir=self._cache_dir,
project_hash=self._project.project_hash,
label_row_hash=label_row.label_hash,
frame_hash=frame_view.image_hash,
frame_title=frame_view.image_title,
)

def _get_label_row_annotations(self, label_row_hash: str) -> Path:
return self._cache_dir / self._project.project_hash / label_row_hash / "annotations.json"
def _get_label_row_annotations(self, label_row: LabelRowV2) -> Path:
return get_label_row_annotations_file(
data_dir=self._cache_dir,
project_hash=self._project.project_hash,
label_row_hash=label_row.label_hash,
)

def _setup(
self,
Expand Down Expand Up @@ -85,7 +96,7 @@ def _setup(
self._frame_paths = []
self._labels = []
for label_row in self._project.list_label_rows_v2():
anns_path = self._get_label_row_annotations(label_row.label_hash)
anns_path = self._get_label_row_annotations(label_row)
label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8")))
for frame_view in label_row.get_frame_views():
clf_instances = frame_view.get_classification_instances(self._classification)
Expand All @@ -99,5 +110,5 @@ def _setup(
if clf_answer is None:
continue

self._frame_paths.append(self._get_frame_path(label_row.label_hash, frame_view.image_title))
self._frame_paths.append(self._get_frame_path(label_row, frame_view.frame))
self._labels.append(clf_answer.title)
31 changes: 26 additions & 5 deletions clip_eval/dataset/encord_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@

def download_image(image_data: Image | Video, destination_dir: Path) -> Path:
# TODO The type of `image_data` is also Video because of a SDK bug explained in `download_label_row_data`.
destination_path = destination_dir / image_data.title
file_name = get_frame_name(image_data.image_hash, image_data.title)
destination_path = destination_dir / file_name
if not destination_path.exists():
download_file(image_data.file_link, destination_path)
return destination_path


def download_label_row_data(
project: Project, label_row: LabelRowV2, data_dir: Path, overwrite_annotations: bool = False
data_dir: Path, project: Project, label_row: LabelRowV2, overwrite_annotations: bool = False
) -> list[Path]:
label_row_dir = data_dir / project.project_hash / label_row.data_hash
label_row_annotations = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
label_row_dir = label_row_annotations.parent
label_row_dir.mkdir(parents=True, exist_ok=True)

# Download the annotations
label_row_annotations = label_row_dir / "annotations.json"
if not label_row_annotations.exists() or overwrite_annotations:
label_row.initialise_labels()
label_row_annotations.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
Expand All @@ -40,7 +41,10 @@ def download_label_row_data(
if label_row.data_type == DataType.IMAGE:
# TODO This `if` is here because of a SDK bug, remove it when IMAGE data is stored in the proper image field [1]
images_data = [project.get_data(label_row.data_hash, get_signed_url=True)[0]]
# Missing field caused by the SDK bug
images_data[0]["image_hash"] = label_row.data_hash
else:
# TODO test this
images_data = project.get_data(label_row.data_hash, get_signed_url=True)[1]
return collect_async(
lambda image_data: download_image(image_data, label_row_dir),
Expand Down Expand Up @@ -72,4 +76,21 @@ def download_data_from_project(project: Project, data_dir: Path, overwrite_annot
data_dir.mkdir(parents=True, exist_ok=True)
for label_row in tqdm(project.list_label_rows_v2(), desc=f"Downloading [{project.title}]"):
if label_row.data_type in {DataType.IMAGE, DataType.IMG_GROUP}:
download_label_row_data(project, label_row, data_dir, overwrite_annotations)
download_label_row_data(data_dir, project, label_row, overwrite_annotations)


def get_frame_name(frame_hash: str, frame_title: str) -> str:
file_extension = frame_title.rsplit(sep=".", maxsplit=1)[-1]
return f"{frame_hash}.{file_extension}"


def get_frame_file(data_dir: Path, project_hash: str, label_row_hash: str, frame_hash: str, frame_title: str) -> Path:
return get_label_row_dir(data_dir, project_hash, label_row_hash) / get_frame_name(frame_hash, frame_title)


def get_label_row_annotations_file(data_dir: Path, project_hash: str, label_row_hash: str) -> Path:
return get_label_row_dir(data_dir, project_hash, label_row_hash) / "annotations.json"


def get_label_row_dir(data_dir: Path, project_hash: str, label_row_hash: str) -> Path:
return data_dir / project_hash / label_row_hash

0 comments on commit 7e86787

Please sign in to comment.