diff --git a/src/encord_active/public/active_project.py b/src/encord_active/public/active_project.py index 141f0ae2b..65a3c9717 100644 --- a/src/encord_active/public/active_project.py +++ b/src/encord_active/public/active_project.py @@ -1,122 +1,201 @@ from dataclasses import dataclass -from typing import Any +from pathlib import Path +from typing import Optional, Type, Union from uuid import UUID, uuid4 import pandas as pd from encord.objects import OntologyStructure -from sqlalchemy import MetaData, Table, create_engine, select, text +from sqlalchemy import Integer, func from sqlalchemy.engine import Engine -from sqlalchemy.sql import Select -from sqlmodel import Session -from sqlmodel import select as sqlmodel_select +from sqlmodel import Session, select -from encord_active.db.models import ProjectTag, ProjectTaggedDataUnit +from encord_active.db.models import ( + Project, + ProjectAnnotationAnalytics, + ProjectDataAnalytics, + ProjectDataUnitMetadata, + ProjectPrediction, + ProjectPredictionAnalytics, + ProjectTag, + ProjectTaggedDataUnit, + get_engine, +) +from encord_active.lib.common.data_utils import url_to_file_path + +_P = Project +_T = ProjectTag @dataclass class DataUnitItem: - du_hash: str + du_hash: UUID frame: int -def get_active_engine(path_to_db: str) -> Engine: - return create_engine(f"sqlite:///{path_to_db}") +AnalyticsModel = Union[Type[ProjectDataAnalytics], Type[ProjectPredictionAnalytics], Type[ProjectAnnotationAnalytics]] + + +def get_active_engine(path_to_db: Union[str, Path]) -> Engine: + path = Path(path_to_db) if isinstance(path_to_db, str) else path_to_db + return get_engine(path, use_alembic=False) class ActiveProject: - def __init__(self, engine: Engine, project_name: str): + @classmethod + def from_db_file(cls, db_file_path: Union[str, Path], project_name: str): + path = Path(db_file_path) if isinstance(db_file_path, str) else db_file_path + engine = get_active_engine(db_file_path) + return ActiveProject(engine, project_name, root_path=path.parent) + + def __init__(self, engine: Optional[Engine], project_name: str, root_path: Optional[Path] = None): self._engine = engine - self._metadata = MetaData(bind=self._engine) self._project_name = project_name + self._root_path = root_path - active_project = Table("active_project", self._metadata, autoload_with=self._engine) - stmt = select(active_project.c.project_hash).where(active_project.c.project_name == f"{self._project_name}") + with Session(self._engine) as sess: + res = sess.exec( + select(_P.project_hash, _P.project_ontology).where(_P.project_name == project_name).limit(1) + ).first() - with self._engine.connect() as connection: - result = connection.execute(stmt).fetchone() + if res is None: + raise ValueError(f"Couldn't find project with name `{project_name}` in the DB.") - if result is not None: - self._project_hash = result[0] - else: - self._project_hash = None + self.project_hash, ontology = res + project_tuples = sess.exec(select(_T.name, _T.tag_hash).where(_T.project_hash == self.project_hash)).all() - with Session(engine) as sess: - self._existing_tags = { - tag.name: tag.tag_hash - for tag in sess.exec( - sqlmodel_select(ProjectTag).where(ProjectTag.project_hash == self._project_hash) - ).all() - } - sess.commit() + # Assuming that there's just one prediction model + # FIXME: With multiple sets of model predictions, we should select the right UUID here + self._model_hash = sess.exec( + select(ProjectPrediction.prediction_hash) + .where(ProjectPrediction.project_hash == self.project_hash) + .limit(1) + ).first() - def _execute_statement(self, stmt: Select) -> Any: - with self._engine.connect() as connection: - result = connection.execute(stmt).fetchone() - - if result is not None: - return result[0] - else: - return None + self._existing_tags = dict(project_tuples) + self.ontology = OntologyStructure.from_dict(ontology) # type: ignore + # For backward compatibility def get_ontology(self) -> OntologyStructure: - active_project = Table("active_project", self._metadata, autoload_with=self._engine) - - stmt = select(active_project.c.project_ontology).where(active_project.c.project_hash == f"{self._project_hash}") + return self.ontology + + def _join_path_statement(self, stmt, base_model: AnalyticsModel): + stmt = stmt.add_columns(ProjectDataUnitMetadata.data_uri, ProjectDataUnitMetadata.data_uri_is_video) + stmt = stmt.join( + ProjectDataUnitMetadata, + onclause=(base_model.du_hash == ProjectDataUnitMetadata.du_hash) + & (base_model.frame == ProjectDataUnitMetadata.frame), + ).where(ProjectDataUnitMetadata.project_hash == self.project_hash) + + def transform(df): + _root_path = self._root_path + if _root_path is None: + raise ValueError("Root path is not set. Provide it in the constructor or use `from_db_file`") + + df["data_uri"] = df.data_uri.map(lambda p: url_to_file_path(p, project_dir=_root_path) if p else p) + return df + + return stmt, transform + + def _join_data_tags_statement(self, stmt, base_model: AnalyticsModel, group_by_cols=None): + stmt = stmt.add_columns(("[" + func.group_concat('"' + ProjectTag.name + '"', ", ") + "]").label("data_tags")) + stmt = ( + stmt.outerjoin( + ProjectTaggedDataUnit, + onclause=(base_model.du_hash == ProjectTaggedDataUnit.du_hash) + & (base_model.frame == ProjectTaggedDataUnit.frame), + ) + .outerjoin( + ProjectTag, + onclause=( + (ProjectTag.tag_hash == ProjectTaggedDataUnit.tag_hash) + & (ProjectTaggedDataUnit.project_hash == self.project_hash) + ), + ) + .group_by(base_model.du_hash, base_model.frame, *(group_by_cols or [])) + ) - return OntologyStructure.from_dict(self._execute_statement(stmt)) + def transform(df): + df["data_tags"] = df.data_tags.map(lambda x: eval(x) if x else []) + return df - def get_prediction_metrics(self) -> pd.DataFrame: - active_project_prediction = Table("active_project_prediction", self._metadata, autoload_with=self._engine) - stmt = select(active_project_prediction.c.prediction_hash).where( - active_project_prediction.c.project_hash == f"{self._project_hash}" - ) + return stmt, transform - prediction_hash = self._execute_statement(stmt) + def get_prediction_metrics(self, include_data_uris: bool = False, include_data_tags: bool = False) -> pd.DataFrame: + """ + Returns a pandas data frame with all the prediction metrics. - active_project_prediction_analytics = Table( - "active_project_prediction_analytics", self._metadata, autoload_with=self._engine - ) + Args: + include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file. + include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image. + Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names. - stmt = select( - [ - active_project_prediction_analytics.c.du_hash, - active_project_prediction_analytics.c.feature_hash, - active_project_prediction_analytics.c.metric_area, - active_project_prediction_analytics.c.metric_area_relative, - active_project_prediction_analytics.c.metric_aspect_ratio, - active_project_prediction_analytics.c.metric_brightness, - active_project_prediction_analytics.c.metric_contrast, - active_project_prediction_analytics.c.metric_sharpness, - active_project_prediction_analytics.c.metric_red, - active_project_prediction_analytics.c.metric_green, - active_project_prediction_analytics.c.metric_blue, - active_project_prediction_analytics.c.metric_label_border_closeness, - active_project_prediction_analytics.c.metric_label_confidence, - text( - """CASE - WHEN feature_hash == match_feature_hash THEN 1 - ELSE 0 - END AS true_positive - """ - ), - ] - ).where(active_project_prediction_analytics.c.prediction_hash == prediction_hash) + """ + if self._model_hash is None: + raise ValueError(f"Project with name `{self._project_name}` does not have any model predictions") - with self._engine.begin() as conn: - df = pd.read_sql(stmt, conn) + with Session(self._engine) as sess: + P = ProjectPredictionAnalytics + stmt = select( # type: ignore + P.du_hash, + P.feature_hash, + P.metric_area, + P.metric_area_relative, + P.metric_aspect_ratio, + P.metric_brightness, + P.metric_contrast, + P.metric_sharpness, + P.metric_red, + P.metric_green, + P.metric_blue, + P.metric_label_border_closeness, + P.metric_label_confidence, + (P.feature_hash == P.match_feature_hash).cast(Integer).label("true_positive"), # type: ignore + ).where(P.project_hash == self.project_hash, P.prediction_hash == self._model_hash) + + transforms = [] + if include_data_uris: + stmt, transform = self._join_path_statement(stmt, P) + transforms.append(transform) + if include_data_tags: + stmt, transform = self._join_data_tags_statement(stmt, P, group_by_cols=[P.object_hash]) + transforms.append(transform) + + df = pd.DataFrame(sess.exec(stmt).all()) + df.columns = list(sess.exec(stmt).keys()) or df.columns # type: ignore + + for transform in transforms: + df = transform(df) return df - def get_images_metrics(self) -> pd.DataFrame: - active_project_analytics_data = Table( - "active_project_analytics_data", self._metadata, autoload_with=self._engine - ) - stmt = select(active_project_analytics_data).where( - active_project_analytics_data.c.project_hash == f"{self._project_hash}" - ) + def get_images_metrics(self, *, include_data_uris: bool = False, include_data_tags: bool = True) -> pd.DataFrame: + """ + Returns a pandas data frame with all the prediction metrics. + + Args: + include_data_uris: If set to true, the data frame will contain a data_uri column containing the path to the image file. + include_data_tags: If set to true, the data frame will contain a data_tags column containing a list of tags for the underlying image. + Disclaimer. We take no measures here to counteract SQL injections so avoid using "funky" characters like '"\\/` in your tag names. + """ + with Session(self._engine) as sess: + # hack to get all columns without the pydantic model + stmt = select(*[c for c in ProjectDataAnalytics.__table__.c][:3]).where( # type: ignore + ProjectDataAnalytics.project_hash == self.project_hash + ) + transforms = [] + if include_data_uris: + stmt, transform = self._join_path_statement(stmt, ProjectDataAnalytics) + transforms.append(transform) + if include_data_tags: + stmt, transform = self._join_data_tags_statement(stmt, ProjectDataAnalytics) + transforms.append(transform) + image_metrics = sess.exec(stmt).all() + + df = pd.DataFrame(image_metrics) + df.columns = list(sess.execute(stmt).keys()) or df.columns # type: ignore - with self._engine.begin() as conn: - df = pd.read_sql(stmt, conn) + for transform in transforms: + df = transform(df) return df @@ -129,7 +208,7 @@ def get_or_add_tag(self, tag: str) -> UUID: new_tag = ProjectTag( tag_hash=tag_hash, name=tag, - project_hash=self._project_hash, + project_hash=self.project_hash, description="", ) sess.add(new_tag) @@ -146,7 +225,7 @@ def add_tag_to_data_units( for du_item in du_items: sess.add( ProjectTaggedDataUnit( - project_hash=self._project_hash, + project_hash=self.project_hash, du_hash=du_item.du_hash, frame=du_item.frame, tag_hash=tag_hash,