Skip to content

Commit

Permalink
fix: pytorch dataset proliferate dataset properties (#685)
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-encord authored Jan 31, 2024
1 parent b4f46ff commit 08562e3
Showing 1 changed file with 37 additions and 13 deletions.
50 changes: 37 additions & 13 deletions src/encord_active/public/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
from uuid import UUID
Expand All @@ -17,13 +18,20 @@
get_engine,
)
from encord_active.lib.common.data_utils import url_to_file_path
from encord_active.lib.db.predictions import FrameClassification

P = Project
T = ProjectTaggedDataUnit
D = ProjectDataUnitMetadata
L = ProjectDataMetadata


@dataclass
class DataIdentifier:
data_hash: UUID
frame: int


class ActiveDataset(Dataset):
def __init__(
self,
Expand Down Expand Up @@ -143,6 +151,7 @@ def __init__(
ontology_hashes: Optional[list[str]] = None,
transform=None,
target_transform=None,
return_meta: bool = False,
):
"""
A dataset hooked up to an Encord Active database.
Expand Down Expand Up @@ -185,6 +194,7 @@ def __init__(
transform,
target_transform,
)
self.return_meta = return_meta

def setup(self):
super().setup()
Expand All @@ -196,7 +206,7 @@ def setup(self):
)
identifier_query = identifier_query.add_columns(D.data_uri, D.classifications, L.label_row_json)

identifiers = sess.exec(identifier_query).all()
self.identifiers = sess.exec(identifier_query).all()
ontology_pairs = [
(c, a)
for c in self.ontology.classifications
Expand All @@ -206,9 +216,10 @@ def setup(self):
]
if len(ontology_pairs) == 0:
raise ValueError("No ontology classifications were found to use for labels")
classification, attribute = ontology_pairs[0]
indices = {o.feature_node_hash: i for i, o in enumerate(attribute.options)}
self.class_names = [o.title for o in attribute.options]
self.classification, self.attribute = ontology_pairs[0]
self.option_indices = {o.feature_node_hash: i for i, o in enumerate(self.attribute.options)}
self._inv_option_indices = {v: k for k, v in self.option_indices.items()}
self.class_names = [o.title for o in self.attribute.options]

self.uris = []
self.labels = []
Expand All @@ -217,32 +228,39 @@ def setup(self):
data_uri,
classifications,
label_row_json,
) in identifiers:
) in self.identifiers:
classification_answers = label_row_json["classification_answers"]

clf_instance = next(
(c for c in classifications if c["featureHash"] == classification.feature_node_hash),
(c for c in classifications if c["featureHash"] == self.classification.feature_node_hash),
None,
)
if clf_instance is None:
continue
clf_hash = clf_instance["classificationHash"]
clf_classifications = classification_answers[clf_hash]["classifications"]
clf_answers = next(
(a for a in clf_classifications if a["featureHash"] == attribute.feature_node_hash),
(a for a in clf_classifications if a["featureHash"] == self.attribute.feature_node_hash),
None,
)
if clf_answers is None:
continue
clf_opt = next(
(o for o in clf_answers["answers"] if o["featureHash"] in indices),
(o for o in clf_answers["answers"] if o["featureHash"] in self.option_indices),
None,
)
if clf_opt is None:
continue

self.uris.append(url_to_file_path(data_uri, self.root_path)) # type: ignore
self.labels.append(indices[clf_opt["featureHash"]])
self.labels.append(self.option_indices[clf_opt["featureHash"]])

def get_frame_classification(self, label_index: int) -> FrameClassification:
fh = self.classification.feature_node_hash
ah = self.attribute.feature_node_hash
return FrameClassification(
feature_hash=fh, attribute_hash=ah, option_hash=self._inv_option_indices[label_index]
)

def __getitem__(self, idx):
data_uri = self.uris[idx]
Expand All @@ -256,7 +274,9 @@ def __getitem__(self, idx):
if self.target_transform:
label = self.target_transform(label)

return img, label
if not self.return_meta:
return img, label
return img, label, self.identifiers[idx][:2]

def __len__(self):
return len(self.labels)
Expand All @@ -271,6 +291,7 @@ def __init__(
ontology_hashes: Optional[list[str]] = None,
transform=None,
target_transform=None,
return_meta: bool = False,
):
"""
A dataset hooked up to an Encord Active database.
Expand Down Expand Up @@ -307,6 +328,7 @@ def __init__(
transform,
target_transform,
)
self.return_meta = return_meta

def setup(self):
super().setup()
Expand All @@ -317,7 +339,7 @@ def setup(self):
in_op(L.project_hash, self.project_hash)
)
identifier_query = identifier_query.add_columns(D.data_uri, D.objects, L.label_row_json)
identifiers = sess.exec(identifier_query).all()
self.identifiers = sess.exec(identifier_query).all()

feature_hash_to_ontology_object: dict[str, Object] = {
o.feature_node_hash: o
Expand Down Expand Up @@ -349,7 +371,7 @@ def setup(self):
data_uri,
all_objects,
label_row_json,
) in identifiers:
) in self.identifiers:
object_hash_to_object = {
o["objectHash"]: o for o in all_objects if o["featureHash"] in feature_hash_to_ontology_object
}
Expand Down Expand Up @@ -378,7 +400,9 @@ def __getitem__(self, idx):
if self.target_transform:
labels = self.target_transform(labels)

return img, labels
if not self.return_meta:
return img, labels
return img, labels, self.identifiers[idx]

def __len__(self):
return len(self.data_unit_paths)

0 comments on commit 08562e3

Please sign in to comment.