Skip to content

Commit

Permalink
chore: save classification answers for future runs
Browse files Browse the repository at this point in the history
  • Loading branch information
eloy-encord committed Apr 9, 2024
1 parent 77b799d commit 12c18dc
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 39 deletions.
78 changes: 54 additions & 24 deletions clip_eval/dataset/encord.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from encord import EncordUserClient
from encord.objects import Classification, LabelRowV2
Expand All @@ -12,6 +14,7 @@
download_data_from_project,
get_frame_file,
get_label_row_annotations_file,
get_label_rows_info_file,
simple_project_split,
)

Expand All @@ -34,9 +37,9 @@ def __init__(
self._setup(project_hash, classification_hash, ssh_key_path, **kwargs)

def __getitem__(self, idx):
frame_path = self._frame_paths[idx]
frame_path = self._dataset_indices_info[idx].image_file
img = Image.open(frame_path)
label = self._labels[idx]
label = self._dataset_indices_info[idx].label

if self.transform is not None:
_d = self.transform(dict(image=[img], label=[label]))
Expand All @@ -46,23 +49,58 @@ def __getitem__(self, idx):
return res_item

def __len__(self):
return len(self._frame_paths)
return len(self._dataset_indices_info)

def _get_frame_path(self, label_row: LabelRowV2, frame: int) -> Path:
def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path:
return get_frame_file(
data_dir=self._cache_dir,
project_hash=self._project.project_hash,
label_row=label_row,
frame=frame,
)

def _get_label_row_annotations(self, label_row: LabelRowV2) -> Path:
def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path:
return get_label_row_annotations_file(
data_dir=self._cache_dir,
project_hash=self._project.project_hash,
label_row_hash=label_row.label_hash,
)

def _ensure_answers_availability(self) -> dict:
lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash)
label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8"))
should_update_info = False
class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices
for label_row in self._label_rows:
if "answers" not in label_rows_info[label_row.label_hash]:
if not label_row.is_labelling_initialised:
# Retrieve label row content from local storage
anns_path = self._get_label_row_annotations_file(label_row)
label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8")))

answers = dict()
for frame_view in label_row.get_frame_views():
clf_instances = frame_view.get_classification_instances(self._classification)
# Skip frames where the input classification is missing
if len(clf_instances) == 0:
continue

clf_instance = clf_instances[0]
clf_answer = clf_instance.get_answer(self._attribute)
# Skip frames where the input classification has no answer (probable annotation error)
if clf_answer is None:
continue

answers[frame_view.frame] = {
"image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(),
"label": class_name_to_idx[clf_answer.title],
}
label_rows_info[label_row.label_hash]["answers"] = answers
should_update_info = True
if should_update_info:
lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8")
return label_rows_info

def _setup(
self,
project_hash: str,
Expand Down Expand Up @@ -107,23 +145,15 @@ def _setup(
**kwargs,
)

self._frame_paths = []
self._labels = []
class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices
# Prepare data for the __getitem__ method
self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = []
label_rows_info = self._ensure_answers_availability()
for label_row in self._label_rows:
anns_path = self._get_label_row_annotations(label_row)
label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8")))
for frame_view in label_row.get_frame_views():
clf_instances = frame_view.get_classification_instances(self._classification)
# Skip frames where the input classification is missing
if len(clf_instances) == 0:
continue

clf_instance = clf_instances[0]
clf_answer = clf_instance.get_answer(self._attribute)
# Skip frames where the input classification has no answer (probable annotation error)
if clf_answer is None:
continue

self._frame_paths.append(self._get_frame_path(label_row, frame_view.frame))
self._labels.append(class_name_to_idx[clf_answer.title])
answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"]
for frame_num in sorted(answers.keys()):
self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num]))

@dataclass
class DatasetIndexInfo:
image_file: Path | str
label: int
31 changes: 16 additions & 15 deletions clip_eval/dataset/encord_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from pathlib import Path
from typing import Any

