From b1b3efd3bb915ff303bec6bde01a4e40dc5d5e4a Mon Sep 17 00:00:00 2001 From: Gorkem Polat <110481017+Gorkem-Encord@users.noreply.github.com> Date: Wed, 29 Nov 2023 16:26:21 +0000 Subject: [PATCH] Feat: Add public methods for tagging images (#676) feat: Add methods for tagging --- src/encord_active/public/active_project.py | 95 ++++++++++++++++++---- 1 file changed, 81 insertions(+), 14 deletions(-) diff --git a/src/encord_active/public/active_project.py b/src/encord_active/public/active_project.py index 1a1ab4379..141f0ae2b 100644 --- a/src/encord_active/public/active_project.py +++ b/src/encord_active/public/active_project.py @@ -1,7 +1,22 @@ +from dataclasses import dataclass +from typing import Any +from uuid import UUID, uuid4 + import pandas as pd from encord.objects import OntologyStructure from sqlalchemy import MetaData, Table, create_engine, select, text from sqlalchemy.engine import Engine +from sqlalchemy.sql import Select +from sqlmodel import Session +from sqlmodel import select as sqlmodel_select + +from encord_active.db.models import ProjectTag, ProjectTaggedDataUnit + + +@dataclass +class DataUnitItem: + du_hash: str + frame: int def get_active_engine(path_to_db: str) -> Engine: @@ -25,20 +40,30 @@ def __init__(self, engine: Engine, project_name: str): else: self._project_hash = None - def get_ontology(self) -> OntologyStructure: - active_project = Table("active_project", self._metadata, autoload_with=self._engine) - - stmt = select(active_project.c.project_ontology).where(active_project.c.project_hash == f"{self._project_hash}") + with Session(engine) as sess: + self._existing_tags = { + tag.name: tag.tag_hash + for tag in sess.exec( + sqlmodel_select(ProjectTag).where(ProjectTag.project_hash == self._project_hash) + ).all() + } + sess.commit() + def _execute_statement(self, stmt: Select) -> Any: with self._engine.connect() as connection: result = connection.execute(stmt).fetchone() if result is not None: - ontology = result[0] + return result[0] else: - ontology = None + return None - return OntologyStructure.from_dict(ontology) + def get_ontology(self) -> OntologyStructure: + active_project = Table("active_project", self._metadata, autoload_with=self._engine) + + stmt = select(active_project.c.project_ontology).where(active_project.c.project_hash == f"{self._project_hash}") + + return OntologyStructure.from_dict(self._execute_statement(stmt)) def get_prediction_metrics(self) -> pd.DataFrame: active_project_prediction = Table("active_project_prediction", self._metadata, autoload_with=self._engine) @@ -46,13 +71,7 @@ def get_prediction_metrics(self) -> pd.DataFrame: active_project_prediction.c.project_hash == f"{self._project_hash}" ) - with self._engine.connect() as connection: - result = connection.execute(stmt).fetchone() - - if result is not None: - prediction_hash = result[0] - else: - prediction_hash = None + prediction_hash = self._execute_statement(stmt) active_project_prediction_analytics = Table( "active_project_prediction_analytics", self._metadata, autoload_with=self._engine @@ -60,6 +79,7 @@ def get_prediction_metrics(self) -> pd.DataFrame: stmt = select( [ + active_project_prediction_analytics.c.du_hash, active_project_prediction_analytics.c.feature_hash, active_project_prediction_analytics.c.metric_area, active_project_prediction_analytics.c.metric_area_relative, @@ -86,3 +106,50 @@ def get_prediction_metrics(self) -> pd.DataFrame: df = pd.read_sql(stmt, conn) return df + + def get_images_metrics(self) -> pd.DataFrame: + active_project_analytics_data = Table( + "active_project_analytics_data", self._metadata, autoload_with=self._engine + ) + stmt = select(active_project_analytics_data).where( + active_project_analytics_data.c.project_hash == f"{self._project_hash}" + ) + + with self._engine.begin() as conn: + df = pd.read_sql(stmt, conn) + + return df + + def get_or_add_tag(self, tag: str) -> UUID: + with Session(self._engine) as sess: + if tag in self._existing_tags: + return self._existing_tags[tag] + + tag_hash = uuid4() + new_tag = ProjectTag( + tag_hash=tag_hash, + name=tag, + project_hash=self._project_hash, + description="", + ) + sess.add(new_tag) + self._existing_tags[tag] = new_tag.tag_hash + sess.commit() + return tag_hash + + def add_tag_to_data_units( + self, + tag_hash: UUID, + du_items: list[DataUnitItem], + ) -> None: + with Session(self._engine) as sess: + for du_item in du_items: + sess.add( + ProjectTaggedDataUnit( + project_hash=self._project_hash, + du_hash=du_item.du_hash, + frame=du_item.frame, + tag_hash=tag_hash, + ) + ) + sess.commit()