diff --git a/encord/objects/classification_instance.py b/encord/objects/classification_instance.py index 850ae94a..67556c67 100644 --- a/encord/objects/classification_instance.py +++ b/encord/objects/classification_instance.py @@ -55,6 +55,23 @@ from encord.objects import LabelRowV2 +# For Audio and Text files, classifications can only be applied to Range(start=0, end=0) +# Because we treat the entire file as being on one frame (for classifications, its different for objects) +def _verify_non_geometric_classifications_range( + ranges_to_add: Ranges, + label_row: Optional[LabelRowV2] +) -> None: + is_range_only_on_frame_0 = (len(ranges_to_add) == 1 + and ranges_to_add[0].start == 0 + and ranges_to_add[0].end == 0) + if (label_row is not None + and not is_geometric(label_row.data_type) + and not is_range_only_on_frame_0): + raise LabelRowError( + "For audio files and text files, classifications can only be attached to frame=0 " + "You may use `ClassificationInstance.set_for_frames(frames=Range(start=0, end=0))`.") + + class ClassificationInstance: def __init__( self, @@ -104,8 +121,9 @@ def feature_hash(self) -> str: def _last_frame(self) -> Union[int, float]: if self._parent is None or self._parent.data_type is DataType.DICOM: return float("inf") - elif self._parent is not None and not is_geometric(self._parent.data_type): - # Non-geometric types have only ONE frame. Classifications only apply to the whole file. + elif self._parent is not None and self._parent.data_type == "text/html": + # For HTML files, the entire file is treated as one frame + # Note: for Audio and Plain Text, classifications must be applied to ALL the "frames" return 1 else: return self._parent.number_of_frames @@ -134,6 +152,7 @@ def _set_for_ranges( frames: Frames, overwrite: bool, created_at: Optional[datetime], + created_by: Optional[str], confidence: float, manual_annotation: bool, @@ -142,7 +161,11 @@ def _set_for_ranges( reviews: Optional[List[dict]], ): new_range_manager = RangeManager(frame_class=frames) - conflicting_ranges = self._is_classification_already_present_on_range(new_range_manager.get_ranges()) + ranges_to_add = new_range_manager.get_ranges() + + _verify_non_geometric_classifications_range(ranges_to_add, self._parent) + + conflicting_ranges = self._is_classification_already_present_on_range(ranges_to_add) if conflicting_ranges and not overwrite: raise LabelRowError( f"The classification '{self.classification_hash}' already exists " @@ -150,7 +173,7 @@ def _set_for_ranges( f"Set 'overwrite' parameter to True to override." ) - ranges_to_add = new_range_manager.get_ranges() + for range_to_add in ranges_to_add: self._check_within_range(range_to_add.end) @@ -224,8 +247,6 @@ def set_for_frames( last_edited_at = datetime.now() if self._range_only: - # Audio range should cover entire audio file - # Text range should always be [0, 0] self._set_for_ranges( frames=frames, overwrite=overwrite, diff --git a/encord/objects/ontology_labels_impl.py b/encord/objects/ontology_labels_impl.py index 4a82d4dc..c0b25528 100644 --- a/encord/objects/ontology_labels_impl.py +++ b/encord/objects/ontology_labels_impl.py @@ -40,7 +40,7 @@ BundledWorkflowReopenPayload, ) from encord.objects.classification import Classification -from encord.objects.classification_instance import ClassificationInstance +from encord.objects.classification_instance import ClassificationInstance, _verify_non_geometric_classifications_range from encord.objects.constants import ( # pylint: disable=unused-import # for backward compatibility DATETIME_LONG_STRING_FORMAT, DEFAULT_CONFIDENCE, @@ -905,10 +905,10 @@ def add_classification_instance(self, classification_instance: ClassificationIns ) """ - Implementation here diverges because audio data will operate on ranges, whereas + Implementation here diverges because non-geometric data will operate on ranges, whereas everything else will operate on frames. """ - if self.data_type == DataType.AUDIO: + if not is_geometric(self.data_type): self._add_classification_instance_for_range( classification_instance=classification_instance, force=force, @@ -937,7 +937,7 @@ def add_classification_instance(self, classification_instance: ClassificationIns self._classifications_to_frames[classification_instance.ontology_item].update(frames) self._add_to_frame_to_hashes_map(classification_instance, frames) - # This should only be used for Audio classification instances + # This should only be used for Non-geometric data def _add_classification_instance_for_range( self, classification_instance: ClassificationInstance, @@ -945,6 +945,10 @@ def _add_classification_instance_for_range( ): classification_hash = classification_instance.classification_hash ranges_to_add = classification_instance.range_list + + if classification_instance.is_range_only(): + _verify_non_geometric_classifications_range(ranges_to_add, self) + already_present_ranges = self._is_classification_already_present_on_ranges( classification_instance.ontology_item, ranges_to_add ) @@ -977,7 +981,7 @@ def remove_classification(self, classification_instance: ClassificationInstance) classification_hash = classification_instance.classification_hash self._classifications_map.pop(classification_hash) - if self.data_type == DataType.AUDIO: + if not is_geometric(self.data_type): range_manager = self._classifications_to_ranges[classification_instance.ontology_item] ranges_to_remove = classification_instance.range_list range_manager.remove_ranges(ranges_to_remove) diff --git a/tests/objects/data/audio_labels.py b/tests/objects/data/audio_labels.py index 5a5d398f..547054f8 100644 --- a/tests/objects/data/audio_labels.py +++ b/tests/objects/data/audio_labels.py @@ -39,7 +39,7 @@ "manualAnnotation": True, }, ], - "range": [[0, 1]], + "range": [[0, 0]], "createdBy": "user1Hash", "createdAt": "Tue, 05 Nov 2024 09:41:37 ", "lastEditedBy": "user1Hash", @@ -71,7 +71,7 @@ "manualAnnotation": True, }, ], - "range": [[0, 1]], + "range": [[0, 0]], "createdBy": "user1Hash", "createdAt": "Tue, 05 Nov 2024 09:41:37 ", "lastEditedBy": "user1Hash", @@ -101,7 +101,7 @@ "manualAnnotation": True, }, ], - "range": [[0, 1]], + "range": [[0, 0]], "createdBy": "user1Hash", "createdAt": "Tue, 05 Nov 2024 09:41:37 ", "lastEditedBy": "user1Hash", diff --git a/tests/objects/test_label_structure.py b/tests/objects/test_label_structure.py index 7b556070..9e232422 100644 --- a/tests/objects/test_label_structure.py +++ b/tests/objects/test_label_structure.py @@ -685,14 +685,13 @@ def test_add_and_get_classification_instances_to_audio_label_row(ontology): overlapping_classification_instance = ClassificationInstance(text_classification, range_only=True) overlapping_classification_instance.set_for_frames(0) - with pytest.raises(LabelRowError): + with pytest.raises(LabelRowError) as e: label_row.add_classification_instance(overlapping_classification_instance) - - overlapping_classification_instance.remove_from_frames(0) - overlapping_classification_instance.set_for_frames(5) - label_row.add_classification_instance(overlapping_classification_instance) - with pytest.raises(LabelRowError): - overlapping_classification_instance.set_for_frames(0) + + assert e.value.message == (f"A ClassificationInstance '{overlapping_classification_instance.classification_hash}' was already added " + "and has overlapping frames. Overlapping frames that were " + "found are `[(0:0)]`. Make sure that you only add classifications " + "which are on frames where the same type of classification does not yet exist.") # Do not raise if overwrite flag is passed overlapping_classification_instance.set_for_frames(0, overwrite=True) @@ -1127,21 +1126,25 @@ def test_non_range_classification_cannot_be_added_to_audio_label_row(ontology): label_row.add_classification_instance(classification_instance) -def test_audio_classification_exceed_max_frames(ontology, empty_audio_label_row: LabelRowV2): - classification_instance = ClassificationInstance(checklist_classification, range_only=True) - classification_instance.set_for_frames(Range(start=0, end=0)) - empty_audio_label_row.add_classification_instance(classification_instance) +def test_non_geometric_label_rows_can_only_have_classifications_on_frame_0( + ontology, + empty_audio_label_row: LabelRowV2, + empty_plain_text_label_row: LabelRowV2, + empty_html_text_label_row: LabelRowV2, +): + for label_row in [empty_audio_label_row, empty_html_text_label_row, empty_plain_text_label_row]: + classification_instance = ClassificationInstance(checklist_classification, range_only=True) + classification_instance.set_for_frames(Range(start=0, end=0)) + label_row.add_classification_instance(classification_instance) - with pytest.raises(LabelRowError) as e: - classification_instance.set_for_frames(frames=5000) + with pytest.raises(LabelRowError) as e: + classification_instance.set_for_frames(Range(start=0, end=1)) - assert e.value.message == ( - "The supplied frame of `5000` is not within the acceptable bounds of `0` to `1`. " - "Note: for non-geometric data (e.g. AUDIO and PLAIN_TEXT), the entire file only has 1 frame." - ) + assert e.value.message == ("For audio files and text files, classifications can only be " + "attached to frame=0 You may use " + "`ClassificationInstance.set_for_frames(frames=Range(start=0, end=0))`." + ) - range_list = classification_instance.range_list - assert len(range_list) == 1 def test_audio_object_exceed_max_frames(ontology, empty_audio_label_row: LabelRowV2):