diff --git a/encord/objects/coordinates.py b/encord/objects/coordinates.py index a690d081..52d4d88c 100644 --- a/encord/objects/coordinates.py +++ b/encord/objects/coordinates.py @@ -13,12 +13,13 @@ from __future__ import annotations from dataclasses import dataclass -from enum import Flag, auto +from enum import auto from typing import Any, Dict, List, Optional, Type, Union from encord.exceptions import LabelRowError from encord.objects.bitmask import BitmaskCoordinates from encord.objects.common import Shape +from encord.orm.analytics import CamelStrEnum from encord.orm.base_dto import BaseDTO @@ -272,7 +273,7 @@ def to_dict(self) -> dict: return {str(idx): {"x": value.x, "y": value.y} for idx, value in enumerate(self.values)} -class Visibility(Flag): +class Visibility(CamelStrEnum): """ An enumeration to represent the visibility state of an item. diff --git a/encord/objects/ontology_labels_impl.py b/encord/objects/ontology_labels_impl.py index d8ddcdd2..65f1c09c 100644 --- a/encord/objects/ontology_labels_impl.py +++ b/encord/objects/ontology_labels_impl.py @@ -56,6 +56,7 @@ PolylineCoordinates, RotatableBoundingBoxCoordinates, SkeletonCoordinates, + Visibility, ) from encord.objects.frames import Frames, Range, Ranges, frames_class_to_frames_list, frames_to_ranges from encord.objects.metadata import DICOMSeriesMetadata, DICOMSliceMetadata @@ -2284,9 +2285,20 @@ def _get_coordinates(self, frame_object_label: dict) -> Coordinates: elif "polyline" in frame_object_label: return PolylineCoordinates.from_dict(frame_object_label) elif "skeleton" in frame_object_label: + + def _with_visibility_enum(point: dict): + if point.get(Visibility.INVISIBLE.value): + point["visibility"] = Visibility.INVISIBLE + elif point.get(Visibility.OCCLUDED.value): + point["visibility"] = Visibility.OCCLUDED + elif point.get(Visibility.VISIBLE.value): + point["visibility"] = Visibility.VISIBLE + return point + + values = [_with_visibility_enum(pnt) for pnt in frame_object_label["skeleton"].values()] skeleton_frame_object_label = { "name": frame_object_label["name"], - "values": list(frame_object_label["skeleton"].values()), + "values": values, } return SkeletonCoordinates.from_dict(skeleton_frame_object_label) elif "bitmask" in frame_object_label: diff --git a/tests/objects/data/skeleton_coordinates.py b/tests/objects/data/skeleton_coordinates.py index 0fc659ef..1633ca8e 100644 --- a/tests/objects/data/skeleton_coordinates.py +++ b/tests/objects/data/skeleton_coordinates.py @@ -1,4 +1,4 @@ -from encord.objects.coordinates import SkeletonCoordinate, SkeletonCoordinates +from encord.objects.coordinates import SkeletonCoordinate, SkeletonCoordinates, Visibility ontology = { "objects": [ @@ -106,6 +106,7 @@ "color": "#000000", "value": "point_0", "featureHash": "1wthOoHe", + "visibility": "visible", }, "1": { "x": 0.4649, @@ -114,6 +115,7 @@ "color": "#000000", "value": "point_1", "featureHash": "KGp1oToz", + "visibility": "occluded", }, "2": { "x": 0.2356, @@ -122,6 +124,7 @@ "color": "#000000", "value": "point_2", "featureHash": "OqR+F4dN", + "visibility": "invisible", }, }, "manualAnnotation": True, @@ -139,6 +142,7 @@ "object_actions": {}, "label_status": "LABELLED", } + expected_coordinates = SkeletonCoordinates( values=[ SkeletonCoordinate( @@ -148,7 +152,7 @@ color="#000000", feature_hash="1wthOoHe", value="point_0", - visibility=None, + visibility=Visibility.VISIBLE, ), SkeletonCoordinate( x=0.4649, @@ -157,7 +161,7 @@ color="#000000", feature_hash="KGp1oToz", value="point_1", - visibility=None, + visibility=Visibility.OCCLUDED, ), SkeletonCoordinate( x=0.2356, @@ -166,7 +170,7 @@ color="#000000", feature_hash="OqR+F4dN", value="point_2", - visibility=None, + visibility=Visibility.INVISIBLE, ), ], name="Triangle", diff --git a/tests/objects/test_label_structure_converter.py b/tests/objects/test_label_structure_converter.py index 38f4743d..9aa5037f 100644 --- a/tests/objects/test_label_structure_converter.py +++ b/tests/objects/test_label_structure_converter.py @@ -339,9 +339,11 @@ def test_skeleton_template_coordinates(): assert len(obj_instances) == 1 obj_instance = obj_instances[0] - ann = obj_instance.get_annotations()[0] - assert ann.coordinates == skeleton_coordinates.expected_coordinates + annotation = obj_instance.get_annotations()[0] + assert annotation.coordinates == skeleton_coordinates.expected_coordinates + label_dict = label_row.to_encord_dict() - label_dict_obj = list(skeleton_coordinates.labels["data_units"].values())[0]["labels"]["objects"][0] - origin_obj = list(label_dict["data_units"].values())[0]["labels"]["objects"][0] - assert origin_obj["skeleton"] == label_dict_obj["skeleton"] + skeleton_dict = list(label_dict["data_units"].values())[0]["labels"]["objects"][0] + expected_skeleton_dict = list(skeleton_coordinates.labels["data_units"].values())[0]["labels"]["objects"][0] + + assert skeleton_dict["skeleton"] == expected_skeleton_dict["skeleton"]