from encord import Project
from encord.common.constants import DATETIME_STRING_FORMAT
Expand Down Expand Up @@ -44,7 +45,7 @@ def _download_label_rows(
data_dir: Path,
label_rows: list[LabelRowV2],
overwrite_annotations: bool,
downloaded_label_rows_tracker: dict,
label_rows_info: dict[str, Any],
tqdm_desc: str | None,
):
if tqdm_desc is None:
Expand All @@ -55,22 +56,22 @@ def _download_label_rows(
continue
save_annotations = False
# Trigger the images download if the label hash is not found or is None (never downloaded).
if label_row.label_hash not in downloaded_label_rows_tracker.keys():
if label_row.label_hash not in label_rows_info.keys():
_download_label_row_image_data(data_dir, project, label_row)
save_annotations = True
# Overwrite annotations only if `last_edited_at` values differ between the existing and new annotations.
elif (
overwrite_annotations
and label_row.last_edited_at.strftime(DATETIME_STRING_FORMAT)
!= downloaded_label_rows_tracker[label_row.label_hash]["last_edited_at"]
!= label_rows_info[label_row.label_hash]["last_edited_at"]
):
label_row.initialise_labels()
save_annotations = True

if save_annotations:
annotations_file = get_label_row_annotations_file(data_dir, project.project_hash, label_row.label_hash)
annotations_file.write_text(json.dumps(label_row.to_encord_dict()), encoding="utf-8")
downloaded_label_rows_tracker[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}
label_rows_info[label_row.label_hash] = {"last_edited_at": label_row.last_edited_at}


def download_data_from_project(
Expand Down Expand Up @@ -101,20 +102,16 @@ def download_data_from_project(
:param tqdm_desc: Optional description for tqdm progress bar.
Defaults to 'Downloading data from Encord project `{project.title}`'
"""
# Read internal file that controls the downloaded data progress
downloaded_label_rows_tracker_file = data_dir / project.project_hash / "label_rows.json"
downloaded_label_rows_tracker_file.parent.mkdir(parents=True, exist_ok=True)
downloaded_label_rows_tracker = (
json.loads(downloaded_label_rows_tracker_file.read_text(encoding="utf-8"))
if downloaded_label_rows_tracker_file.is_file()
else dict()
)
# Read file that tracks the downloaded data progress
lrs_info_file = get_label_rows_info_file(data_dir, project.project_hash)
lrs_info_file.parent.mkdir(parents=True, exist_ok=True)
label_rows_info = json.loads(lrs_info_file.read_text(encoding="utf-8")) if lrs_info_file.is_file() else dict()

# Retrieve only the unseen data if there is no explicit annotation update
filtered_label_rows = (
label_rows
if overwrite_annotations
else [lr for lr in label_rows if lr.label_hash not in downloaded_label_rows_tracker.keys()]
else [lr for lr in label_rows if lr.label_hash not in label_rows_info.keys()]
)
if len(filtered_label_rows) == 0:
return
Expand All @@ -125,12 +122,12 @@ def download_data_from_project(
data_dir,
filtered_label_rows,
overwrite_annotations,
downloaded_label_rows_tracker,
label_rows_info,
tqdm_desc=tqdm_desc,
)
finally:
# Save the current download progress in case of failure
downloaded_label_rows_tracker_file.write_text(json.dumps(downloaded_label_rows_tracker), encoding="utf-8")
lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8")


def get_frame_name(frame_hash: str, frame_title: str) -> str:
Expand Down Expand Up @@ -158,6 +155,10 @@ def get_label_row_dir(data_dir: Path, project_hash: str, label_row_hash: str) ->
return data_dir / project_hash / label_row_hash


def get_label_rows_info_file(data_dir: Path, project_hash: str) -> Path:
return data_dir / project_hash / "label_rows_info.json"


def simple_project_split(
project: Project,
seed: int = 42,
Expand Down

0 comments on commit 12c18dc

Please sign in to comment.