Skip to content

Commit

Permalink
Feat: Add public methods for tagging images (#676)
Browse files Browse the repository at this point in the history
feat: Add methods for tagging
  • Loading branch information
Gorkem-Encord authored Nov 29, 2023
1 parent 5d1ab47 commit b1b3efd
Showing 1 changed file with 81 additions and 14 deletions.
95 changes: 81 additions & 14 deletions src/encord_active/public/active_project.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from dataclasses import dataclass
from typing import Any
from uuid import UUID, uuid4

import pandas as pd
from encord.objects import OntologyStructure
from sqlalchemy import MetaData, Table, create_engine, select, text
from sqlalchemy.engine import Engine
from sqlalchemy.sql import Select
from sqlmodel import Session
from sqlmodel import select as sqlmodel_select

from encord_active.db.models import ProjectTag, ProjectTaggedDataUnit


@dataclass
class DataUnitItem:
du_hash: str
frame: int


def get_active_engine(path_to_db: str) -> Engine:
Expand All @@ -25,41 +40,46 @@ def __init__(self, engine: Engine, project_name: str):
else:
self._project_hash = None

def get_ontology(self) -> OntologyStructure:
active_project = Table("active_project", self._metadata, autoload_with=self._engine)

stmt = select(active_project.c.project_ontology).where(active_project.c.project_hash == f"{self._project_hash}")
with Session(engine) as sess:
self._existing_tags = {
tag.name: tag.tag_hash
for tag in sess.exec(
sqlmodel_select(ProjectTag).where(ProjectTag.project_hash == self._project_hash)
).all()
}
sess.commit()

def _execute_statement(self, stmt: Select) -> Any:
with self._engine.connect() as connection:
result = connection.execute(stmt).fetchone()

if result is not None:
ontology = result[0]
return result[0]
else:
ontology = None
return None

return OntologyStructure.from_dict(ontology)
def get_ontology(self) -> OntologyStructure:
active_project = Table("active_project", self._metadata, autoload_with=self._engine)

stmt = select(active_project.c.project_ontology).where(active_project.c.project_hash == f"{self._project_hash}")

return OntologyStructure.from_dict(self._execute_statement(stmt))

def get_prediction_metrics(self) -> pd.DataFrame:
active_project_prediction = Table("active_project_prediction", self._metadata, autoload_with=self._engine)
stmt = select(active_project_prediction.c.prediction_hash).where(
active_project_prediction.c.project_hash == f"{self._project_hash}"
)

with self._engine.connect() as connection:
result = connection.execute(stmt).fetchone()

if result is not None:
prediction_hash = result[0]
else:
prediction_hash = None
prediction_hash = self._execute_statement(stmt)

active_project_prediction_analytics = Table(
"active_project_prediction_analytics", self._metadata, autoload_with=self._engine
)

stmt = select(
[
active_project_prediction_analytics.c.du_hash,
active_project_prediction_analytics.c.feature_hash,
active_project_prediction_analytics.c.metric_area,
active_project_prediction_analytics.c.metric_area_relative,
Expand All @@ -86,3 +106,50 @@ def get_prediction_metrics(self) -> pd.DataFrame:
df = pd.read_sql(stmt, conn)

return df

def get_images_metrics(self) -> pd.DataFrame:
active_project_analytics_data = Table(
"active_project_analytics_data", self._metadata, autoload_with=self._engine
)
stmt = select(active_project_analytics_data).where(
active_project_analytics_data.c.project_hash == f"{self._project_hash}"
)

with self._engine.begin() as conn:
df = pd.read_sql(stmt, conn)

return df

def get_or_add_tag(self, tag: str) -> UUID:
with Session(self._engine) as sess:
if tag in self._existing_tags:
return self._existing_tags[tag]

tag_hash = uuid4()
new_tag = ProjectTag(
tag_hash=tag_hash,
name=tag,
project_hash=self._project_hash,
description="",
)
sess.add(new_tag)
self._existing_tags[tag] = new_tag.tag_hash
sess.commit()
return tag_hash

def add_tag_to_data_units(
self,
tag_hash: UUID,
du_items: list[DataUnitItem],
) -> None:
with Session(self._engine) as sess:
for du_item in du_items:
sess.add(
ProjectTaggedDataUnit(
project_hash=self._project_hash,
du_hash=du_item.du_hash,
frame=du_item.frame,
tag_hash=tag_hash,
)
)
sess.commit()

0 comments on commit b1b3efd

Please sign in to comment.