diff --git a/clip_eval/dataset/encord_dataset.py b/clip_eval/dataset/encord_dataset.py index e36fbe5..bd72923 100644 --- a/clip_eval/dataset/encord_dataset.py +++ b/clip_eval/dataset/encord_dataset.py @@ -32,9 +32,12 @@ def __getitem__(self, idx): img = Image.open(frame_path) label = self._labels[idx] - if self.transform: - img, label = self.transform(img, label) - return img, label + if self.transform is not None: + _d = self.transform(dict(image=[img], label=[label])) + res_item = dict(image=_d["image"][0], label=_d["label"][0]) + else: + res_item = dict(image=img, label=label) + return res_item def __len__(self): return len(self._frame_paths) @@ -95,6 +98,7 @@ def _setup( self._frame_paths = [] self._labels = [] + class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices for label_row in self._project.list_label_rows_v2(): anns_path = self._get_label_row_annotations(label_row) label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) @@ -111,4 +115,4 @@ def _setup( continue self._frame_paths.append(self._get_frame_path(label_row, frame_view.frame)) - self._labels.append(clf_answer.title) + self._labels.append(class_name_to_idx[clf_answer.title])