Skip to content

Commit

Permalink
Audio and text classifications can now only be added on frame=0
Browse files Browse the repository at this point in the history
  • Loading branch information
clinton-encord committed Jan 3, 2025
1 parent 5c18d7a commit f2dd1fd
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 33 deletions.
33 changes: 27 additions & 6 deletions encord/objects/classification_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,23 @@
from encord.objects import LabelRowV2


# For Audio and Text files, classifications can only be applied to Range(start=0, end=0)
# Because we treat the entire file as being on one frame (for classifications, its different for objects)
def _verify_non_geometric_classifications_range(
ranges_to_add: Ranges,
label_row: Optional[LabelRowV2]
) -> None:
is_range_only_on_frame_0 = (len(ranges_to_add) == 1
and ranges_to_add[0].start == 0
and ranges_to_add[0].end == 0)
if (label_row is not None
and not is_geometric(label_row.data_type)
and not is_range_only_on_frame_0):
raise LabelRowError(
"For audio files and text files, classifications can only be attached to frame=0 "
"You may use `ClassificationInstance.set_for_frames(frames=Range(start=0, end=0))`.")


class ClassificationInstance:
def __init__(
self,
Expand Down Expand Up @@ -104,8 +121,9 @@ def feature_hash(self) -> str:
def _last_frame(self) -> Union[int, float]:
if self._parent is None or self._parent.data_type is DataType.DICOM:
return float("inf")
elif self._parent is not None and not is_geometric(self._parent.data_type):
# Non-geometric types have only ONE frame. Classifications only apply to the whole file.
elif self._parent is not None and self._parent.data_type == "text/html":
# For HTML files, the entire file is treated as one frame
# Note: for Audio and Plain Text, classifications must be applied to ALL the "frames"
return 1
else:
return self._parent.number_of_frames
Expand Down Expand Up @@ -134,6 +152,7 @@ def _set_for_ranges(
frames: Frames,
overwrite: bool,
created_at: Optional[datetime],

created_by: Optional[str],
confidence: float,
manual_annotation: bool,
Expand All @@ -142,15 +161,19 @@ def _set_for_ranges(
reviews: Optional[List[dict]],
):
new_range_manager = RangeManager(frame_class=frames)
conflicting_ranges = self._is_classification_already_present_on_range(new_range_manager.get_ranges())
ranges_to_add = new_range_manager.get_ranges()

_verify_non_geometric_classifications_range(ranges_to_add, self._parent)

conflicting_ranges = self._is_classification_already_present_on_range(ranges_to_add)
if conflicting_ranges and not overwrite:
raise LabelRowError(
f"The classification '{self.classification_hash}' already exists "
f"on the ranges {conflicting_ranges}. "
f"Set 'overwrite' parameter to True to override."
)

ranges_to_add = new_range_manager.get_ranges()

for range_to_add in ranges_to_add:
self._check_within_range(range_to_add.end)

Expand Down Expand Up @@ -224,8 +247,6 @@ def set_for_frames(
last_edited_at = datetime.now()

if self._range_only:
# Audio range should cover entire audio file
# Text range should always be [0, 0]
self._set_for_ranges(
frames=frames,
overwrite=overwrite,
Expand Down
14 changes: 9 additions & 5 deletions encord/objects/ontology_labels_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
BundledWorkflowReopenPayload,
)
from encord.objects.classification import Classification
from encord.objects.classification_instance import ClassificationInstance
from encord.objects.classification_instance import ClassificationInstance, _verify_non_geometric_classifications_range
from encord.objects.constants import ( # pylint: disable=unused-import # for backward compatibility
DATETIME_LONG_STRING_FORMAT,
DEFAULT_CONFIDENCE,
Expand Down Expand Up @@ -905,10 +905,10 @@ def add_classification_instance(self, classification_instance: ClassificationIns
)

"""
Implementation here diverges because audio data will operate on ranges, whereas
Implementation here diverges because non-geometric data will operate on ranges, whereas
everything else will operate on frames.
"""
if self.data_type == DataType.AUDIO:
if not is_geometric(self.data_type):
self._add_classification_instance_for_range(
classification_instance=classification_instance,
force=force,
Expand Down Expand Up @@ -937,14 +937,18 @@ def add_classification_instance(self, classification_instance: ClassificationIns
self._classifications_to_frames[classification_instance.ontology_item].update(frames)
self._add_to_frame_to_hashes_map(classification_instance, frames)

# This should only be used for Audio classification instances
# This should only be used for Non-geometric data
def _add_classification_instance_for_range(
self,
classification_instance: ClassificationInstance,
force: bool,
):
classification_hash = classification_instance.classification_hash
ranges_to_add = classification_instance.range_list

if classification_instance.is_range_only():
_verify_non_geometric_classifications_range(ranges_to_add, self)

already_present_ranges = self._is_classification_already_present_on_ranges(
classification_instance.ontology_item, ranges_to_add
)
Expand Down Expand Up @@ -977,7 +981,7 @@ def remove_classification(self, classification_instance: ClassificationInstance)
classification_hash = classification_instance.classification_hash
self._classifications_map.pop(classification_hash)

if self.data_type == DataType.AUDIO:
if not is_geometric(self.data_type):
range_manager = self._classifications_to_ranges[classification_instance.ontology_item]
ranges_to_remove = classification_instance.range_list
range_manager.remove_ranges(ranges_to_remove)
Expand Down
6 changes: 3 additions & 3 deletions tests/objects/data/audio_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"manualAnnotation": True,
},
],
"range": [[0, 1]],
"range": [[0, 0]],
"createdBy": "user1Hash",
"createdAt": "Tue, 05 Nov 2024 09:41:37 ",
"lastEditedBy": "user1Hash",
Expand Down Expand Up @@ -71,7 +71,7 @@
"manualAnnotation": True,
},
],
"range": [[0, 1]],
"range": [[0, 0]],
"createdBy": "user1Hash",
"createdAt": "Tue, 05 Nov 2024 09:41:37 ",
"lastEditedBy": "user1Hash",
Expand Down Expand Up @@ -101,7 +101,7 @@
"manualAnnotation": True,
},
],
"range": [[0, 1]],
"range": [[0, 0]],
"createdBy": "user1Hash",
"createdAt": "Tue, 05 Nov 2024 09:41:37 ",
"lastEditedBy": "user1Hash",
Expand Down
41 changes: 22 additions & 19 deletions tests/objects/test_label_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,14 +685,13 @@ def test_add_and_get_classification_instances_to_audio_label_row(ontology):
overlapping_classification_instance = ClassificationInstance(text_classification, range_only=True)
overlapping_classification_instance.set_for_frames(0)

