From 36be608b22c614e2104bfaf1c6b091c0eda9ff68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frederik=20Hvilsh=C3=B8j?= <93145535+frederik-encord@users.noreply.github.com> Date: Tue, 12 Sep 2023 10:39:19 +0200 Subject: [PATCH] feat: key point label transformer (#642) --- .../src/components/explorer/Explorer.tsx | 8 ++++---- .../lib/labels/label_transformer.py | 18 +++++++++++++++++- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/encord_active/frontend/src/components/explorer/Explorer.tsx b/src/encord_active/frontend/src/components/explorer/Explorer.tsx index 4b2aee17c..a87d43ccd 100644 --- a/src/encord_active/frontend/src/components/explorer/Explorer.tsx +++ b/src/encord_active/frontend/src/components/explorer/Explorer.tsx @@ -1164,15 +1164,15 @@ const ImageWithPolygons = ({ str: return f"BounndingBoxLabel({self.class_}, {self.bounding_box})" +class KeypointLabel(Label): + point: Point + + def __str__(self) -> str: + return f"KeypointLabel({self.class_})" + + @dataclass class PolygonLabel: class_: str @@ -50,7 +57,7 @@ def __str__(self) -> str: return f"PolygonLabel({self.class_}, [{n}, 2])" -LabelOptions = Union[BoundingBoxLabel, ClassificationLabel, PolygonLabel] +LabelOptions = Union[BoundingBoxLabel, ClassificationLabel, PolygonLabel, KeypointLabel] class DataLabel(NamedTuple): @@ -191,6 +198,13 @@ def _add_bounding_box_label(self, label: BoundingBoxLabel, data_unit: dict): bbox_obj = make_object_dict(ont_obj, bbox_dict) data_unit.setdefault("labels", {}).setdefault("objects", []).append(bbox_obj) + def _add_keypoint_label(self, label: KeypointLabel, data_unit: dict): + shape = Shape.POINT + + ont_obj = self._get_object_by_name(label.class_, shape=shape, create_when_missing=True) + bbox_obj = make_object_dict(ont_obj, label.point) + data_unit.setdefault("labels", {}).setdefault("objects", []).append(bbox_obj) + def _add_polygon_label(self, label: PolygonLabel, data_unit: dict): ont_obj = self._get_object_by_name(label.class_, shape=Shape.POLYGON, create_when_missing=True) points = [Point(*r) for r in label.polygon] @@ -204,6 +218,8 @@ def _add_label(self, label: LabelOptions, label_row: dict, data_unit: dict): self._add_bounding_box_label(label, data_unit) elif isinstance(label, PolygonLabel): self._add_polygon_label(label, data_unit) + elif isinstance(label, KeypointLabel): + self._add_keypoint_label(label, data_unit) def add_labels(self, label_paths: List[Path], data_paths: List[Path]): if not self.label_transformer: