From 08562e36ab927398f2d7340cf7fe164fa94bfe3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frederik=20Hvilsh=C3=B8j?= <93145535+frederik-encord@users.noreply.github.com> Date: Wed, 31 Jan 2024 11:00:00 +0100 Subject: [PATCH] fix: pytorch dataset proliferate dataset properties (#685) --- src/encord_active/public/dataset.py | 50 +++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/src/encord_active/public/dataset.py b/src/encord_active/public/dataset.py index 3311802d5..4d6c5c05c 100644 --- a/src/encord_active/public/dataset.py +++ b/src/encord_active/public/dataset.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from pathlib import Path from typing import Optional, Union from uuid import UUID @@ -17,6 +18,7 @@ get_engine, ) from encord_active.lib.common.data_utils import url_to_file_path +from encord_active.lib.db.predictions import FrameClassification P = Project T = ProjectTaggedDataUnit @@ -24,6 +26,12 @@ L = ProjectDataMetadata +@dataclass +class DataIdentifier: + data_hash: UUID + frame: int + + class ActiveDataset(Dataset): def __init__( self, @@ -143,6 +151,7 @@ def __init__( ontology_hashes: Optional[list[str]] = None, transform=None, target_transform=None, + return_meta: bool = False, ): """ A dataset hooked up to an Encord Active database. @@ -185,6 +194,7 @@ def __init__( transform, target_transform, ) + self.return_meta = return_meta def setup(self): super().setup() @@ -196,7 +206,7 @@ def setup(self): ) identifier_query = identifier_query.add_columns(D.data_uri, D.classifications, L.label_row_json) - identifiers = sess.exec(identifier_query).all() + self.identifiers = sess.exec(identifier_query).all() ontology_pairs = [ (c, a) for c in self.ontology.classifications @@ -206,9 +216,10 @@ def setup(self): ] if len(ontology_pairs) == 0: raise ValueError("No ontology classifications were found to use for labels") - classification, attribute = ontology_pairs[0] - indices = {o.feature_node_hash: i for i, o in enumerate(attribute.options)} - self.class_names = [o.title for o in attribute.options] + self.classification, self.attribute = ontology_pairs[0] + self.option_indices = {o.feature_node_hash: i for i, o in enumerate(self.attribute.options)} + self._inv_option_indices = {v: k for k, v in self.option_indices.items()} + self.class_names = [o.title for o in self.attribute.options] self.uris = [] self.labels = [] @@ -217,11 +228,11 @@ def setup(self): data_uri, classifications, label_row_json, - ) in identifiers: + ) in self.identifiers: classification_answers = label_row_json["classification_answers"] clf_instance = next( - (c for c in classifications if c["featureHash"] == classification.feature_node_hash), + (c for c in classifications if c["featureHash"] == self.classification.feature_node_hash), None, ) if clf_instance is None: @@ -229,20 +240,27 @@ def setup(self): clf_hash = clf_instance["classificationHash"] clf_classifications = classification_answers[clf_hash]["classifications"] clf_answers = next( - (a for a in clf_classifications if a["featureHash"] == attribute.feature_node_hash), + (a for a in clf_classifications if a["featureHash"] == self.attribute.feature_node_hash), None, ) if clf_answers is None: continue clf_opt = next( - (o for o in clf_answers["answers"] if o["featureHash"] in indices), + (o for o in clf_answers["answers"] if o["featureHash"] in self.option_indices), None, ) if clf_opt is None: continue self.uris.append(url_to_file_path(data_uri, self.root_path)) # type: ignore - self.labels.append(indices[clf_opt["featureHash"]]) + self.labels.append(self.option_indices[clf_opt["featureHash"]]) + + def get_frame_classification(self, label_index: int) -> FrameClassification: + fh = self.classification.feature_node_hash + ah = self.attribute.feature_node_hash + return FrameClassification( + feature_hash=fh, attribute_hash=ah, option_hash=self._inv_option_indices[label_index] + ) def __getitem__(self, idx): data_uri = self.uris[idx] @@ -256,7 +274,9 @@ def __getitem__(self, idx): if self.target_transform: label = self.target_transform(label) - return img, label + if not self.return_meta: + return img, label + return img, label, self.identifiers[idx][:2] def __len__(self): return len(self.labels) @@ -271,6 +291,7 @@ def __init__( ontology_hashes: Optional[list[str]] = None, transform=None, target_transform=None, + return_meta: bool = False, ): """ A dataset hooked up to an Encord Active database. @@ -307,6 +328,7 @@ def __init__( transform, target_transform, ) + self.return_meta = return_meta def setup(self): super().setup() @@ -317,7 +339,7 @@ def setup(self): in_op(L.project_hash, self.project_hash) ) identifier_query = identifier_query.add_columns(D.data_uri, D.objects, L.label_row_json) - identifiers = sess.exec(identifier_query).all() + self.identifiers = sess.exec(identifier_query).all() feature_hash_to_ontology_object: dict[str, Object] = { o.feature_node_hash: o @@ -349,7 +371,7 @@ def setup(self): data_uri, all_objects, label_row_json, - ) in identifiers: + ) in self.identifiers: object_hash_to_object = { o["objectHash"]: o for o in all_objects if o["featureHash"] in feature_hash_to_ontology_object } @@ -378,7 +400,9 @@ def __getitem__(self, idx): if self.target_transform: labels = self.target_transform(labels) - return img, labels + if not self.return_meta: + return img, labels + return img, labels, self.identifiers[idx] def __len__(self): return len(self.data_unit_paths)