with pytest.raises(LabelRowError):
with pytest.raises(LabelRowError) as e:
label_row.add_classification_instance(overlapping_classification_instance)

overlapping_classification_instance.remove_from_frames(0)
overlapping_classification_instance.set_for_frames(5)
label_row.add_classification_instance(overlapping_classification_instance)
with pytest.raises(LabelRowError):
overlapping_classification_instance.set_for_frames(0)

assert e.value.message == (f"A ClassificationInstance '{overlapping_classification_instance.classification_hash}' was already added "
"and has overlapping frames. Overlapping frames that were "
"found are `[(0:0)]`. Make sure that you only add classifications "
"which are on frames where the same type of classification does not yet exist.")

# Do not raise if overwrite flag is passed
overlapping_classification_instance.set_for_frames(0, overwrite=True)
Expand Down Expand Up @@ -1127,21 +1126,25 @@ def test_non_range_classification_cannot_be_added_to_audio_label_row(ontology):
label_row.add_classification_instance(classification_instance)


def test_audio_classification_exceed_max_frames(ontology, empty_audio_label_row: LabelRowV2):
classification_instance = ClassificationInstance(checklist_classification, range_only=True)
classification_instance.set_for_frames(Range(start=0, end=0))
empty_audio_label_row.add_classification_instance(classification_instance)
def test_non_geometric_label_rows_can_only_have_classifications_on_frame_0(
ontology,
empty_audio_label_row: LabelRowV2,
empty_plain_text_label_row: LabelRowV2,
empty_html_text_label_row: LabelRowV2,
):
for label_row in [empty_audio_label_row, empty_html_text_label_row, empty_plain_text_label_row]:
classification_instance = ClassificationInstance(checklist_classification, range_only=True)
classification_instance.set_for_frames(Range(start=0, end=0))
label_row.add_classification_instance(classification_instance)

with pytest.raises(LabelRowError) as e:
classification_instance.set_for_frames(frames=5000)
with pytest.raises(LabelRowError) as e:
classification_instance.set_for_frames(Range(start=0, end=1))

assert e.value.message == (
"The supplied frame of `5000` is not within the acceptable bounds of `0` to `1`. "
"Note: for non-geometric data (e.g. AUDIO and PLAIN_TEXT), the entire file only has 1 frame."
)
assert e.value.message == ("For audio files and text files, classifications can only be "
"attached to frame=0 You may use "
"`ClassificationInstance.set_for_frames(frames=Range(start=0, end=0))`."
)

range_list = classification_instance.range_list
assert len(range_list) == 1


def test_audio_object_exceed_max_frames(ontology, empty_audio_label_row: LabelRowV2):
Expand Down

0 comments on commit f2dd1fd

Please sign in to comment.