diff --git a/encord/common/range_manager.py b/encord/common/range_manager.py index 67fde267..d060e099 100644 --- a/encord/common/range_manager.py +++ b/encord/common/range_manager.py @@ -70,6 +70,10 @@ def remove_ranges(self, ranges_to_remove: Ranges) -> None: for r in ranges_to_remove: self.remove_range(r) + def clear_ranges(self) -> None: + """Clear all ranges.""" + self.ranges = [] + def get_ranges(self) -> Ranges: """Return the sorted list of merged ranges.""" copied_ranges = [range.copy() for range in self.ranges] diff --git a/encord/constants/enums.py b/encord/constants/enums.py index d3360c4e..694ad79b 100644 --- a/encord/constants/enums.py +++ b/encord/constants/enums.py @@ -45,3 +45,17 @@ def from_upper_case_string(string: str) -> DataType: def to_upper_case_string(self) -> str: return self.value.upper() + + +GEOMETRIC_TYPES = { + DataType.VIDEO, + DataType.IMAGE, + DataType.IMG_GROUP, + DataType.DICOM, + DataType.DICOM_STUDY, + DataType.NIFTI, +} + + +def is_geometric(data_type: DataType) -> bool: + return data_type in GEOMETRIC_TYPES diff --git a/encord/objects/classification_instance.py b/encord/objects/classification_instance.py index a1cea1fa..4adff256 100644 --- a/encord/objects/classification_instance.py +++ b/encord/objects/classification_instance.py @@ -31,7 +31,7 @@ from encord.common.range_manager import RangeManager from encord.common.time_parser import parse_datetime -from encord.constants.enums import DataType +from encord.constants.enums import DataType, is_geometric from encord.exceptions import LabelRowError from encord.objects.answers import Answer, ValueType, _get_static_answer_map from encord.objects.attributes import ( @@ -55,6 +55,17 @@ from encord.objects import LabelRowV2 +# For Audio and Text files, classifications can only be applied to Range(start=0, end=0) +# Because we treat the entire file as being on one frame (for classifications, its different for objects) +def _verify_non_geometric_classifications_range(ranges_to_add: Ranges, label_row: Optional[LabelRowV2]) -> None: + is_range_only_on_frame_0 = len(ranges_to_add) == 1 and ranges_to_add[0].start == 0 and ranges_to_add[0].end == 0 + if label_row is not None and not is_geometric(label_row.data_type) and not is_range_only_on_frame_0: + raise LabelRowError( + "For audio files and text files, classifications can only be attached to frame=0 " + "You may use `ClassificationInstance.set_for_frames(frames=Range(start=0, end=0))`." + ) + + class ClassificationInstance: def __init__( self, @@ -104,6 +115,9 @@ def feature_hash(self) -> str: def _last_frame(self) -> Union[int, float]: if self._parent is None or self._parent.data_type is DataType.DICOM: return float("inf") + elif self._parent is not None and not is_geometric(self._parent.data_type): + # For audio and text files, the entire file is treated as one frame + return 1 else: return self._parent.number_of_frames @@ -139,7 +153,11 @@ def _set_for_ranges( reviews: Optional[List[dict]], ): new_range_manager = RangeManager(frame_class=frames) - conflicting_ranges = self._is_classification_already_present_on_range(new_range_manager.get_ranges()) + ranges_to_add = new_range_manager.get_ranges() + + _verify_non_geometric_classifications_range(ranges_to_add, self._parent) + + conflicting_ranges = self._is_classification_already_present_on_range(ranges_to_add) if conflicting_ranges and not overwrite: raise LabelRowError( f"The classification '{self.classification_hash}' already exists " @@ -147,14 +165,9 @@ def _set_for_ranges( f"Set 'overwrite' parameter to True to override." ) - ranges_to_add = new_range_manager.get_ranges() - for range_to_add in ranges_to_add: - self._check_within_range(range_to_add.end) - """ - At this point, this classification instance operates on ranges, NOT on frames. - We therefore leave only FRAME 0 in the map.The frame_data for FRAME 0 will be - treated as the data for all "frames" in this classification instance. + For non-geometric files, the frame_data for FRAME 0 will be + treated as the data for the entire classification instance. """ self._set_frame_and_frame_data( frame=0, @@ -685,7 +698,9 @@ def _is_selectable_child_attribute(self, attribute: Attribute) -> bool: def _check_within_range(self, frame: int) -> None: if frame < 0 or frame >= self._last_frame: raise LabelRowError( - f"The supplied frame of `{frame}` is not within the acceptable bounds of `0` to `{self._last_frame}`." + f"The supplied frame of `{frame}` is not within the acceptable bounds of `0` to `{self._last_frame}`. " + f"Note: for non-geometric data (e.g. {DataType.AUDIO} and {DataType.PLAIN_TEXT}), " + f"the entire file has only 1 frame." ) def _is_classification_already_present(self, frames: Iterable[int]) -> Set[int]: diff --git a/encord/objects/common.py b/encord/objects/common.py index 96864517..a6d70151 100644 --- a/encord/objects/common.py +++ b/encord/objects/common.py @@ -36,6 +36,7 @@ class Shape(StringEnum): ROTATABLE_BOUNDING_BOX = "rotatable_bounding_box" BITMASK = "bitmask" AUDIO = "audio" + TEXT = "text" class DeidentifyRedactTextMode(Enum): diff --git a/encord/objects/coordinates.py b/encord/objects/coordinates.py index 52d4d88c..0516c77a 100644 --- a/encord/objects/coordinates.py +++ b/encord/objects/coordinates.py @@ -19,6 +19,8 @@ from encord.exceptions import LabelRowError from encord.objects.bitmask import BitmaskCoordinates from encord.objects.common import Shape +from encord.objects.frames import Ranges +from encord.objects.html_node import HtmlRange from encord.orm.analytics import CamelStrEnum from encord.orm.base_dto import BaseDTO @@ -339,11 +341,50 @@ def to_dict(self, by_alias=True, exclude_none=True) -> Dict[str, Any]: class AudioCoordinates(BaseDTO): - pass + """ + Represents coordinates for an audio file + + Attributes: + range (Ranges): Ranges in milliseconds for audio files + """ + + range: Ranges + + def __post_init__(self): + if len(self.range) == 0: + raise ValueError("Range list must contain at least one range.") + + +class TextCoordinates(BaseDTO): + """ + Represents coordinates for a text file + + Attributes: + range (Ranges): Ranges of chars for simple text files + """ + + range: Ranges + + +class HtmlCoordinates(BaseDTO): + """ + Represents coordinates for a html file + + Attributes: + range_html (List[HtmlRange]): A list of HtmlRange objects + """ + + range: List[HtmlRange] + + +NON_GEOMETRIC_COORDINATES = {AudioCoordinates, TextCoordinates, HtmlCoordinates} Coordinates = Union[ AudioCoordinates, + TextCoordinates, + Union[HtmlCoordinates, TextCoordinates], + HtmlCoordinates, BoundingBoxCoordinates, RotatableBoundingBoxCoordinates, PointCoordinate, @@ -352,13 +393,15 @@ class AudioCoordinates(BaseDTO): SkeletonCoordinates, BitmaskCoordinates, ] -ACCEPTABLE_COORDINATES_FOR_ONTOLOGY_ITEMS: Dict[Shape, Type[Coordinates]] = { - Shape.BOUNDING_BOX: BoundingBoxCoordinates, - Shape.ROTATABLE_BOUNDING_BOX: RotatableBoundingBoxCoordinates, - Shape.POINT: PointCoordinate, - Shape.POLYGON: PolygonCoordinates, - Shape.POLYLINE: PolylineCoordinates, - Shape.SKELETON: SkeletonCoordinates, - Shape.BITMASK: BitmaskCoordinates, - Shape.AUDIO: AudioCoordinates, + +ACCEPTABLE_COORDINATES_FOR_ONTOLOGY_ITEMS: Dict[Shape, List[Type[Coordinates]]] = { + Shape.BOUNDING_BOX: [BoundingBoxCoordinates], + Shape.ROTATABLE_BOUNDING_BOX: [RotatableBoundingBoxCoordinates], + Shape.POINT: [PointCoordinate], + Shape.POLYGON: [PolygonCoordinates], + Shape.POLYLINE: [PolylineCoordinates], + Shape.SKELETON: [SkeletonCoordinates], + Shape.BITMASK: [BitmaskCoordinates], + Shape.AUDIO: [AudioCoordinates], + Shape.TEXT: [TextCoordinates, HtmlCoordinates], } diff --git a/encord/objects/html_node.py b/encord/objects/html_node.py new file mode 100644 index 00000000..0bfa43ce --- /dev/null +++ b/encord/objects/html_node.py @@ -0,0 +1,69 @@ +""" +--- +title: "Objects - HTML Node" +slug: "sdk-ref-objects-html-node" +hidden: false +metadata: + title: "Objects - HTML Node" + description: "Encord SDK Objects - HTML Node." +category: "64e481b57b6027003f20aaa0" +--- +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Collection, List, Union, cast + +from encord.orm.base_dto import BaseDTO + + +class HtmlNode(BaseDTO): + """ + A class representing a single HTML node, with the node and offset. + + Attributes: + node (str): The xpath of the node + offset (int): The offset of the content from the xpath + """ + + node: str + offset: int + + def __repr__(self): + return f"(Node: {self.node} Offset: {self.offset})" + + +class HtmlRange(BaseDTO): + """ + A class representing a section of HTML with a start and end node. + + Attributes: + start (HtmlNode): The starting node of the range. + end (HtmlNode): The ending node of the range. + """ + + start: HtmlNode + end: HtmlNode + + def __repr__(self): + return f"({self.start} - {self.end})" + + def to_dict(self): + return { + "start": {"node": self.start.node, "offset": self.start.offset}, + "end": {"node": self.end.node, "offset": self.end.offset}, + } + + def __hash__(self): + return f"{self.start.node}-{self.start.offset}-{self.end.node}-{self.end.offset}" + + @classmethod + def from_dict(cls, d: dict): + return HtmlRange( + start=HtmlNode(node=d["start"]["node"], offset=d["start"]["offset"]), + end=HtmlNode(node=d["end"]["node"], offset=d["end"]["offset"]), + ) + + +HtmlRanges = List[HtmlRange] diff --git a/encord/objects/ontology_labels_impl.py b/encord/objects/ontology_labels_impl.py index 65f1c09c..dd317e16 100644 --- a/encord/objects/ontology_labels_impl.py +++ b/encord/objects/ontology_labels_impl.py @@ -22,7 +22,7 @@ from encord.client import EncordClientProject from encord.client import LabelRow as OrmLabelRow from encord.common.range_manager import RangeManager -from encord.constants.enums import DataType +from encord.constants.enums import DataType, is_geometric from encord.exceptions import LabelRowError, WrongProjectTypeError from encord.http.bundle import Bundle, BundleResultHandler, BundleResultMapper, bundled_operation from encord.http.limits import ( @@ -40,7 +40,7 @@ BundledWorkflowReopenPayload, ) from encord.objects.classification import Classification -from encord.objects.classification_instance import ClassificationInstance +from encord.objects.classification_instance import ClassificationInstance, _verify_non_geometric_classifications_range from encord.objects.constants import ( # pylint: disable=unused-import # for backward compatibility DATETIME_LONG_STRING_FORMAT, DEFAULT_CONFIDENCE, @@ -51,14 +51,17 @@ BitmaskCoordinates, BoundingBoxCoordinates, Coordinates, + HtmlCoordinates, PointCoordinate, PolygonCoordinates, PolylineCoordinates, RotatableBoundingBoxCoordinates, SkeletonCoordinates, + TextCoordinates, Visibility, ) from encord.objects.frames import Frames, Range, Ranges, frames_class_to_frames_list, frames_to_ranges +from encord.objects.html_node import HtmlRange from encord.objects.metadata import DICOMSeriesMetadata, DICOMSliceMetadata from encord.objects.ontology_object import Object from encord.objects.ontology_object_instance import ObjectInstance @@ -829,6 +832,19 @@ def add_object_instance(self, object_instance: ObjectInstance, force: bool = Tru object_instance.is_valid() + # We want to ensure that we are only adding the object_instance to a label_row + # IF AND ONLY IF the file type is text/html and the object_instance has range_html set + if self.file_type == "text/html" and object_instance.range_html is None: + raise LabelRowError( + "Unable to assign object instance without a html range to a html file. " + f"Please ensure the object instance exists on frame=0, and has coordinates of type {HtmlCoordinates}." + ) + elif self.file_type != "text/html" and object_instance.range_html is not None: + raise LabelRowError( + "Unable to assign object instance with a html range to a non-html file. " + f"Please ensure the object instance does not have coordinates of type {HtmlCoordinates}." + ) + if object_instance.is_assigned_to_label_row(): raise LabelRowError( "The supplied ObjectInstance is already part of a LabelRowV2. You can only add a ObjectInstance to one " @@ -847,9 +863,8 @@ def add_object_instance(self, object_instance: ObjectInstance, force: bool = Tru self._objects_map[object_hash] = object_instance object_instance._parent = self - if not object_instance.is_range_only(): - frames = set(_frame_views_to_frame_numbers(object_instance.get_annotations())) - self._add_to_frame_to_hashes_map(object_instance, frames) + frames = set(_frame_views_to_frame_numbers(object_instance.get_annotations())) + self._add_to_frame_to_hashes_map(object_instance, frames) def add_classification_instance(self, classification_instance: ClassificationInstance, force: bool = False) -> None: """ @@ -867,9 +882,9 @@ def add_classification_instance(self, classification_instance: ClassificationIns classification_instance.is_valid() # TODO: Need to update the docstring for this method, talk to Laverne. - if not classification_instance.is_range_only() and self.data_type == DataType.AUDIO: + if not classification_instance.is_range_only() and not is_geometric(self.data_type): raise LabelRowError( - "To add a ClassificationInstance object to an Audio LabelRow," + f"To add a ClassificationInstance object to a label row where data_type = {self.data_type}," "the ClassificationInstance object needs to be created with the " "range_only property set to True." "You can do ClassificationInstance(range_only=True) or " @@ -891,10 +906,10 @@ def add_classification_instance(self, classification_instance: ClassificationIns ) """ - Implementation here diverges because audio data will operate on ranges, whereas + Implementation here diverges because non-geometric data will operate on ranges, whereas everything else will operate on frames. """ - if self.data_type == DataType.AUDIO: + if not is_geometric(self.data_type): self._add_classification_instance_for_range( classification_instance=classification_instance, force=force, @@ -923,7 +938,7 @@ def add_classification_instance(self, classification_instance: ClassificationIns self._classifications_to_frames[classification_instance.ontology_item].update(frames) self._add_to_frame_to_hashes_map(classification_instance, frames) - # This should only be used for Audio classification instances + # This should only be used for Non-geometric data def _add_classification_instance_for_range( self, classification_instance: ClassificationInstance, @@ -931,6 +946,10 @@ def _add_classification_instance_for_range( ): classification_hash = classification_instance.classification_hash ranges_to_add = classification_instance.range_list + + if classification_instance.is_range_only(): + _verify_non_geometric_classifications_range(ranges_to_add, self) + already_present_ranges = self._is_classification_already_present_on_ranges( classification_instance.ontology_item, ranges_to_add ) @@ -963,7 +982,7 @@ def remove_classification(self, classification_instance: ClassificationInstance) classification_hash = classification_instance.classification_hash self._classifications_map.pop(classification_hash) - if self.data_type == DataType.AUDIO: + if not is_geometric(self.data_type): range_manager = self._classifications_to_ranges[classification_instance.ontology_item] ranges_to_remove = classification_instance.range_list range_manager.remove_ranges(ranges_to_remove) @@ -1631,19 +1650,30 @@ def _to_object_answers(self) -> Dict[str, Any]: } # At some point, we also want to add these to the other modalities - if self.data_type == DataType.AUDIO: - annotation = obj.get_annotations()[0] - ret[obj.object_hash]["range"] = [[range.start, range.end] for range in obj.range_list] - ret[obj.object_hash]["createdBy"] = annotation.created_by - ret[obj.object_hash]["createdAt"] = annotation.created_at.strftime(DATETIME_LONG_STRING_FORMAT) - ret[obj.object_hash]["lastEditedBy"] = annotation.last_edited_by - ret[obj.object_hash]["lastEditedAt"] = annotation.last_edited_at.strftime(DATETIME_LONG_STRING_FORMAT) - ret[obj.object_hash]["manualAnnotation"] = annotation.manual_annotation - ret[obj.object_hash]["featureHash"] = obj.feature_hash - ret[obj.object_hash]["name"] = obj.ontology_item.name - ret[obj.object_hash]["color"] = obj.ontology_item.color - ret[obj.object_hash]["shape"] = obj.ontology_item.shape.value - ret[obj.object_hash]["value"] = _lower_snake_case(obj.ontology_item.name) + if not is_geometric(self.data_type): + # For non-frame entities, all annotations exist only on one frame + annotation = obj.get_annotation(0) + object_answer_dict = ret[obj.object_hash] + object_answer_dict["createdBy"] = annotation.created_by + object_answer_dict["createdAt"] = annotation.created_at.strftime(DATETIME_LONG_STRING_FORMAT) + object_answer_dict["lastEditedBy"] = annotation.last_edited_by + object_answer_dict["lastEditedAt"] = annotation.last_edited_at.strftime(DATETIME_LONG_STRING_FORMAT) + object_answer_dict["manualAnnotation"] = annotation.manual_annotation + object_answer_dict["featureHash"] = obj.feature_hash + object_answer_dict["name"] = obj.ontology_item.name + object_answer_dict["color"] = obj.ontology_item.color + object_answer_dict["shape"] = obj.ontology_item.shape.value + object_answer_dict["value"] = _lower_snake_case(obj.ontology_item.name) + + if self.file_type == "text/html": + if obj.range_html is None: + raise LabelRowError("Html annotations should have range_html set within the TextCoordinates") + object_answer_dict["range_html"] = [x.to_dict() for x in obj.range_html] + object_answer_dict["range"] = [] + else: + if obj.range_list is None: + raise LabelRowError("Non-geometric annotations should have range set within the Coordinates") + object_answer_dict["range"] = [[range.start, range.end] for range in obj.range_list] return ret @@ -1671,11 +1701,11 @@ def _to_classification_answers(self) -> Dict[str, Any]: } # At some point, we also want to add these to the other modalities - if self.data_type == DataType.AUDIO: + if not is_geometric(self.data_type): annotation = classification.get_annotations()[0] - ret[classification.classification_hash]["range"] = [ - [range.start, range.end] for range in classification.range_list - ] + + # For non-geometric data, classifications apply to whole file + ret[classification.classification_hash]["range"] = [] ret[classification.classification_hash]["createdBy"] = annotation.created_by ret[classification.classification_hash]["createdAt"] = annotation.created_at.strftime( DATETIME_LONG_STRING_FORMAT @@ -1745,24 +1775,40 @@ def _to_encord_data_unit(self, frame_level_data: FrameLevelImageGroupData) -> Di ret["data_link"] = frame_level_data.data_link ret["data_type"] = frame_level_data.file_type - ret["data_sequence"] = data_sequence - if self.data_type != DataType.AUDIO: - ret["width"] = frame_level_data.width - ret["height"] = frame_level_data.height - - else: + if self.data_type == DataType.AUDIO: ret["audio_codec"] = self._label_row_read_only_data.audio_codec ret["audio_sample_rate"] = self._label_row_read_only_data.audio_sample_rate ret["audio_bit_depth"] = self._label_row_read_only_data.audio_bit_depth ret["audio_num_channels"] = self._label_row_read_only_data.audio_num_channels + elif self.data_type == DataType.PLAIN_TEXT or self.data_type == DataType.PDF: + pass + elif ( + self.data_type == DataType.IMAGE + or self.data_type == DataType.NIFTI + or self.data_type == DataType.VIDEO + or self.data_type == DataType.IMG_GROUP + or self.data_type == DataType.DICOM + or self.data_type == DataType.DICOM_STUDY + ): + ret["width"] = frame_level_data.width + ret["height"] = frame_level_data.height + elif self.data_type == DataType.MISSING_DATA_TYPE: + raise LabelRowError("Label row is missing data type.") + else: + exhaustive_guard(self.data_type) ret["labels"] = self._to_encord_labels(frame_level_data) - if self._label_row_read_only_data.duration is not None: + if self._label_row_read_only_data.duration is not None and self.data_type != DataType.PLAIN_TEXT: ret["data_duration"] = self._label_row_read_only_data.duration - if self._label_row_read_only_data.fps is not None and self.data_type != DataType.AUDIO: + + if ( + self._label_row_read_only_data.fps is not None + and self.data_type != DataType.AUDIO + and self.data_type != DataType.PLAIN_TEXT + ): ret["data_fps"] = self._label_row_read_only_data.fps return ret @@ -1998,6 +2044,7 @@ def _parse_label_row_dict(self, label_row_dict: dict) -> LabelRowReadOnlyData: if data_type == DataType.VIDEO or data_type == DataType.IMAGE: data_dict = list(label_row_dict["data_units"].values())[0] data_link = data_dict["data_link"] + file_type = data_dict.get("data_type") # Dimensions should be always there # But we have some older entries that don't have them # So setting them to None for now until the format is not guaranteed to be enforced @@ -2007,12 +2054,15 @@ def _parse_label_row_dict(self, label_row_dict: dict) -> LabelRowReadOnlyData: elif data_type == DataType.DICOM or data_type == DataType.NIFTI: dicom_dict = list(label_row_dict["data_units"].values())[0] data_link = None + file_type = dicom_dict["data_type"] + height = dicom_dict["height"] width = dicom_dict["width"] elif data_type == DataType.AUDIO: data_dict = list(label_row_dict["data_units"].values())[0] data_link = data_dict["data_link"] + file_type = data_dict.get("data_type") height = None width = None audio_codec = data_dict["audio_codec"] @@ -2020,8 +2070,16 @@ def _parse_label_row_dict(self, label_row_dict: dict) -> LabelRowReadOnlyData: audio_num_channels = data_dict["audio_num_channels"] audio_bit_depth = data_dict["audio_bit_depth"] + elif data_type == DataType.PLAIN_TEXT: + data_dict = list(label_row_dict["data_units"].values())[0] + data_link = data_dict["data_link"] + file_type = data_dict.get("data_type") + height = None + width = None + elif data_type == DataType.IMG_GROUP: data_link = None + file_type = None height = None width = None @@ -2031,6 +2089,7 @@ def _parse_label_row_dict(self, label_row_dict: dict) -> LabelRowReadOnlyData: elif data_type == DataType.PLAIN_TEXT or data_type == DataType.PDF: data_dict = list(label_row_dict["data_units"].values())[0] data_link = data_dict["data_link"] + file_type = data_dict.get("data_type") height = None width = None @@ -2072,7 +2131,7 @@ def _parse_label_row_dict(self, label_row_dict: dict) -> LabelRowReadOnlyData: priority=label_row_dict.get("priority", self._label_row_read_only_data.priority), client_metadata=label_row_dict.get("client_metadata", self._label_row_read_only_data.client_metadata), images_data=label_row_dict.get("images_data", self._label_row_read_only_data.images_data), - file_type=label_row_dict.get("file_type", None), + file_type=file_type, is_valid=bool(label_row_dict.get("is_valid", True)), backing_item_uuid=self.backing_item_uuid, ) @@ -2111,7 +2170,8 @@ def _parse_labels_from_dict(self, label_row_dict: dict): raise NotImplementedError(f"Got an unexpected data type `{data_type}`") elif data_type == DataType.AUDIO or data_type == DataType.PDF or data_type == DataType.PLAIN_TEXT: - self._add_objects_instances_from_objects_without_frames(object_answers) + is_html = data_unit["data_type"] == "text/html" + self._add_objects_instances_from_objects_without_frames(object_answers, html=is_html) self._add_classification_instances_from_classifications_without_frames(classification_answers) else: @@ -2142,14 +2202,21 @@ def _add_data_unit_metadata(self, data_type: DataType, metadata: Optional[Dict[s def _add_objects_instances_from_objects_without_frames( self, object_answers: dict, + html: bool = False, ): - for object_answer in object_answers.values(): - ranges: Ranges = [] - for range_elem in object_answer["range"]: - ranges.append(Range(range_elem[0], range_elem[1])) + if html: + for object_answer in object_answers.values(): + object_instance = self._create_new_html_object_instance(object_answer, object_answer["range_html"]) + self.add_object_instance(object_instance) + + else: + for object_answer in object_answers.values(): + ranges: Ranges = [] + for range_elem in object_answer["range"]: + ranges.append(Range(range_elem[0], range_elem[1])) - object_instance = self._create_new_object_instance_with_ranges(object_answer, ranges) - self.add_object_instance(object_instance) + object_instance = self._create_new_object_instance_with_ranges(object_answer, ranges) + self.add_object_instance(object_instance) def _add_object_instances_from_objects( self, @@ -2223,17 +2290,56 @@ def _create_new_object_instance_with_ranges( object_frame_instance_info = ObjectInstance.FrameInfo.from_dict(frame_info_dict) expected_shape: Shape + coordinates: Union[AudioCoordinates, TextCoordinates] if self._label_row_read_only_data.data_type == DataType.AUDIO: expected_shape = Shape.AUDIO + coordinates = AudioCoordinates(range=ranges) + elif self._label_row_read_only_data.data_type == DataType.PLAIN_TEXT: + expected_shape = Shape.TEXT + coordinates = TextCoordinates(range=ranges) else: unknown_data_type = self._label_row_read_only_data.data_type raise RuntimeError(f"Unexpected data type[{unknown_data_type}] for range based objects") if label_class.shape != expected_shape: raise LabelRowError("Unsupported object shape for data type") object_instance = ObjectInstance(label_class, object_hash=object_hash) + + object_instance.set_for_frames( + coordinates, + frames=0, + created_at=object_frame_instance_info.created_at, + created_by=object_frame_instance_info.created_by, + confidence=object_frame_instance_info.confidence, + manual_annotation=object_frame_instance_info.manual_annotation, + last_edited_at=object_frame_instance_info.last_edited_at, + last_edited_by=object_frame_instance_info.last_edited_by, + reviews=object_frame_instance_info.reviews, + overwrite=True, + # Always overwrite during label row dict parsing, as older dicts known to have duplicates + ) + answer_list = object_answer["classifications"] + object_instance.set_answer_from_list(answer_list) + + return object_instance + + def _create_new_html_object_instance( + self, + object_answer: dict, + range_html: dict, + ) -> ObjectInstance: + feature_hash = object_answer["featureHash"] + object_hash = object_answer["objectHash"] + + label_class = self._ontology.structure.get_child_by_hash(feature_hash, type_=Object) + + frame_info_dict = {k: v for k, v in object_answer.items() if v is not None} + frame_info_dict.setdefault("confidence", 1.0) # confidence sometimes not present. + object_frame_instance_info = ObjectInstance.FrameInfo.from_dict(frame_info_dict) + + object_instance = ObjectInstance(label_class, object_hash=object_hash) object_instance.set_for_frames( - AudioCoordinates(), - ranges, + HtmlCoordinates(range=[HtmlRange.from_dict(x) for x in range_html]), + frames=0, created_at=object_frame_instance_info.created_at, created_by=object_frame_instance_info.created_by, confidence=object_frame_instance_info.confidence, @@ -2323,13 +2429,7 @@ def _add_classification_instances_from_classifications_without_frames( classification_answers: dict, ): for classification_answer in classification_answers.values(): - ranges: Ranges = [] - for range_elem in classification_answer["range"]: - ranges.append(Range(range_elem[0], range_elem[1])) - - classification_instance = self._create_new_classification_instance_with_ranges( - classification_answer, ranges - ) + classification_instance = self._create_new_classification_instance_with_ranges(classification_answer) self.add_classification_instance(classification_instance) def _parse_image_group_frame_level_data(self, label_row_data_units: dict) -> Dict[int, FrameLevelImageGroupData]: @@ -2386,9 +2486,7 @@ def _create_new_classification_instance( return None # This is only to be used by non-frame modalities (e.g. Audio) - def _create_new_classification_instance_with_ranges( - self, classification_answer: dict, ranges: Ranges - ) -> ClassificationInstance: + def _create_new_classification_instance_with_ranges(self, classification_answer: dict) -> ClassificationInstance: feature_hash = classification_answer["featureHash"] classification_hash = classification_answer["classificationHash"] @@ -2399,8 +2497,10 @@ def _create_new_classification_instance_with_ranges( classification_instance = ClassificationInstance( label_class, classification_hash=classification_hash, range_only=True ) + + # For non-geometric data, the classification will always be treated as being on frame=0, + # which is the entire file classification_instance.set_for_frames( - ranges, created_at=range_view.created_at, created_by=range_view.created_by, confidence=range_view.confidence, diff --git a/encord/objects/ontology_object_instance.py b/encord/objects/ontology_object_instance.py index c835978b..ad8e3ad0 100644 --- a/encord/objects/ontology_object_instance.py +++ b/encord/objects/ontology_object_instance.py @@ -43,8 +43,11 @@ from encord.objects.constants import DEFAULT_CONFIDENCE, DEFAULT_MANUAL_ANNOTATION from encord.objects.coordinates import ( ACCEPTABLE_COORDINATES_FOR_ONTOLOGY_ITEMS, + NON_GEOMETRIC_COORDINATES, AudioCoordinates, Coordinates, + HtmlCoordinates, + TextCoordinates, ) from encord.objects.frames import ( Frames, @@ -54,6 +57,7 @@ frames_to_ranges, ranges_list_to_ranges, ) +from encord.objects.html_node import HtmlRange, HtmlRanges from encord.objects.internal_helpers import ( _infer_attribute_from_answer, _search_child_attributes, @@ -82,10 +86,8 @@ def __init__(self, ontology_object: Object, *, object_hash: Optional[str] = None self._dynamic_answer_manager = DynamicAnswerManager(self) # Only used for non-frame entities - self._range_only = ontology_object.shape in (Shape.AUDIO,) - self._range_manager: RangeManager = RangeManager() + self._non_geometric = ontology_object.shape in (Shape.AUDIO, Shape.TEXT) - # Only used for frame entities self._frames_to_instance_data: Dict[int, ObjectInstance.FrameData] = {} def is_assigned_to_label_row(self) -> Optional[LabelRowV2]: @@ -145,9 +147,19 @@ def _last_frame(self) -> Union[int, float]: return self._parent.number_of_frames @property - def range_list(self) -> Ranges: - if self._range_only: - return self._range_manager.get_ranges() + def range_list(self) -> Ranges | None: + if self._non_geometric: + non_geometric_annotation = self._get_non_geometric_annotation() + if non_geometric_annotation is None: + return None + + coordinates = non_geometric_annotation.coordinates + + if isinstance(coordinates, (AudioCoordinates, TextCoordinates)): + return coordinates.range + else: + return None + else: raise LabelRowError( "No ranges available for this object instance." @@ -156,8 +168,24 @@ def range_list(self) -> Ranges: "You can do ObjectInstance(audio_ontology_object) to achieve this." ) + @property + def range_html(self) -> Optional[HtmlRanges]: + if not self._non_geometric: + return None + + non_geometric_annotation = self._get_non_geometric_annotation() + if non_geometric_annotation is None: + return None + + coordinates = non_geometric_annotation.coordinates + + if isinstance(coordinates, HtmlCoordinates): + return coordinates.range + else: + return None + def is_range_only(self) -> bool: - return self._range_only + return self._non_geometric def get_answer( self, @@ -414,16 +442,6 @@ def check_within_range(self, frame: int) -> None: f"The supplied frame of `{frame}` is not within the acceptable bounds of `0` to `{self._last_frame}`." ) - def _set_for_ranges( - self, - frames: Frames, - ) -> None: - new_range_manager = RangeManager(frame_class=frames) - ranges_to_add = new_range_manager.get_ranges() - for range_to_add in ranges_to_add: - self.check_within_range(range_to_add.end) - self._range_manager.add_ranges(ranges_to_add) - def set_for_frames( self, coordinates: Coordinates, @@ -465,73 +483,62 @@ def set_for_frames( reviews: Should only be set by internal functions. is_deleted: Should only be set by internal functions. """ - if self._range_only: - if not isinstance(coordinates, AudioCoordinates): - raise LabelRowError("Expecting range only coordinate type") - existing_frame_data = self._frames_to_instance_data.get(0) - if overwrite is False and existing_frame_data is not None and self._range_manager.intersection(frames): + if self._non_geometric: + if not isinstance(coordinates, tuple(NON_GEOMETRIC_COORDINATES)): + raise LabelRowError("Expecting non-geometric coordinate type") + + elif frames != 0: + raise LabelRowError( + f"For objects with a non-geometric shape (e.g. {Shape.TEXT} and {Shape.AUDIO}), " + f"There is only one frame. Please ensure `set_for_frames` is called with `frames=0`." + ) + + frames_list = frames_class_to_frames_list(frames) + + for frame in frames_list: + existing_frame_data = self._frames_to_instance_data.get(frame) + + if overwrite is False and existing_frame_data is not None: raise LabelRowError( "Cannot overwrite existing data for a frame. Set `overwrite` to `True` to overwrite." ) - self._set_for_ranges( - frames=frames, - ) + check_coordinate_type(coordinates, self._ontology_object, self._parent) + + if isinstance(coordinates, (TextCoordinates, AudioCoordinates)): + for non_geometric_range in coordinates.range: + self.check_within_range(non_geometric_range.end) + else: + self.check_within_range(frame) if existing_frame_data is None: existing_frame_data = ObjectInstance.FrameData( coordinates=coordinates, object_frame_instance_info=ObjectInstance.FrameInfo() ) - self._frames_to_instance_data[0] = existing_frame_data - - existing_frame_data.object_frame_instance_info.update_from_optional_fields( - created_at=created_at, - created_by=created_by, - last_edited_at=last_edited_at, - last_edited_by=last_edited_by, - confidence=confidence, - manual_annotation=manual_annotation, - reviews=reviews, - is_deleted=is_deleted, - ) - - else: - if isinstance(coordinates, AudioCoordinates): - raise LabelRowError("Cannot add audio coordinates to object with frames") - frames_list = frames_class_to_frames_list(frames) - - for frame in frames_list: - existing_frame_data = self._frames_to_instance_data.get(frame) + self._frames_to_instance_data[frame] = existing_frame_data - if overwrite is False and existing_frame_data is not None: - raise LabelRowError( - "Cannot overwrite existing data for a frame. Set `overwrite` to `True` to overwrite." - ) - - check_coordinate_type(coordinates, self._ontology_object) - self.check_within_range(frame) + existing_frame_data.object_frame_instance_info.update_from_optional_fields( + created_at=created_at, + created_by=created_by, + last_edited_at=last_edited_at, + last_edited_by=last_edited_by, + confidence=confidence, + manual_annotation=manual_annotation, + reviews=reviews, + is_deleted=is_deleted, + ) + existing_frame_data.coordinates = coordinates - if existing_frame_data is None: - existing_frame_data = ObjectInstance.FrameData( - coordinates=coordinates, object_frame_instance_info=ObjectInstance.FrameInfo() - ) - self._frames_to_instance_data[frame] = existing_frame_data - - existing_frame_data.object_frame_instance_info.update_from_optional_fields( - created_at=created_at, - created_by=created_by, - last_edited_at=last_edited_at, - last_edited_by=last_edited_by, - confidence=confidence, - manual_annotation=manual_annotation, - reviews=reviews, - is_deleted=is_deleted, - ) - existing_frame_data.coordinates = coordinates + if self._parent: + self._parent.add_to_single_frame_to_hashes_map(self, frame) - if self._parent: - self._parent.add_to_single_frame_to_hashes_map(self, frame) + def _get_non_geometric_annotation(self) -> Optional[Annotation]: + # Non-geometric annotations (e.g. Audio and Text) only have one frame. + if 0 not in self._frames_to_instance_data: + return None + else: + return self.get_annotation(0) def get_annotation(self, frame: Union[int, str] = 0) -> Annotation: """ @@ -547,7 +554,7 @@ def get_annotation(self, frame: Union[int, str] = 0) -> Annotation: Raises: LabelRowError: If the frame is not present in the label row. """ - if self._range_only and frame != 0: + if self._non_geometric and frame != 0: raise LabelRowError( 'This annotation data for this object instance is stored on only one "frame". ' "Use `get_annotation(0)` to get the frame data of the first frame." @@ -598,10 +605,8 @@ def get_annotation_frames(self) -> set[int]: Returns: List[Annotation]: A list of `ObjectInstance.Annotation` in order of available frames. """ - if self._range_only: - return self._range_manager.get_ranges_as_frames() - else: - return {self.get_annotation(frame_num).frame for frame_num in sorted(self._frames_to_instance_data.keys())} + + return {self.get_annotation(frame_num).frame for frame_num in sorted(self._frames_to_instance_data.keys())} def remove_from_frames(self, frames: Frames) -> None: """ @@ -610,19 +615,18 @@ def remove_from_frames(self, frames: Frames) -> None: Args: frames: The frames from which to remove the object instance. """ - if self._range_only: - new_range_manager = RangeManager(frame_class=frames) - ranges_to_add = new_range_manager.get_ranges() - for range_to_add in ranges_to_add: - self.check_within_range(range_to_add.end) - self._range_manager.remove_ranges(ranges_to_add) - else: - frames_list = frames_class_to_frames_list(frames) - for frame in frames_list: - self._frames_to_instance_data.pop(frame) + if self._non_geometric and frames != 0: + raise LabelRowError( + f"For objects with a non-geometric shape (e.g. {Shape.TEXT} and {Shape.AUDIO}), " + f"There is only one frame. Please ensure `remove_from_frames` is called with `frames=0`." + ) - if self._parent: - self._parent._remove_from_frame_to_hashes_map(frames_list, self.object_hash) + frames_list = frames_class_to_frames_list(frames) + for frame in frames_list: + self._frames_to_instance_data.pop(frame) + + if self._parent: + self._parent._remove_from_frame_to_hashes_map(frames_list, self.object_hash) def is_valid(self) -> None: """ @@ -631,12 +635,8 @@ def is_valid(self) -> None: Raises: LabelRowError: If the ObjectInstance is not on any frames. """ - if self._range_only: - if len(self._range_manager.get_ranges()) == 0: - raise LabelRowError("ObjectInstance is not on any frames. Please add it to at least one frame.") - else: - if len(self._frames_to_instance_data) == 0: - raise LabelRowError("ObjectInstance is not on any frames. Please add it to at least one frame.") + if len(self._frames_to_instance_data) == 0: + raise LabelRowError("ObjectInstance is not on any frames. Please add it to at least one frame.") self.are_dynamic_answers_valid() @@ -822,7 +822,7 @@ def from_dict(d: dict) -> ObjectInstance.FrameInfo: last_edited_at=last_edited_at, last_edited_by=d.get("lastEditedBy"), confidence=d["confidence"], - manual_annotation=d["manualAnnotation"], + manual_annotation=d.get("manualAnnotation", True), reviews=d.get("reviews"), is_deleted=d.get("isDeleted"), ) @@ -949,23 +949,38 @@ def __lt__(self, other: ObjectInstance) -> bool: return self._object_hash < other._object_hash -def check_coordinate_type(coordinates: Coordinates, ontology_object: Object) -> None: +def check_coordinate_type(coordinates: Coordinates, ontology_object: Object, parent: Optional[LabelRowV2]) -> None: """ Check if the coordinate type matches the expected type for the ontology object. Args: coordinates (Coordinates): The coordinates to check. ontology_object (Object): The ontology object to check against. + parent (LabelRowV2): The parent label row (if any) of the ontology object. Raises: LabelRowError: If the coordinate type does not match the expected type. """ - expected_coordinate_type = ACCEPTABLE_COORDINATES_FOR_ONTOLOGY_ITEMS[ontology_object.shape] - if not isinstance(coordinates, expected_coordinate_type): + expected_coordinate_types = ACCEPTABLE_COORDINATES_FOR_ONTOLOGY_ITEMS[ontology_object.shape] + if all( + not isinstance(coordinates, expected_coordinate_type) for expected_coordinate_type in expected_coordinate_types + ): raise LabelRowError( - f"Expected a coordinate of type `{expected_coordinate_type}`, but got type `{type(coordinates)}`." + f"Expected coordinates of one of the following types: `{expected_coordinate_types}`, but got type `{type(coordinates)}`." ) + # An ontology object with `Text` shape can have both coordinates `HtmlCoordinates` and `TextCoordinates` + # Therefore, we need to further check the file type, to ensure that `HtmlCoordinates` are only used for + # HTML files, and `TextCoordinates` are only used for plain text files. + if isinstance(coordinates, TextCoordinates): + if parent is not None and parent == "text/html": + raise LabelRowError(f"Expected coordinates of type {HtmlCoordinates}`, but got type `{type(coordinates)}`.") + elif isinstance(coordinates, HtmlCoordinates): + if parent is not None and parent.file_type != "text/html": + raise LabelRowError( + "For non-html labels, ensure the `range` property " "is set when instantiating the TextCoordinates." + ) + class DynamicAnswerManager: """ diff --git a/tests/objects/data/all_ontology_types.py b/tests/objects/data/all_ontology_types.py index 799e709e..3df0f98c 100644 --- a/tests/objects/data/all_ontology_types.py +++ b/tests/objects/data/all_ontology_types.py @@ -103,6 +103,13 @@ {"id": "9", "name": "audio 2", "color": "#A4DD00", "shape": "audio", "featureNodeHash": "VDeQk05m"}, {"id": "10", "name": "audio 3", "color": "#A4DD00", "shape": "audio", "featureNodeHash": "bjvtzFgi"}, {"id": "11", "name": "audio 4", "color": "#A4DD00", "shape": "audio", "featureNodeHash": "3X3+Ydcy"}, + { + "id": "12", + "name": "text object", + "color": "#A4DD00", + "shape": "text", + "featureNodeHash": "textObjectFeatureNodeHash", + }, ], "classifications": [ { diff --git a/tests/objects/data/all_types_ontology_structure.py b/tests/objects/data/all_types_ontology_structure.py index 3bc5036c..200d1cca 100644 --- a/tests/objects/data/all_types_ontology_structure.py +++ b/tests/objects/data/all_types_ontology_structure.py @@ -208,6 +208,14 @@ feature_node_hash="KVfzNkFy", attributes=[], ), + Object( + uid=9, + name="text object", + color="#A4FF00", + shape=Shape.TEXT, + feature_node_hash="textFeatureNodeHash", + attributes=[], + ), ], classifications=[ Classification( diff --git a/tests/objects/data/audio_labels.py b/tests/objects/data/audio_labels.py index 5a5d398f..892f74f7 100644 --- a/tests/objects/data/audio_labels.py +++ b/tests/objects/data/audio_labels.py @@ -39,7 +39,7 @@ "manualAnnotation": True, }, ], - "range": [[0, 1]], + "range": [], "createdBy": "user1Hash", "createdAt": "Tue, 05 Nov 2024 09:41:37 ", "lastEditedBy": "user1Hash", @@ -71,7 +71,7 @@ "manualAnnotation": True, }, ], - "range": [[0, 1]], + "range": [], "createdBy": "user1Hash", "createdAt": "Tue, 05 Nov 2024 09:41:37 ", "lastEditedBy": "user1Hash", @@ -101,7 +101,7 @@ "manualAnnotation": True, }, ], - "range": [[0, 1]], + "range": [], "createdBy": "user1Hash", "createdAt": "Tue, 05 Nov 2024 09:41:37 ", "lastEditedBy": "user1Hash", diff --git a/tests/objects/data/html_text_labels.py b/tests/objects/data/html_text_labels.py new file mode 100644 index 00000000..07750271 --- /dev/null +++ b/tests/objects/data/html_text_labels.py @@ -0,0 +1,166 @@ +HTML_TEXT_LABELS = { + "label_hash": "0aea5ac7-cbc0-4451-a242-e22445d2c9fa", + "branch_name": "main", + "created_at": "2023-02-09 14:12:03", + "last_edited_at": "2023-02-09 14:12:03", + "data_hash": "aaa6bc82-9f89-4545-adbb-f271bf28cf99", + "annotation_task_status": "QUEUED", + "is_shadow_data": False, + "dataset_hash": "b02ba3d9-883b-4c5e-ba09-751072ccfc57", + "dataset_title": "Text Dataset", + "data_title": "airbnb.html", + "data_type": "plain_text", + "data_units": { + "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2": { + "data_hash": "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2", + "data_title": "sample-audio.mp3", + "data_type": "text/html", + "data_sequence": 0, + "data_link": "text-link", + "labels": {}, + } + }, + "object_answers": { + "textObjectHash": { + "objectHash": "textObjectHash", + "featureHash": "textObjectFeatureNodeHash", + "classifications": [], + "range": [], + "range_html": [ + { + "start": { + "node": "start_node", + "offset": 5, + }, + "end": { + "node": "end_node", + "offset": 10, + }, + } + ], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + "name": "text object", + "value": "text_object", + "color": "#A4DD00", + "shape": "text", + }, + }, + "classification_answers": { + "textClassificationHash": { + "classificationHash": "textClassificationHash", + "featureHash": "jPOcEsbw", + "classifications": [ + { + "name": "Text classification", + "value": "text_classification", + "answers": "Text Answer", + "featureHash": "OxrtEM+v", + "manualAnnotation": True, + }, + ], + "range": [], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + }, + "radioClassificationHash": { + "classificationHash": "radioClassificationHash", + "featureHash": "NzIxNTU1", + "classifications": [ + { + "name": "Radio classification 1", + "value": "radio_classification_1", + "answers": [ + { + "name": "cl 1 option 1", + "value": "cl_1_option_1", + "featureHash": "MTcwMjM5", + } + ], + "featureHash": "MjI5MTA5", + "manualAnnotation": True, + }, + { + "name": "cl 1 2 text", + "value": "cl_1_2_text", + "answers": "Nested Text Answer", + "featureHash": "MTg0MjIw", + "manualAnnotation": True, + }, + ], + "range": [], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + }, + "checklistClassificationHash": { + "classificationHash": "checklistClassificationHash", + "featureHash": "3DuQbFxo", + "classifications": [ + { + "name": "Checklist classification", + "value": "checklist_classification", + "answers": [ + { + "name": "Checklist classification answer 1", + "value": "checklist_classification_answer_1", + "featureHash": "fvLjF0qZ", + }, + { + "name": "Checklist classification answer 2", + "value": "checklist_classification_answer_2", + "featureHash": "a4r7nK9i", + }, + ], + "featureHash": "9mwWr3OE", + "manualAnnotation": True, + }, + ], + "range": [], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + }, + }, + "object_actions": {}, + "label_status": "LABEL_IN_PROGRESS", +} + +EMPTY_HTML_TEXT_LABELS = { + "label_hash": "0aea5ac7-cbc0-4451-a242-e22445d2c9fa", + "branch_name": "main", + "created_at": "2023-02-09 14:12:03", + "last_edited_at": "2023-02-09 14:12:03", + "data_hash": "aaa6bc82-9f89-4545-adbb-f271bf28cf99", + "annotation_task_status": "QUEUED", + "is_shadow_data": False, + "dataset_hash": "b02ba3d9-883b-4c5e-ba09-751072ccfc57", + "dataset_title": "Text Dataset", + "data_title": "airbnb.html", + "file_type": "text/html", + "data_type": "plain_text", + "data_units": { + "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2": { + "data_hash": "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2", + "data_title": "sample-audio.mp3", + "data_type": "text/html", + "data_sequence": 0, + "data_link": "text-link", + "labels": {}, + } + }, + "object_answers": {}, + "classification_answers": {}, + "object_actions": {}, + "label_status": "LABEL_IN_PROGRESS", +} diff --git a/tests/objects/data/plain_text.py b/tests/objects/data/plain_text.py new file mode 100644 index 00000000..c23f7cb6 --- /dev/null +++ b/tests/objects/data/plain_text.py @@ -0,0 +1,153 @@ +PLAIN_TEXT_LABELS = { + "label_hash": "0aea5ac7-cbc0-4451-a242-e22445d2c9fa", + "branch_name": "main", + "created_at": "2023-02-09 14:12:03", + "last_edited_at": "2023-02-09 14:12:03", + "data_hash": "aaa6bc82-9f89-4545-adbb-f271bf28cf99", + "annotation_task_status": "QUEUED", + "is_shadow_data": False, + "dataset_hash": "b02ba3d9-883b-4c5e-ba09-751072ccfc57", + "dataset_title": "Text Dataset", + "data_title": "text.txt", + "data_type": "plain_text", + "data_units": { + "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2": { + "data_hash": "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2", + "data_title": "text.txt", + "data_type": "text/plain", + "data_sequence": 0, + "data_link": "text-link", + "labels": {}, + } + }, + "object_answers": { + "textObjectHash": { + "objectHash": "textObjectHash", + "featureHash": "textObjectFeatureNodeHash", + "classifications": [], + "range": [[0, 5]], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + "name": "text object", + "value": "text_object", + "color": "#A4DD00", + "shape": "text", + }, + }, + "classification_answers": { + "textClassificationHash": { + "classificationHash": "textClassificationHash", + "featureHash": "jPOcEsbw", + "classifications": [ + { + "name": "Text classification", + "value": "text_classification", + "answers": "Text Answer", + "featureHash": "OxrtEM+v", + "manualAnnotation": True, + }, + ], + "range": [], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + }, + "radioClassificationHash": { + "classificationHash": "radioClassificationHash", + "featureHash": "NzIxNTU1", + "classifications": [ + { + "name": "Radio classification 1", + "value": "radio_classification_1", + "answers": [ + { + "name": "cl 1 option 1", + "value": "cl_1_option_1", + "featureHash": "MTcwMjM5", + } + ], + "featureHash": "MjI5MTA5", + "manualAnnotation": True, + }, + { + "name": "cl 1 2 text", + "value": "cl_1_2_text", + "answers": "Nested Text Answer", + "featureHash": "MTg0MjIw", + "manualAnnotation": True, + }, + ], + "range": [], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + }, + "checklistClassificationHash": { + "classificationHash": "checklistClassificationHash", + "featureHash": "3DuQbFxo", + "classifications": [ + { + "name": "Checklist classification", + "value": "checklist_classification", + "answers": [ + { + "name": "Checklist classification answer 1", + "value": "checklist_classification_answer_1", + "featureHash": "fvLjF0qZ", + }, + { + "name": "Checklist classification answer 2", + "value": "checklist_classification_answer_2", + "featureHash": "a4r7nK9i", + }, + ], + "featureHash": "9mwWr3OE", + "manualAnnotation": True, + }, + ], + "range": [], + "createdBy": "user1Hash", + "createdAt": "Tue, 05 Nov 2024 09:41:37 ", + "lastEditedBy": "user1Hash", + "lastEditedAt": "Tue, 05 Nov 2024 09:41:37 ", + "manualAnnotation": True, + }, + }, + "object_actions": {}, + "label_status": "LABEL_IN_PROGRESS", +} + +EMPTY_PLAIN_TEXT_LABELS = { + "label_hash": "0aea5ac7-cbc0-4451-a242-e22445d2c9fa", + "branch_name": "main", + "created_at": "2023-02-09 14:12:03", + "last_edited_at": "2023-02-09 14:12:03", + "data_hash": "aaa6bc82-9f89-4545-adbb-f271bf28cf99", + "annotation_task_status": "QUEUED", + "is_shadow_data": False, + "dataset_hash": "b02ba3d9-883b-4c5e-ba09-751072ccfc57", + "dataset_title": "Text Dataset", + "data_title": "text.txt", + "data_type": "plain_text", + "data_units": { + "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2": { + "data_hash": "cd53f484-c9ab-4fd1-9c14-5b34d4e42ba2", + "data_title": "text.txt", + "data_type": "text/plain", + "data_sequence": 0, + "data_link": "text-link", + "labels": {}, + } + }, + "object_answers": {}, + "classification_answers": {}, + "object_actions": {}, + "label_status": "LABEL_IN_PROGRESS", +} diff --git a/tests/objects/test_label_structure.py b/tests/objects/test_label_structure.py index 0c9a111c..8ded1de7 100644 --- a/tests/objects/test_label_structure.py +++ b/tests/objects/test_label_structure.py @@ -22,10 +22,13 @@ from encord.objects.coordinates import ( AudioCoordinates, BoundingBoxCoordinates, + HtmlCoordinates, PointCoordinate, PolygonCoordinates, + TextCoordinates, ) from encord.objects.frames import Range +from encord.objects.html_node import HtmlNode, HtmlRange from encord.objects.options import Option from encord.orm.label_row import LabelRowMetadata, LabelStatus from tests.objects.common import FAKE_LABEL_ROW_METADATA @@ -33,6 +36,8 @@ from tests.objects.data.all_types_ontology_structure import all_types_structure from tests.objects.data.audio_labels import EMPTY_AUDIO_LABELS from tests.objects.data.empty_image_group import empty_image_group_labels +from tests.objects.data.html_text_labels import EMPTY_HTML_TEXT_LABELS +from tests.objects.data.plain_text import EMPTY_PLAIN_TEXT_LABELS from tests.objects.test_label_structure_converter import ontology_from_dict box_ontology_item = all_types_structure.get_child_by_hash("MjI2NzEy", Object) @@ -41,43 +46,45 @@ audio_obj_ontology_item = all_types_structure.get_child_by_hash("KVfzNkFy", Object) -nested_box_ontology_item = all_types_structure.get_child_by_hash("MTA2MjAx") -text_attribute_1 = all_types_structure.get_child_by_hash("OTkxMjU1") -checklist_attribute_1 = all_types_structure.get_child_by_hash("ODcxMDAy") -checklist_attribute_1_option_1 = all_types_structure.get_child_by_hash("MTE5MjQ3") -checklist_attribute_1_option_2 = all_types_structure.get_child_by_hash("Nzg3MDE3") - -deeply_nested_polygon_item = all_types_structure.get_child_by_hash("MTM1MTQy") -nested_polygon_text = all_types_structure.get_child_by_hash("OTk555U1") -nested_polygon_checklist = all_types_structure.get_child_by_hash("ODc555Ay") -nested_polygon_checklist_option_1 = all_types_structure.get_child_by_hash("MT5555Q3") -nested_polygon_checklist_option_2 = all_types_structure.get_child_by_hash("Nzg5555E3") -radio_attribute_level_1 = all_types_structure.get_child_by_hash("MTExMjI3") -radio_nested_option_1 = all_types_structure.get_child_by_hash("MTExNDQ5") -radio_nested_option_1_text = all_types_structure.get_child_by_hash("MjE2OTE0") -radio_nested_option_2 = all_types_structure.get_child_by_hash("MTcxMjAy") -radio_nested_option_2_checklist = all_types_structure.get_child_by_hash("ODc666Ay") -radio_nested_option_2_checklist_option_1 = all_types_structure.get_child_by_hash("MT66665Q3") -radio_nested_option_2_checklist_option_2 = all_types_structure.get_child_by_hash("Nzg66665E3") - -keypoint_dynamic = all_types_structure.get_child_by_hash("MTY2MTQx") -dynamic_text: TextAttribute = all_types_structure.get_child_by_hash("OTkxMjU1") -dynamic_checklist = all_types_structure.get_child_by_hash("ODcxMDAy") -dynamic_checklist_option_1 = all_types_structure.get_child_by_hash("MTE5MjQ3") -dynamic_checklist_option_2 = all_types_structure.get_child_by_hash("Nzg3MDE3") -dynamic_radio = all_types_structure.get_child_by_hash("MTExM9I3", type_=RadioAttribute) -dynamic_radio_option_1 = all_types_structure.get_child_by_hash("MT9xNDQ5", type_=Option) # Dynamic and deeply nested. -dynamic_radio_option_2 = all_types_structure.get_child_by_hash("9TcxMjAy", type_=Option) # Dynamic and deeply nested. +text_obj_ontology_item = all_types_structure.get_child_by_hash("textFeatureNodeHash", Object) + +nested_box_ontology_item = all_types_structure.get_child_by_hash("MTA2MjAx", Object) +text_attribute_1 = all_types_structure.get_child_by_hash("OTkxMjU1", TextAttribute) +checklist_attribute_1 = all_types_structure.get_child_by_hash("ODcxMDAy", ChecklistAttribute) +checklist_attribute_1_option_1 = all_types_structure.get_child_by_hash("MTE5MjQ3", Option) +checklist_attribute_1_option_2 = all_types_structure.get_child_by_hash("Nzg3MDE3", Option) + +deeply_nested_polygon_item = all_types_structure.get_child_by_hash("MTM1MTQy", Object) +nested_polygon_text = all_types_structure.get_child_by_hash("OTk555U1", TextAttribute) +nested_polygon_checklist = all_types_structure.get_child_by_hash("ODc555Ay", ChecklistAttribute) +nested_polygon_checklist_option_1 = all_types_structure.get_child_by_hash("MT5555Q3", Option) +nested_polygon_checklist_option_2 = all_types_structure.get_child_by_hash("Nzg5555E3", Option) +radio_attribute_level_1 = all_types_structure.get_child_by_hash("MTExMjI3", RadioAttribute) +radio_nested_option_1 = all_types_structure.get_child_by_hash("MTExNDQ5", Option) +radio_nested_option_1_text = all_types_structure.get_child_by_hash("MjE2OTE0", TextAttribute) +radio_nested_option_2 = all_types_structure.get_child_by_hash("MTcxMjAy", Option) +radio_nested_option_2_checklist = all_types_structure.get_child_by_hash("ODc666Ay", ChecklistAttribute) +radio_nested_option_2_checklist_option_1 = all_types_structure.get_child_by_hash("MT66665Q3", Option) +radio_nested_option_2_checklist_option_2 = all_types_structure.get_child_by_hash("Nzg66665E3", Option) + +keypoint_dynamic = all_types_structure.get_child_by_hash("MTY2MTQx", Object) +dynamic_text = all_types_structure.get_child_by_hash("OTkxMjU1", TextAttribute) +dynamic_checklist = all_types_structure.get_child_by_hash("ODcxMDAy", ChecklistAttribute) +dynamic_checklist_option_1 = all_types_structure.get_child_by_hash("MTE5MjQ3", Option) +dynamic_checklist_option_2 = all_types_structure.get_child_by_hash("Nzg3MDE3", Option) +dynamic_radio = all_types_structure.get_child_by_hash("MTExM9I3", RadioAttribute) +dynamic_radio_option_1 = all_types_structure.get_child_by_hash("MT9xNDQ5", Option) # Dynamic and deeply nested. +dynamic_radio_option_2 = all_types_structure.get_child_by_hash("9TcxMjAy", Option) # Dynamic and deeply nested. text_classification = all_types_structure.get_child_by_hash("jPOcEsbw", Classification) -text_classification_attribute: TextAttribute = all_types_structure.get_child_by_hash("OxrtEM+v", TextAttribute) -radio_classification = all_types_structure.get_child_by_hash("NzIxNTU1") -radio_classification_option_1 = all_types_structure.get_child_by_hash("MTcwMjM5") -radio_classification_option_2 = all_types_structure.get_child_by_hash("MjUzMTg1") -radio_classification_option_2_text = all_types_structure.get_child_by_hash("MTg0MjIw") -checklist_classification: Classification = all_types_structure.get_child_by_hash("3DuQbFxo") -checklist_classification_option_1 = all_types_structure.get_child_by_hash("fvLjF0qZ") -checklist_classification_option_2 = all_types_structure.get_child_by_hash("a4r7nK9i") +text_classification_attribute = all_types_structure.get_child_by_hash("OxrtEM+v", TextAttribute) +radio_classification = all_types_structure.get_child_by_hash("NzIxNTU1", Classification) +radio_classification_option_1 = all_types_structure.get_child_by_hash("MTcwMjM5", Option) +radio_classification_option_2 = all_types_structure.get_child_by_hash("MjUzMTg1", Option) +radio_classification_option_2_text = all_types_structure.get_child_by_hash("MTg0MjIw", TextAttribute) +checklist_classification = all_types_structure.get_child_by_hash("3DuQbFxo", Classification) +checklist_classification_option_1 = all_types_structure.get_child_by_hash("fvLjF0qZ", Option) +checklist_classification_option_2 = all_types_structure.get_child_by_hash("a4r7nK9i", Option) BOX_COORDINATES = BoundingBoxCoordinates( height=0.1, @@ -460,7 +467,7 @@ def test_classification_answering_with_ontology_access() -> None: radio_classification_ = all_types_structure.get_child_by_title("Radio classification 1", Classification) radio_instance = radio_classification_.create_instance() - radio_classification_attribute_1: ChecklistAttribute = radio_classification_.get_child_by_title( + radio_classification_attribute_1 = radio_classification_.get_child_by_title( "Radio classification 1", type_=RadioAttribute ) # Different `type_` with generic `Attribute` @@ -657,50 +664,42 @@ def test_add_and_get_classification_instances_to_audio_label_row(ontology): label_row.from_labels_dict(EMPTY_AUDIO_LABELS) classification_instance_1 = ClassificationInstance(text_classification, range_only=True) - classification_instance_2 = ClassificationInstance(text_classification, range_only=True) - classification_instance_3 = ClassificationInstance(checklist_classification, range_only=True) + classification_instance_2 = ClassificationInstance(checklist_classification, range_only=True) - classification_instance_1.set_for_frames(Range(1, 2)) - classification_instance_2.set_for_frames(Range(3, 4)) - classification_instance_3.set_for_frames(Range(1, 4)) + classification_instance_1.set_for_frames(Range(0, 0)) + classification_instance_2.set_for_frames(Range(0, 0)) label_row.add_classification_instance(classification_instance_1) label_row.add_classification_instance(classification_instance_2) - label_row.add_classification_instance(classification_instance_3) classification_instances = label_row.get_classification_instances() + assert set(classification_instances) == { classification_instance_1, classification_instance_2, - classification_instance_3, } filtered_classification_instances = label_row.get_classification_instances(text_classification) - assert set(filtered_classification_instances) == {classification_instance_1, classification_instance_2} + assert set(filtered_classification_instances) == {classification_instance_1} overlapping_classification_instance = ClassificationInstance(text_classification, range_only=True) - overlapping_classification_instance.set_for_frames(1) + overlapping_classification_instance.set_for_frames(0) - with pytest.raises(LabelRowError): + with pytest.raises(LabelRowError) as e: label_row.add_classification_instance(overlapping_classification_instance) - overlapping_classification_instance.remove_from_frames(1) - overlapping_classification_instance.set_for_frames(5) - label_row.add_classification_instance(overlapping_classification_instance) - with pytest.raises(LabelRowError): - overlapping_classification_instance.set_for_frames(1) + assert e.value.message == ( + f"A ClassificationInstance '{overlapping_classification_instance.classification_hash}' was already added " + "and has overlapping frames. Overlapping frames that were " + "found are `[(0:0)]`. Make sure that you only add classifications " + "which are on frames where the same type of classification does not yet exist." + ) # Do not raise if overwrite flag is passed - overlapping_classification_instance.set_for_frames(1, overwrite=True) + overlapping_classification_instance.set_for_frames(0, overwrite=True) label_row.remove_classification(classification_instance_1) - overlapping_classification_instance.set_for_frames(1) - - with pytest.raises(LabelRowError): - overlapping_classification_instance.set_for_frames(3) - - classification_instance_2.remove_from_frames(3) - overlapping_classification_instance.set_for_frames(3) + overlapping_classification_instance.set_for_frames(0) def test_object_instance_answer_for_static_attributes(): @@ -1005,6 +1004,8 @@ def test_frame_view(ontology) -> None: frame_view.add_classification_instance(classification_instance) frames = label_row.get_frame_views() + assert label_row_metadata.duration is not None + assert label_row_metadata.frames_per_second is not None assert len(frames) == label_row_metadata.duration * label_row_metadata.frames_per_second frame_num = 0 @@ -1062,6 +1063,52 @@ def empty_audio_label_row() -> LabelRowV2: return label_row +@pytest.fixture +def empty_html_text_label_row() -> LabelRowV2: + label_row_metadata_dict = asdict(FAKE_LABEL_ROW_METADATA) + label_row_metadata_dict["data_type"] = "plain_text" + label_row_metadata_dict["file_type"] = "text/html" + label_row_metadata = LabelRowMetadata(**label_row_metadata_dict) + + label_row = LabelRowV2(label_row_metadata, Mock(), ontology_from_dict(all_ontology_types)) + label_row.from_labels_dict(EMPTY_HTML_TEXT_LABELS) + + return label_row + + +@pytest.fixture +def empty_plain_text_label_row() -> LabelRowV2: + label_row_metadata_dict = asdict(FAKE_LABEL_ROW_METADATA) + label_row_metadata_dict["data_type"] = "plain_text" + label_row_metadata_dict["file_type"] = "text/plain" + label_row_metadata = LabelRowMetadata(**label_row_metadata_dict) + + label_row = LabelRowV2(label_row_metadata, Mock(), ontology_from_dict(all_ontology_types)) + label_row.from_labels_dict(EMPTY_PLAIN_TEXT_LABELS) + + return label_row + + +def test_non_geometric_label_rows_must_use_classification_instance_with_range_only( + ontology, + empty_audio_label_row: LabelRowV2, + empty_plain_text_label_row: LabelRowV2, + empty_html_text_label_row: LabelRowV2, +): + classification_instance = ClassificationInstance(checklist_classification) + classification_instance.set_for_frames(Range(start=0, end=0)) + for label_row in [empty_plain_text_label_row, empty_html_text_label_row, empty_html_text_label_row]: + with pytest.raises(LabelRowError) as e: + label_row.add_classification_instance(classification_instance) + assert str(e.value.message) == ( + f"To add a ClassificationInstance object to a label row where data_type = {label_row.data_type}," + "the ClassificationInstance object needs to be created with the " + "range_only property set to True." + "You can do ClassificationInstance(range_only=True) or " + "Classification.create_instance(range_only=True) to achieve this." + ) + + def test_non_range_classification_cannot_be_added_to_audio_label_row(ontology): label_row_metadata_dict = asdict(FAKE_LABEL_ROW_METADATA) label_row_metadata_dict["frames_per_second"] = 1000 @@ -1081,54 +1128,43 @@ def test_non_range_classification_cannot_be_added_to_audio_label_row(ontology): label_row.add_classification_instance(classification_instance) -def test_audio_classification_overwrite(ontology, empty_audio_label_row: LabelRowV2): - classification_instance = ClassificationInstance(checklist_classification, range_only=True) - classification_instance.set_for_frames(Range(start=0, end=100)) - empty_audio_label_row.add_classification_instance(classification_instance) - - with pytest.raises(LabelRowError): - classification_instance.set_for_frames(Range(start=5, end=20)) - - with pytest.raises(LabelRowError): - classification_instance.set_for_frames(Range(start=100, end=101)) - - # No error when set overwrite to True - classification_instance.set_for_frames(Range(start=100, end=101), overwrite=True) - range_list = classification_instance.range_list - assert len(range_list) == 1 - assert range_list[0].start == 0 - assert range_list[0].end == 101 - - -def test_audio_classification_exceed_max_frames(ontology, empty_audio_label_row: LabelRowV2): - classification_instance = ClassificationInstance(checklist_classification, range_only=True) - classification_instance.set_for_frames(Range(start=0, end=100)) - empty_audio_label_row.add_classification_instance(classification_instance) +def test_non_geometric_label_rows_can_only_have_classifications_on_frame_0( + ontology, + empty_audio_label_row: LabelRowV2, + empty_plain_text_label_row: LabelRowV2, + empty_html_text_label_row: LabelRowV2, +): + for label_row in [empty_audio_label_row, empty_html_text_label_row, empty_plain_text_label_row]: + classification_instance = ClassificationInstance(checklist_classification, range_only=True) + classification_instance.set_for_frames(Range(start=0, end=0)) + label_row.add_classification_instance(classification_instance) - with pytest.raises(LabelRowError): - classification_instance.set_for_frames(Range(start=200, end=5000)) + with pytest.raises(LabelRowError) as e: + classification_instance.set_for_frames(Range(start=0, end=1)) - range_list = classification_instance.range_list - assert len(range_list) == 1 - assert range_list[0].start == 0 - assert range_list[0].end == 100 + assert e.value.message == ( + "For audio files and text files, classifications can only be " + "attached to frame=0 You may use " + "`ClassificationInstance.set_for_frames(frames=Range(start=0, end=0))`." + ) def test_audio_object_exceed_max_frames(ontology, empty_audio_label_row: LabelRowV2): object_instance = ObjectInstance(audio_obj_ontology_item) - object_instance.set_for_frames(AudioCoordinates(), Range(start=0, end=100)) + object_instance.set_for_frames(AudioCoordinates(range=[Range(start=0, end=100)])) empty_audio_label_row.add_object_instance(object_instance) with pytest.raises(LabelRowError): - object_instance.set_for_frames(AudioCoordinates(), Range(start=200, end=5000)) + object_instance.set_for_frames(AudioCoordinates(range=[Range(start=200, end=5000)])) range_list = object_instance.range_list + assert range_list is not None assert len(range_list) == 1 assert range_list[0].start == 0 assert range_list[0].end == 100 -def test_get_annotations_from_audio_classification(ontology) -> None: +def test_get_annotations_from_non_geometric_classification(ontology) -> None: now = datetime.datetime.now() classification_instance = ClassificationInstance(checklist_classification, range_only=True) @@ -1159,8 +1195,7 @@ def test_get_annotations_from_audio_object(ontology) -> None: object_instance = ObjectInstance(audio_obj_ontology_item) object_instance.set_for_frames( - AudioCoordinates(), - Range(start=0, end=1500), + AudioCoordinates(range=[Range(start=0, end=1500)]), created_at=now, created_by="user1", last_edited_at=now, @@ -1181,24 +1216,17 @@ def test_get_annotations_from_audio_object(ontology) -> None: assert annotation.reviews is None -def test_audio_classification_can_be_added_edited_and_removed(ontology, empty_audio_label_row: LabelRowV2): +def test_audio_classification_can_be_added_and_removed(ontology, empty_audio_label_row: LabelRowV2): label_row = empty_audio_label_row classification_instance = ClassificationInstance(checklist_classification, range_only=True) - classification_instance.set_for_frames(Range(start=0, end=1500)) + classification_instance.set_for_frames(Range(start=0, end=0)) range_list = classification_instance.range_list assert len(range_list) == 1 assert range_list[0].start == 0 - assert range_list[0].end == 1500 + assert range_list[0].end == 0 label_row.add_classification_instance(classification_instance) assert len(label_row.get_classification_instances()) == 1 - classification_instance.set_for_frames(Range(start=2000, end=2499)) - range_list = classification_instance.range_list - assert len(range_list) == 2 - assert range_list[0].start == 0 - assert range_list[0].end == 1500 - assert range_list[1].start == 2000 - assert range_list[1].end == 2499 label_row.remove_classification(classification_instance) assert len(label_row.get_classification_instances()) == 0 @@ -1207,8 +1235,9 @@ def test_audio_classification_can_be_added_edited_and_removed(ontology, empty_au def test_audio_object_can_be_added_edited_and_removed(ontology, empty_audio_label_row: LabelRowV2): label_row = empty_audio_label_row obj_instance = ObjectInstance(audio_obj_ontology_item) - obj_instance.set_for_frames(AudioCoordinates(), Range(start=0, end=1500)) + obj_instance.set_for_frames(AudioCoordinates(range=[Range(start=0, end=1500)])) range_list = obj_instance.range_list + assert range_list is not None assert len(range_list) == 1 assert range_list[0].start == 0 assert range_list[0].end == 1500 @@ -1216,19 +1245,279 @@ def test_audio_object_can_be_added_edited_and_removed(ontology, empty_audio_labe label_row.add_object_instance(obj_instance) assert len(label_row.get_classification_instances()) == 0 assert len(label_row.get_object_instances()) == 1 - obj_instance.set_for_frames(AudioCoordinates(), Range(start=2000, end=2499)) - range_list = obj_instance.range_list - assert len(range_list) == 2 - assert range_list[0].start == 0 - assert range_list[0].end == 1500 - assert range_list[1].start == 2000 - assert range_list[1].end == 2499 - - obj_instance.remove_from_frames(Range(start=0, end=1500)) + obj_instance.set_for_frames(AudioCoordinates(range=[Range(start=2000, end=2499)]), overwrite=True) range_list = obj_instance.range_list + assert range_list is not None assert len(range_list) == 1 assert range_list[0].start == 2000 assert range_list[0].end == 2499 + obj_instance.remove_from_frames(frames=0) + range_list = obj_instance.range_list + assert range_list is None + label_row.remove_object(obj_instance) assert len(label_row.get_object_instances()) == 0 + + +def test_get_annotations_from_html_text_object(ontology) -> None: + now = datetime.datetime.now() + + range = HtmlRange( + start=HtmlNode(node="/html[1]/body[1]/div[1]/text()[1]", offset=50), + end=HtmlNode(node="/html[1]/body[1]/div[1]/text()[1]", offset=60), + ) + + object_instance = ObjectInstance(text_obj_ontology_item) + object_instance.set_for_frames( + HtmlCoordinates(range=[range]), + created_at=now, + created_by="user1", + last_edited_at=now, + last_edited_by="user2", + ) + + annotations = object_instance.get_annotations() + + assert len(annotations) == 1 + + annotation = annotations[0] + assert annotation.manual_annotation == DEFAULT_MANUAL_ANNOTATION + assert annotation.confidence == DEFAULT_CONFIDENCE + assert annotation.created_at == now + assert annotation.created_by == "user1" + assert annotation.last_edited_at == now + assert annotation.last_edited_by == "user2" + assert annotation.reviews is None + + +def test_html_text_classification_can_be_added_removed(ontology, empty_html_text_label_row: LabelRowV2): + label_row = empty_html_text_label_row + classification_instance = ClassificationInstance(checklist_classification, range_only=True) + classification_instance.set_for_frames(Range(start=0, end=0)) + range_list = classification_instance.range_list + assert len(range_list) == 1 + assert range_list[0].start == 0 + assert range_list[0].end == 0 + + label_row.add_classification_instance(classification_instance) + assert len(label_row.get_classification_instances()) == 1 + + label_row.remove_classification(classification_instance) + assert len(label_row.get_classification_instances()) == 0 + + +def test_html_text_object_can_be_added_edited_and_removed(ontology, empty_html_text_label_row: LabelRowV2): + label_row = empty_html_text_label_row + obj_instance = ObjectInstance(text_obj_ontology_item) + + initial_range = [ + HtmlRange( + start=HtmlNode(node="start_node", offset=50), + end=HtmlNode(node="end_node", offset=100), + ) + ] + + obj_instance.set_for_frames(HtmlCoordinates(range=initial_range)) + range = obj_instance.range_html + + assert range is not None + assert len(range) == 1 + assert range[0].start.node == "start_node" + assert range[0].start.offset == 50 + assert range[0].end.node == "end_node" + assert range[0].end.offset == 100 + + label_row.add_object_instance(obj_instance) + assert len(label_row.get_classification_instances()) == 0 + assert len(label_row.get_object_instances()) == 1 + + edited_range = [ + HtmlRange( + start=HtmlNode(node="start_node_edited", offset=70), + end=HtmlNode(node="end_node_edited", offset=90), + ), + HtmlRange( + start=HtmlNode(node="start_node_new", offset=5), + end=HtmlNode(node="end_node_new", offset=7), + ), + ] + + obj_instance.set_for_frames(HtmlCoordinates(range=edited_range), overwrite=True) + range = obj_instance.range_html + assert range is not None + assert len(range) == 2 + assert range[0].start.node == "start_node_edited" + assert range[0].start.offset == 70 + assert range[0].end.node == "end_node_edited" + assert range[0].end.offset == 90 + + assert range[1].start.node == "start_node_new" + assert range[1].start.offset == 5 + assert range[1].end.node == "end_node_new" + assert range[1].end.offset == 7 + + obj_instance.remove_from_frames(frames=0) + range = obj_instance.range_html + assert range is None + + +def test_html_text_object_cannot_be_added_to_non_html_label_row( + ontology, empty_audio_label_row: LabelRowV2, empty_plain_text_label_row: LabelRowV2 +) -> None: + obj_instance = ObjectInstance(text_obj_ontology_item) + + initial_range = [ + HtmlRange( + start=HtmlNode(node="start_node", offset=50), + end=HtmlNode(node="end_node", offset=100), + ) + ] + + obj_instance.set_for_frames(HtmlCoordinates(range=initial_range)) + range = obj_instance.range_html + + assert range is not None + assert len(range) == 1 + assert range[0].start.node == "start_node" + assert range[0].start.offset == 50 + assert range[0].end.node == "end_node" + assert range[0].end.offset == 100 + + with pytest.raises(LabelRowError) as e: + empty_audio_label_row.add_object_instance(obj_instance) + + assert str(e.value.message) == ( + "Unable to assign object instance with a html range to a non-html file. " + f"Please ensure the object instance does not have coordinates of type {HtmlCoordinates}." + ) + + with pytest.raises(LabelRowError) as e: + empty_plain_text_label_row.add_object_instance(obj_instance) + + assert str(e.value.message) == ( + "Unable to assign object instance with a html range to a non-html file. " + f"Please ensure the object instance does not have coordinates of type {HtmlCoordinates}." + ) + + +def test_set_for_frames_with_range_html_throws_error_if_used_incorrectly( + ontology, empty_html_text_label_row: LabelRowV2, empty_plain_text_label_row: LabelRowV2 +): + range_html = [ + HtmlRange( + start=HtmlNode(node="start_node", offset=50), + end=HtmlNode(node="end_node", offset=100), + ) + ] + + # Adding HtmlCoordinates to an object instance where the object's shape is NOT text + audio_obj_instance = ObjectInstance(audio_obj_ontology_item) + with pytest.raises(LabelRowError) as e: + audio_obj_instance.set_for_frames(coordinates=HtmlCoordinates(range=range_html)) + + assert ( + str(e.value.message) + == f"Expected coordinates of one of the following types: `[{AudioCoordinates}]`, but got type `{HtmlCoordinates}`." + ) + + # Adding HtmlCoordinates to an object instance which is attached to a label row where the + # file type is NOT 'text/html' + html_text_obj_instance = ObjectInstance(text_obj_ontology_item) + html_text_obj_instance.set_for_frames(coordinates=HtmlCoordinates(range=range_html)) + + with pytest.raises(LabelRowError) as e: + empty_plain_text_label_row.add_object_instance(html_text_obj_instance) + + assert ( + str(e.value.message) == "Unable to assign object instance with a html range to a non-html file. " + f"Please ensure the object instance does not have coordinates of type {HtmlCoordinates}." + ) + + +def test_get_annotations_from_plain_text_object(ontology) -> None: + now = datetime.datetime.now() + + object_instance = ObjectInstance(text_obj_ontology_item) + object_instance.set_for_frames( + TextCoordinates(range=[Range(start=0, end=1500)]), + created_at=now, + created_by="user1", + last_edited_at=now, + last_edited_by="user2", + ) + + annotations = object_instance.get_annotations() + + assert len(annotations) == 1 + + annotation = annotations[0] + assert annotation.manual_annotation == DEFAULT_MANUAL_ANNOTATION + assert annotation.confidence == DEFAULT_CONFIDENCE + assert annotation.created_at == now + assert annotation.created_by == "user1" + assert annotation.last_edited_at == now + assert annotation.last_edited_by == "user2" + assert annotation.reviews is None + + +def test_plain_text_classification_can_be_added_and_removed(ontology, empty_plain_text_label_row: LabelRowV2): + label_row = empty_plain_text_label_row + classification_instance = ClassificationInstance(checklist_classification, range_only=True) + classification_instance.set_for_frames(Range(start=0, end=0)) + range_list = classification_instance.range_list + assert len(range_list) == 1 + assert range_list[0].start == 0 + assert range_list[0].end == 0 + + label_row.add_classification_instance(classification_instance) + assert len(label_row.get_classification_instances()) == 1 + + label_row.remove_classification(classification_instance) + assert len(label_row.get_classification_instances()) == 0 + + +def test_plain_text_object_can_be_added_edited_and_removed(ontology, empty_plain_text_label_row: LabelRowV2): + label_row = empty_plain_text_label_row + obj_instance = ObjectInstance(text_obj_ontology_item) + + initial_range = [Range(start=0, end=50)] + obj_instance.set_for_frames(TextCoordinates(range=initial_range)) + range_list = obj_instance.range_list + + assert range_list is not None + assert len(range_list) == 1 + assert range_list[0].start == 0 + assert range_list[0].end == 50 + + label_row.add_object_instance(obj_instance) + assert len(label_row.get_classification_instances()) == 0 + assert len(label_row.get_object_instances()) == 1 + + edited_range = [Range(start=5, end=10)] + + obj_instance.set_for_frames(TextCoordinates(range=edited_range), overwrite=True) + range_list = obj_instance.range_list + assert range_list is not None + assert len(range_list) == 1 + assert range_list[0].start == 5 + assert range_list[0].end == 10 + + obj_instance.remove_from_frames(frames=0) + range_html = obj_instance.range_html + assert range_html is None + + +def test_plain_text_object_cannot_be_added_to_html_label_row(ontology, empty_html_text_label_row: LabelRowV2) -> None: + label_row = empty_html_text_label_row + obj_instance = ObjectInstance(text_obj_ontology_item) + + obj_instance.set_for_frames(TextCoordinates(range=[Range(start=0, end=50)])) + + with pytest.raises(LabelRowError) as e: + label_row.add_object_instance(obj_instance) + + assert str(e.value.message) == ( + "Unable to assign object instance without a html range to a html file. " + f"Please ensure the object instance exists on frame=0, and has coordinates of type {HtmlCoordinates}." + ) diff --git a/tests/objects/test_label_structure_converter.py b/tests/objects/test_label_structure_converter.py index 9aa5037f..58974861 100644 --- a/tests/objects/test_label_structure_converter.py +++ b/tests/objects/test_label_structure_converter.py @@ -34,10 +34,12 @@ empty_image_group_labels, empty_image_group_ontology, ) +from tests.objects.data.html_text_labels import HTML_TEXT_LABELS from tests.objects.data.image_group import image_group_labels, image_group_ontology from tests.objects.data.ontology_with_many_dynamic_classifications import ( ontology as ontology_with_many_dynamic_classifications, ) +from tests.objects.data.plain_text import PLAIN_TEXT_LABELS def ontology_from_dict(ontology_structure_dict: Dict): @@ -159,6 +161,40 @@ def test_serialise_audio_objects() -> None: ) +def test_serialise_html_text(): + label_row_metadata_dict = asdict(FAKE_LABEL_ROW_METADATA) + label_row_metadata_dict["frames_per_second"] = 1000 + label_row_metadata_dict["data_type"] = "plain_text" + label_row_metadata = LabelRowMetadata(**label_row_metadata_dict) + + label_row = LabelRowV2(label_row_metadata, Mock(), ontology_from_dict(all_ontology_types)) + label_row.from_labels_dict(HTML_TEXT_LABELS) + + actual = label_row.to_encord_dict() + deep_diff_enhanced( + HTML_TEXT_LABELS, + actual, + exclude_regex_paths=[r"\['reviews'\]", r"\['isDeleted'\]"], + ) + + +def test_serialise_plain_text(): + label_row_metadata_dict = asdict(FAKE_LABEL_ROW_METADATA) + label_row_metadata_dict["frames_per_second"] = 1000 + label_row_metadata_dict["data_type"] = "plain_text" + label_row_metadata = LabelRowMetadata(**label_row_metadata_dict) + + label_row = LabelRowV2(label_row_metadata, Mock(), ontology_from_dict(all_ontology_types)) + label_row.from_labels_dict(PLAIN_TEXT_LABELS) + + actual = label_row.to_encord_dict() + deep_diff_enhanced( + PLAIN_TEXT_LABELS, + actual, + exclude_regex_paths=[r"\['reviews'\]", r"\['isDeleted'\]"], + ) + + def test_serialise_dicom_with_dynamic_classifications(): label_row_metadata_dict = asdict(FAKE_LABEL_ROW_METADATA) label_row_metadata_dict["duration"] = None