From 2507f54bc984c21252a20355c4cd1c89047857d7 Mon Sep 17 00:00:00 2001 From: David Sapiro <115489098+Encord-davids@users.noreply.github.com> Date: Wed, 4 Jan 2023 14:33:11 +0000 Subject: [PATCH] refactor: decouple rendering (#57) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: remove set project dir * refactor: move db to lib (#29) * refactor: adds a project file structure class * refactor: remove st and state dependency from connections db * refactor: unifies db queries for state population * refactor: move tag count to lit * refactor: moves metrics loading and its deps to lib (#36) * refactor: dataframe validation and decoupling outliers computation * refactor: adds lost changes due to merge conflicts * refactor: decouple img utils * refactor: metric explorer (#40) * refactor: restructure metrics explorer statistics components * moves pagination to a separate component * moves get_histogram to a lib * avoid using string enum value for embedding type * split and decouple similarities and embeddings from explorer * Refactor/decouple model quality pages (#46) * chore: Add local ci-tests script to gitignore. * refactor(model-quality): move data logic to lib - move data.py an map_mar.py from app to lib - add pandera schemas - refactor data frame merging for reuse - use session state setdefault for caching prediction data - Type inputs to map_mar computations - schemas for model data frames - refactor grid view - redefine prediction scope - metric charts - move charts to lib * refactor: export utils (#43) * refactor: export actions utils * refactor: dataset balancing * refactor: move cli helper to lib * Chore/gitignore (#45) chore: Add local ci-tests script to gitignore. * refactor: lib restructure (#47) * refactor: move metric from common to metrics * refactor: dataset balance module * refactor: adds sandbox projects module * moves project stuff to project dir * rename type definitions back to metric * rename metrics files * rename embedding files * rename encord project to utils * encrod action utils * moves common embeddings to utils * reorganize writers * moves tester to execute * model predictions reader * model predictions filters * fix isort and black * fix mypy errors * fix(imports): broken imports after refactor * fix: remove print * fix: wrong dir in project fetch meta * refactor: consolidated state and bugfixes (#56) * refactor: consolidate state * state cleanup and state mutator * adds a lazy state initialization and some more cleanup * refactor: organize state files * removes global state const * remove generic multiselect * remove mutliselect * more cleanup of predictions state * fix multiselect values * completely gets rid of predictions state * Fix: convert imported polygon points to absolute values (#52) fix: unnormalize polygon points * adds db module * fix missing similarities missing imports * fix: double annotator stats title (#55) fix: remove duplicate title Co-authored-by: Gorkem Polat <110481017+Gorkem-Encord@users.noreply.github.com> Co-authored-by: Frederik HvilshΓΈj <93145535+frederik-encord@users.noreply.github.com> * fix: linting error * delete test file Co-authored-by: Frederik HvilshΓΈj Co-authored-by: Frederik HvilshΓΈj <93145535+frederik-encord@users.noreply.github.com> Co-authored-by: Gorkem Polat <110481017+Gorkem-Encord@users.noreply.github.com> --- .gitignore | 6 +- docs/builders/build_metrics_pages.py | 4 +- docs/docs/metrics/write-your-own.md | 6 +- docs/docs/sdk/importing-model-predictions.mdx | 16 +- docs/docs/workflows/import-predictions.mdx | 8 +- ...ve_Building_your_own_metric_tutorial.ipynb | 8 +- poetry.lock | 56 +- pyproject.toml | 1 + .../app/actions_page/export_balance.py | 138 +---- .../app/actions_page/export_filter.py | 107 +++- src/encord_active/app/common/action_utils.py | 256 --------- src/encord_active/app/common/cli_helpers.py | 17 - .../app/common/components/__init__.py | 1 - .../common/components/annotator_statistics.py | 61 ++ .../app/common/components/data_tags.py | 6 +- .../app/common/components/label_statistics.py | 38 ++ .../app/common/components/metric_summary.py | 78 +++ .../app/common/components/multi_select.py | 39 -- .../app/common/components/paginator.py | 21 + .../app/common/components/prediction_grid.py | 115 ++++ .../app/common/components/similarities.py | 175 ++++++ .../app/common/components/slicer.py | 22 + .../components/tags}/__init__.py | 0 .../{ => tags}/bulk_tagging_form.py | 23 +- .../{ => tags}/individual_tagging.py | 21 +- .../components/{ => tags}/tag_creator.py | 37 +- src/encord_active/app/common/css.py | 2 +- .../app/common/embedding_utils.py | 219 -------- src/encord_active/app/common/metric.py | 43 -- src/encord_active/app/common/page.py | 53 +- src/encord_active/app/common/state.py | 136 ++--- src/encord_active/app/common/state_hooks.py | 65 +++ src/encord_active/app/common/utils.py | 213 +------ src/encord_active/app/data_quality/common.py | 82 --- .../app/data_quality/sub_pages/explorer.py | 520 +++++++----------- .../app/data_quality/sub_pages/summary.py | 116 +--- src/encord_active/app/db/connection.py | 19 - .../app/model_quality/components/__init__.py | 2 - .../components/false_negative_view.py | 97 ---- .../model_quality/components/index_view.py | 79 --- .../model_quality/components/metric_view.py | 85 --- .../app/model_quality/components/utils.py | 35 -- src/encord_active/app/model_quality/data.py | 218 -------- .../app/model_quality/settings.py | 37 +- .../app/model_quality/sub_pages/__init__.py | 71 ++- .../sub_pages/false_negatives.py | 51 +- .../sub_pages/false_positives.py | 42 +- .../app/model_quality/sub_pages/metrics.py | 161 +----- .../sub_pages/performance_by_metric.py | 284 +++------- .../model_quality/sub_pages/true_positives.py | 43 +- src/encord_active/app/streamlit_entrypoint.py | 18 +- src/encord_active/app/views/landing_page.py | 4 +- src/encord_active/app/views/metrics.py | 9 +- src/encord_active/app/views/model_quality.py | 79 +-- src/encord_active/cli/imports.py | 4 +- src/encord_active/cli/main.py | 6 +- src/encord_active/cli/print.py | 4 +- src/encord_active/cli/utils/coco.py | 4 +- src/encord_active/cli/utils/encord.py | 4 +- src/encord_active/cli/utils/streamlit.py | 3 + src/encord_active/lib/charts/__init__.py | 0 src/encord_active/lib/charts/histogram.py | 28 + .../lib/charts/metric_importance.py | 70 +++ .../lib/charts/partition_histogram.py | 17 + .../lib/charts/performance_by_metric.py | 195 +++++++ .../lib/charts/precision_recall.py | 82 +++ src/encord_active/lib/charts/scopes.py | 7 + .../{app => lib}/common/colors.py | 0 src/encord_active/lib/common/image_utils.py | 252 +++++++++ src/encord_active/lib/common/iterator.py | 2 +- src/encord_active/lib/common/tester.py | 71 --- src/encord_active/lib/common/utils.py | 48 +- src/encord_active/lib/common/writer.py | 118 ---- src/encord_active/lib/dataset/__init__.py | 0 src/encord_active/lib/dataset/balance.py | 79 +++ src/encord_active/lib/dataset/outliers.py | 54 ++ src/encord_active/lib/db/__init__.py | 0 src/encord_active/lib/db/connection.py | 26 + src/encord_active/lib/db/helpers/__init__.py | 0 src/encord_active/lib/db/helpers/tags.py | 20 + .../{app => lib}/db/merged_metrics.py | 15 +- .../{app => lib}/db/predictions.py | 6 +- src/encord_active/{app => lib}/db/tags.py | 28 +- .../lib/embeddings/{cnn_embed.py => cnn.py} | 25 +- .../embeddings/{hu_embed.py => hu_moments.py} | 2 +- src/encord_active/lib/embeddings/utils.py | 150 +++++ src/encord_active/lib/embeddings/writer.py | 60 ++ src/encord_active/lib/encord/__init__.py | 0 src/encord_active/lib/encord/actions.py | 250 +++++++++ .../lib/encord/{project.py => utils.py} | 0 src/encord_active/lib/metrics/__init__.py | 0 src/encord_active/lib/metrics/example.py | 23 +- src/encord_active/lib/metrics/execute.py | 149 +++++ .../geometric/annotation_duplicates.py | 15 +- .../lib/metrics/geometric/hu_static.py | 18 +- .../lib/metrics/geometric/hu_temporal.py | 17 +- .../geometric/image_border_closeness.py | 11 +- .../lib/metrics/geometric/object_size.py | 17 +- .../geometric/occlusion_detection_video.py | 11 +- .../lib/metrics/heuristic/_annotation_time.py | 11 +- .../heuristic/high_iou_changing_classes.py | 19 +- .../lib/metrics/heuristic/img_features.py | 39 +- .../missing_objects_and_wrong_tracks.py | 21 +- .../lib/metrics/heuristic/object_counting.py | 11 +- .../lib/{common => metrics}/metric.py | 22 +- src/encord_active/lib/metrics/run_all.py | 77 --- .../metrics/semantic/_class_uncertainty.py | 15 +- .../metrics/semantic/_heatmap_uncertainty.py | 11 +- .../semantic/img_classification_quality.py | 8 +- .../metrics/semantic/img_object_quality.py | 16 +- src/encord_active/lib/metrics/utils.py | 139 +++++ src/encord_active/lib/metrics/writer.py | 70 +++ .../lib/model_predictions/filters.py | 36 ++ .../lib/model_predictions/importers.py | 4 +- .../model_predictions}/map_mar.py | 198 ++++--- .../lib/model_predictions/reader.py | 169 ++++++ .../{prediction_writer.py => writer.py} | 6 +- src/encord_active/lib/project/__init__.py | 0 .../lib/{common => project}/project.py | 51 +- .../lib/project/project_file_structure.py | 49 ++ .../sandbox_projects.py} | 0 121 files changed, 3752 insertions(+), 3154 deletions(-) delete mode 100644 src/encord_active/app/common/action_utils.py delete mode 100644 src/encord_active/app/common/cli_helpers.py create mode 100644 src/encord_active/app/common/components/annotator_statistics.py create mode 100644 src/encord_active/app/common/components/label_statistics.py create mode 100644 src/encord_active/app/common/components/metric_summary.py delete mode 100644 src/encord_active/app/common/components/multi_select.py create mode 100644 src/encord_active/app/common/components/paginator.py create mode 100644 src/encord_active/app/common/components/prediction_grid.py create mode 100644 src/encord_active/app/common/components/similarities.py create mode 100644 src/encord_active/app/common/components/slicer.py rename src/encord_active/app/{actions_page/coco_parser => common/components/tags}/__init__.py (100%) rename src/encord_active/app/common/components/{ => tags}/bulk_tagging_form.py (72%) rename src/encord_active/app/common/components/{ => tags}/individual_tagging.py (71%) rename src/encord_active/app/common/components/{ => tags}/tag_creator.py (61%) delete mode 100644 src/encord_active/app/common/embedding_utils.py delete mode 100644 src/encord_active/app/common/metric.py create mode 100644 src/encord_active/app/common/state_hooks.py delete mode 100644 src/encord_active/app/data_quality/common.py delete mode 100644 src/encord_active/app/db/connection.py delete mode 100644 src/encord_active/app/model_quality/components/__init__.py delete mode 100644 src/encord_active/app/model_quality/components/false_negative_view.py delete mode 100644 src/encord_active/app/model_quality/components/index_view.py delete mode 100644 src/encord_active/app/model_quality/components/metric_view.py delete mode 100644 src/encord_active/app/model_quality/components/utils.py delete mode 100644 src/encord_active/app/model_quality/data.py create mode 100644 src/encord_active/lib/charts/__init__.py create mode 100644 src/encord_active/lib/charts/histogram.py create mode 100644 src/encord_active/lib/charts/metric_importance.py create mode 100644 src/encord_active/lib/charts/partition_histogram.py create mode 100644 src/encord_active/lib/charts/performance_by_metric.py create mode 100644 src/encord_active/lib/charts/precision_recall.py create mode 100644 src/encord_active/lib/charts/scopes.py rename src/encord_active/{app => lib}/common/colors.py (100%) create mode 100644 src/encord_active/lib/common/image_utils.py delete mode 100644 src/encord_active/lib/common/tester.py create mode 100644 src/encord_active/lib/dataset/__init__.py create mode 100644 src/encord_active/lib/dataset/balance.py create mode 100644 src/encord_active/lib/dataset/outliers.py create mode 100644 src/encord_active/lib/db/__init__.py create mode 100644 src/encord_active/lib/db/connection.py create mode 100644 src/encord_active/lib/db/helpers/__init__.py create mode 100644 src/encord_active/lib/db/helpers/tags.py rename src/encord_active/{app => lib}/db/merged_metrics.py (89%) rename src/encord_active/{app => lib}/db/predictions.py (93%) rename src/encord_active/{app => lib}/db/tags.py (70%) rename src/encord_active/lib/embeddings/{cnn_embed.py => cnn.py} (93%) rename src/encord_active/lib/embeddings/{hu_embed.py => hu_moments.py} (96%) create mode 100644 src/encord_active/lib/embeddings/utils.py create mode 100644 src/encord_active/lib/embeddings/writer.py create mode 100644 src/encord_active/lib/encord/__init__.py create mode 100644 src/encord_active/lib/encord/actions.py rename src/encord_active/lib/encord/{project.py => utils.py} (100%) create mode 100644 src/encord_active/lib/metrics/__init__.py create mode 100644 src/encord_active/lib/metrics/execute.py rename src/encord_active/lib/{common => metrics}/metric.py (90%) delete mode 100644 src/encord_active/lib/metrics/run_all.py create mode 100644 src/encord_active/lib/metrics/utils.py create mode 100644 src/encord_active/lib/metrics/writer.py create mode 100644 src/encord_active/lib/model_predictions/filters.py rename src/encord_active/{app/model_quality => lib/model_predictions}/map_mar.py (55%) create mode 100644 src/encord_active/lib/model_predictions/reader.py rename src/encord_active/lib/model_predictions/{prediction_writer.py => writer.py} (98%) create mode 100644 src/encord_active/lib/project/__init__.py rename src/encord_active/lib/{common => project}/project.py (86%) create mode 100644 src/encord_active/lib/project/project_file_structure.py rename src/encord_active/lib/{metrics/fetch_prebuilt_metrics.py => project/sandbox_projects.py} (100%) diff --git a/.gitignore b/.gitignore index 86eb869ff..e263f46ba 100644 --- a/.gitignore +++ b/.gitignore @@ -339,7 +339,7 @@ Sessionx.vim # Temporary .netrwhist # Auto-generated tag files -tags +# tags # Persistent undo [._]*.un~ @@ -348,9 +348,7 @@ tags .idea -/viewer/outputs/ -/outputs/ -/viewer/pages/outputs/ /local_tests/ imports.prof +ci-tests diff --git a/docs/builders/build_metrics_pages.py b/docs/builders/build_metrics_pages.py index 02ba2dc86..ce8c9cc3d 100644 --- a/docs/builders/build_metrics_pages.py +++ b/docs/builders/build_metrics_pages.py @@ -8,8 +8,8 @@ from tabulate import tabulate -import encord_active.lib.common.metric as metrics -import encord_active.lib.metrics.run_all as run_all +import encord_active.lib.metrics.execute as run_all +import encord_active.lib.metrics.metric as metrics github_url = "https://github.com/encord-team/encord-active" descriptions = { diff --git a/docs/docs/metrics/write-your-own.md b/docs/docs/metrics/write-your-own.md index e6ac01375..467ec1053 100644 --- a/docs/docs/metrics/write-your-own.md +++ b/docs/docs/metrics/write-your-own.md @@ -25,8 +25,8 @@ Your implementation should call `writer.write(, )` for eve ```python from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, MetricType, Metric -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import AnnotationType, DataType, MetricType, Metric +from encord_active.lib.metrics.writer import CSVMetricWriter class ExampleMetric(Metric): TITLE = "Example Title" @@ -63,7 +63,7 @@ class ExampleMetric(Metric): if __name__ == "__main__": import sys from pathlib import Path - from encord_active.lib.common.tester import perform_test + from encord_active.lib.metric.execute import perform_test path = sys.argv[1] perform_test(ExampleMetric(), data_dir=Path(path)) diff --git a/docs/docs/sdk/importing-model-predictions.mdx b/docs/docs/sdk/importing-model-predictions.mdx index 01c715f56..09c657fed 100644 --- a/docs/docs/sdk/importing-model-predictions.mdx +++ b/docs/docs/sdk/importing-model-predictions.mdx @@ -80,7 +80,7 @@ That is: - `h`: box pixel height / image height ```python -from encord_active.app.db.predictions import BoundingBox, Prediction, Format +from encord_active.lib.db.predictions import BoundingBox, Prediction, Format prediction = Prediction( data_hash = "", @@ -112,7 +112,7 @@ BoundingBox(x=10/img_w, y=25/img_h, w=200/img_w, h=150/img_h) You specify masks as binary `numpy` arrays of size [height, width] and with `dtype` `np.uint8`. ```python -from encord_active.app.db.predictions import Prediction, Format +from encord_active.lib.db.predictions import Prediction, Format prediction = Prediction( data_hash = "", @@ -136,7 +136,7 @@ That is, an array of relative (`x`, `y`) coordinates: - `y`: relative y-coordinate of each point of the polygon (pixel coordinate / image height) ```python -from encord_active.app.db.predictions import Prediction, Format +from encord_active.lib.db.predictions import Prediction, Format import numpy as np polygon = np.array([ @@ -182,7 +182,7 @@ You can copy the appropriate snippet based on your prediction format from above Note the highlighted line, which defines where the `.pkl` file will be stored. ```python showLineNumbers -from encord_active.app.db.predictions import Prediction, Format +from encord_active.lib.db.predictions import Prediction, Format predictions_to_store = [] @@ -224,7 +224,7 @@ The code would change to something similar to this: ```python # highlight-next-line -from encord_active.lib.model_predictions.prediction_writer import PredictionWriter +from encord_active.lib.model_predictions.writer import PredictionWriter def predict(test_loader): # highlight-next-line @@ -367,7 +367,7 @@ To import the predictions, you do the following import json from encord_active.lib.model_predictions.importers import import_KITTI_labels -from encord_active.lib.model_predictions.prediction_writer import PredictionWriter +from encord_active.lib.model_predictions.writer import PredictionWriter # highlight-next-line predictions_root = Path("/path/to/your/predictions") @@ -416,7 +416,7 @@ You can use this template where the highlighted lined are what you need to chang ```python from encord_active.lib.model_predictions.importers import import_mask_predictions -from encord_active.lib.model_predictions.prediction_writer import PredictionWriter +from encord_active.lib.model_predictions.writer import PredictionWriter # highlight-start class_map = { @@ -470,7 +470,7 @@ For this, you can use these lines of code: ```python from encord_active.lib.model_predictions.iterator import PredictionIterator -from encord_active.lib.metrics.run_all import run_metrics +from encord_active.lib.metrics.execute import run_metrics run_metrics(data_dir=data_dir, iterator_cls=PredictionIterator) ``` diff --git a/docs/docs/workflows/import-predictions.mdx b/docs/docs/workflows/import-predictions.mdx index b1600c2db..f17a63409 100644 --- a/docs/docs/workflows/import-predictions.mdx +++ b/docs/docs/workflows/import-predictions.mdx @@ -126,7 +126,7 @@ That is: - `h`: box pixel height / image height ```python -from encord_active.app.db.predictions import BoundingBox, Prediction, Format +from encord_active.lib.db.predictions import BoundingBox, Prediction, Format prediction = Prediction( data_hash = "", @@ -158,7 +158,7 @@ BoundingBox(x=10/img_w, y=25/img_h, w=200/img_w, h=150/img_h) You specify masks as binary `numpy` arrays of size [height, width] and with `dtype` `np.uint8`. ```python -from encord_active.app.db.predictions import Prediction, Format +from encord_active.lib.db.predictions import Prediction, Format prediction = Prediction( data_hash = "", @@ -182,7 +182,7 @@ That is, an array of relative (`x`, `y`) coordinates: - `y`: relative y-coordinate of each point of the polygon (pixel coordinate / image height) ```python -from encord_active.app.db.predictions import Prediction, Format +from encord_active.lib.db.predictions import Prediction, Format import numpy as np polygon = np.array([ @@ -228,7 +228,7 @@ You can copy the appropriate snippet based on your prediction format from above Note the highlighted line, which defines where the `.pkl` file will be stored. ```python showLineNumbers -from encord_active.app.db.predictions import Prediction, Format +from encord_active.lib.db.predictions import Prediction, Format predictions_to_store = [] diff --git a/examples/Encord_Active_Building_your_own_metric_tutorial.ipynb b/examples/Encord_Active_Building_your_own_metric_tutorial.ipynb index db97ea26e..de27cd50c 100644 --- a/examples/Encord_Active_Building_your_own_metric_tutorial.ipynb +++ b/examples/Encord_Active_Building_your_own_metric_tutorial.ipynb @@ -504,8 +504,8 @@ "import numpy as np\n", "from encord_active.lib.common import utils\n", "from encord_active.lib.common.iterator import Iterator\n", - "from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType\n", - "from encord_active.lib.common.writer import CSVMetricWriter\n", + "from encord_active.lib.metrics.metric import AnnotationType, DataType, Metric, MetricType\n", + "from encord_active.lib.metrics.writer import CSVMetricWriter\n", "from loguru import logger\n", "\n", "logger = logger.opt(colors=True)\n", @@ -563,7 +563,7 @@ " import sys\n", " from pathlib import Path\n", "\n", - " from encord_active.lib.common.tester import perform_test\n", + " from encord_active.lib.metric.execute import perform_test\n", "\n", " path = sys.argv[1]\n", " perform_test(InstanceDeviation(), data_dir=Path(path), use_cache_only=True)" @@ -717,4 +717,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/poetry.lock b/poetry.lock index 11a944e50..ca3426902 100644 --- a/poetry.lock +++ b/poetry.lock @@ -666,7 +666,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." -category = "dev" +category = "main" optional = false python-versions = "*" @@ -747,6 +747,36 @@ pytz = ">=2020.1" [package.extras] test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] +[[package]] +name = "pandera" +version = "0.13.4" +description = "A light-weight and flexible data validation and testing tool for statistical data objects." +category = "main" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +numpy = ">=1.19.0" +packaging = ">=20.0" +pandas = ">=1.2.0" +pydantic = "*" +typing-inspect = ">=0.6.0" +wrapt = "*" + +[package.extras] +all = ["black", "dask", "fastapi", "frictionless", "geopandas", "hypothesis (>=5.41.1)", "modin", "pandas-stubs (<=1.4.3.220807)", "pyspark (>=3.2.0)", "pyyaml (>=5.1)", "ray (<=1.7.0)", "scipy", "shapely"] +dask = ["dask"] +fastapi = ["fastapi"] +geopandas = ["geopandas", "shapely"] +hypotheses = ["scipy"] +io = ["black", "frictionless", "pyyaml (>=5.1)"] +modin = ["dask", "modin", "ray (<=1.7.0)"] +modin-dask = ["dask", "modin"] +modin-ray = ["modin", "ray (<=1.7.0)"] +mypy = ["pandas-stubs (<=1.4.3.220807)"] +pyspark = ["pyspark (>=3.2.0)"] +strategies = ["hypothesis (>=5.41.1)"] + [[package]] name = "parso" version = "0.8.3" @@ -1557,6 +1587,18 @@ category = "main" optional = false python-versions = ">=3.7" +[[package]] +name = "typing-inspect" +version = "0.8.0" +description = "Runtime inspection utilities for typing module." +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +mypy-extensions = ">=0.3.0" +typing-extensions = ">=3.7.4" + [[package]] name = "tzdata" version = "2022.5" @@ -1659,7 +1701,7 @@ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"] name = "wrapt" version = "1.14.1" description = "Module for decorators, wrappers and monkey patching." -category = "dev" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" @@ -1681,7 +1723,7 @@ coco = ["pycocotools"] [metadata] lock-version = "1.1" python-versions = ">=3.9,<3.9.7 || >3.9.7,<3.11" -content-hash = "2fe8f75b5d1522e253b5b12031500d9a6aa207843f5021be26b097e25a411257" +content-hash = "acf2b6494409623e6fe787c33693cb5a678cb525d153113abd12a297080a5c57" [metadata.files] altair = [ @@ -2409,6 +2451,10 @@ pandas = [ {file = "pandas-1.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:6bb391659a747cf4f181a227c3e64b6d197100d53da98dcd766cc158bdd9ec68"}, {file = "pandas-1.5.1.tar.gz", hash = "sha256:249cec5f2a5b22096440bd85c33106b6102e0672204abd2d5c014106459804ee"}, ] +pandera = [ + {file = "pandera-0.13.4-py3-none-any.whl", hash = "sha256:9e91687861406284270add1d467f204630377892e7a4b45809bb7546f0013153"}, + {file = "pandera-0.13.4.tar.gz", hash = "sha256:6ef2b7ee00d3439ac815d4347984421a08502da1020cec60c06dd0135e8aee2f"}, +] parso = [ {file = "parso-0.8.3-py2.py3-none-any.whl", hash = "sha256:c001d4636cd3aecdaf33cbb40aebb59b094be2a74c556778ef5576c175e19e75"}, {file = "parso-0.8.3.tar.gz", hash = "sha256:8c07be290bb59f03588915921e29e8a50002acaf2cdc5fa0e0114f91709fafa0"}, @@ -3009,6 +3055,10 @@ typing-extensions = [ {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] +typing-inspect = [ + {file = "typing_inspect-0.8.0-py3-none-any.whl", hash = "sha256:5fbf9c1e65d4fa01e701fe12a5bca6c6e08a4ffd5bc60bfac028253a447c5188"}, + {file = "typing_inspect-0.8.0.tar.gz", hash = "sha256:8b1ff0c400943b6145df8119c41c244ca8207f1f10c9c057aeed1560e4806e3d"}, +] tzdata = [ {file = "tzdata-2022.5-py2.py3-none-any.whl", hash = "sha256:323161b22b7802fdc78f20ca5f6073639c64f1a7227c40cd3e19fd1d0ce6650a"}, {file = "tzdata-2022.5.tar.gz", hash = "sha256:e15b2b3005e2546108af42a0eb4ccab4d9e225e2dfbf4f77aad50c70a4b1f3ab"}, diff --git a/pyproject.toml b/pyproject.toml index b7fdda6fc..88ab65e1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ toml = "^0.10.2" pydantic = "^1.10.2" pycocotools = {version = "^2.0.6", optional = true} psutil = "^5.9.4" +pandera = "^0.13.4" [tool.poetry.extras] diff --git a/src/encord_active/app/actions_page/export_balance.py b/src/encord_active/app/actions_page/export_balance.py index ca3b4a6e5..e01b1174a 100644 --- a/src/encord_active/app/actions_page/export_balance.py +++ b/src/encord_active/app/actions_page/export_balance.py @@ -1,31 +1,19 @@ -import io -import json from datetime import datetime -from encodings import utf_8 from typing import Dict, List, Tuple -from zipfile import ZipFile -import altair as alt -import numpy as np import pandas as pd import streamlit as st -from tqdm import tqdm -import encord_active.app.common.state as state -from encord_active.app.common import metric as iutils -from encord_active.app.common.components import multiselect_with_all_option -from encord_active.app.common.metric import MetricData +from encord_active.app.common.state import get_state +from encord_active.app.common.state_hooks import use_state from encord_active.app.common.utils import set_page_config, setup_page -from encord_active.app.data_quality.common import MetricType, load_available_metrics -from encord_active.lib.coco.encoder import generate_coco_file - - -def add_partition(): - st.session_state[state.NUMBER_OF_PARTITIONS] += 1 - - -def remove_partition(): - st.session_state[state.NUMBER_OF_PARTITIONS] -= 1 +from encord_active.lib.charts.partition_histogram import get_partition_histogram +from encord_active.lib.dataset.balance import balance_dataframe, get_partitions_zip +from encord_active.lib.metrics.utils import ( + MetricData, + MetricScope, + load_available_metrics, +) def metrics_panel() -> Tuple[List[MetricData], int]: @@ -37,20 +25,19 @@ def metrics_panel() -> Tuple[List[MetricData], int]: seed (int): The seed for the random sampling. """ # TODO - add label metrics - metrics = load_available_metrics(MetricType.DATA_QUALITY.value) # type: ignore + metrics = load_available_metrics(get_state().project_paths.metrics, MetricScope.DATA_QUALITY) metric_names = [metric.name for metric in metrics] col1, col2 = st.columns([6, 1]) with col1: - selected_metric_names = multiselect_with_all_option( - label="Metrics to balance", + selected_metric_names = st.multiselect( + label="Filter by metric", options=metric_names, key="balance_metrics", - default=["All"], ) seed = col2.number_input("Seed", value=42, step=1, key="seed") - if "All" in selected_metric_names: + if not selected_metric_names: selected_metric_names = metric_names selected_metrics = [metric for metric in metrics if metric.name in selected_metric_names] return selected_metrics, int(seed) @@ -63,8 +50,16 @@ def partitions_panel() -> Dict[str, int]: Returns: A dictionary with the partition names as keys and the partition sizes as values. """ + get_partitions_number, set_partitions_number = use_state(1) + + def add_partition(): + set_partitions_number(lambda prev: prev + 1) + + def remove_partition(): + set_partitions_number(lambda prev: prev - 1) + partition_sizes = {} - for i in range(st.session_state[state.NUMBER_OF_PARTITIONS]): + for i in range(get_partitions_number()): partition_columns = st.columns((4, 12, 1)) partition_name = partition_columns[0].text_input( f"Name of partition {i + 1}", key=f"name_partition_{i + 1}", value=f"Partition {i + 1}" @@ -74,7 +69,7 @@ def partitions_panel() -> Dict[str, int]: key=f"size_partition_{i + 1}", min_value=1, max_value=100, - value=100 // st.session_state[state.NUMBER_OF_PARTITIONS], + value=100 // get_partitions_number(), step=1, ) if i > 0: @@ -92,71 +87,6 @@ def partitions_panel() -> Dict[str, int]: return partition_sizes -def balance_dataframe(selected_metrics: List[MetricData], partition_sizes: Dict[str, int], seed: int) -> pd.DataFrame: - """ - Balances the dataset over the selected metrics and partition sizes. - Currently, it is done by random sampling. - - Args: - selected_metrics (List[MetricData]): The metrics to balance over. - partition_sizes (Dict[str,int]): The dictionary of partition names : partition sizes. - seed (int): The seed for the random sampling. - - Returns: - pd.Dataframe: A dataframe with the following columns: sample identifiers, metric values and allocated partition. - """ - # Collect metric dataframes - merged_df_list = [] - for i, m in enumerate(selected_metrics): - df = iutils.load_metric(m, normalize=False).copy() - merged_df_list.append(df[["identifier", "score"]].rename(columns={"score": m.name})) - - # Merge all dataframes by identifier - merged_df = merged_df_list.pop() - for df_tmp in merged_df_list: - merged_df = merged_df.merge(df_tmp, on="identifier", how="outer") - - # Randomly sample from each partition and add column to merged_df - n_samples = len(merged_df) - selection_df = merged_df.copy() - merged_df["partition"] = "" - for partition_name, partition_size in [(k, v) for k, v in partition_sizes.items()][:-1]: - n_partition = int(np.floor(n_samples * partition_size / 100)) - partition_df = selection_df.sample(n=n_partition, replace=False, random_state=seed) - # Remove samples from selection_df - selection_df = selection_df[~selection_df["identifier"].isin(partition_df["identifier"])] - # Add partition column to merged_df - merged_df.loc[partition_df.index, "partition"] = partition_name - - # Assign the remaining samples to the last partition - merged_df.loc[merged_df["partition"] == "", "partition"] = list(partition_sizes.keys())[-1] - return merged_df - - -def get_partitions_zip(partition_dict: Dict[str, pd.DataFrame]) -> bytes: - """ - Creates a zip file with a COCO json object for each partition. - - Args: - partition_dict (Dict[str, pd.DataFrame]): A dictionary of partition names : partition dataframes. - - Returns: - bytes: The zip file as a byte array. - """ - with st.spinner("Generating COCO files"): - zip_io = io.BytesIO() - with ZipFile(zip_io, mode="w") as zf: - partition_dict.pop("Unassigned", None) - for partition_name, partition in tqdm(partition_dict.items(), desc="Generating COCO files"): - coco_json = generate_coco_file(partition, st.session_state.project_dir, st.session_state.ontology_file) - with zf.open(partition_name.replace(" ", "_").lower() + ".json", "w") as zip_file: - writer = utf_8.StreamWriter(zip_file) - json.dump(coco_json, writer) # type: ignore - zip_io.seek(0) - partitions_zip_file = zip_io.read() - return partitions_zip_file - - def export_balance(): setup_page() st.header("Balance & Export") @@ -164,9 +94,6 @@ def export_balance(): "Here you can create balanced partitions of your dataset over a set of metrics and export them as a CSV file." ) - if not st.session_state.get(state.NUMBER_OF_PARTITIONS): - st.session_state[state.NUMBER_OF_PARTITIONS] = 1 - selected_metrics, seed = metrics_panel() partition_sizes = partitions_panel() @@ -178,7 +105,7 @@ def export_balance(): st.warning("Due to rounding errors, the resulting partition sizes might not be exactly as specified. ") cols = st.columns(len(partition_sizes)) partition_dict: Dict[str, pd.DataFrame] = {} - for col, (partition_name, partition_size) in zip(cols, partition_sizes.items()): + for col, (partition_name, _) in zip(cols, partition_sizes.items()): partition = balanced_df[balanced_df["partition"] == partition_name] partition_dict[partition_name] = partition n_partition_df = partition.shape[0] @@ -192,7 +119,8 @@ def export_balance(): help="Generate COCO file with filtered data", ) - partitions_zip_file = get_partitions_zip(partition_dict) if is_pressed else "" + with st.spinner("Generating COCO files"): + partitions_zip_file = get_partitions_zip(partition_dict, get_state().project_paths) if is_pressed else "" action_columns[1].download_button( "⬇ Download filtered data", @@ -206,19 +134,7 @@ def export_balance(): # Plot distribution of partitions for each metric for m in selected_metrics: with st.expander(f"{m.name} - Partition distribution"): - # Get altair layered histogram of partitions - chart = ( - alt.Chart(balanced_df) - .mark_bar( - binSpacing=0, - ) - .encode( - x=alt.X(f"{m.name}:Q", bin=alt.Bin(maxbins=50)), - y="count()", - color="partition:N", - tooltip=["partition", "count()"], - ) - ) + chart = get_partition_histogram(balanced_df, m.name) st.altair_chart(chart, use_container_width=True) diff --git a/src/encord_active/app/actions_page/export_filter.py b/src/encord_active/app/actions_page/export_filter.py index a71e250df..d59b09580 100644 --- a/src/encord_active/app/actions_page/export_filter.py +++ b/src/encord_active/app/actions_page/export_filter.py @@ -10,11 +10,15 @@ is_numeric_dtype, ) -import encord_active.app.common.state as state -from encord_active.app.common.action_utils import create_new_project_on_encord_platform +from encord_active.app.common.state import get_state +from encord_active.app.common.state_hooks import use_state from encord_active.app.common.utils import set_page_config, setup_page -from encord_active.app.db.tags import Tags from encord_active.lib.coco.encoder import generate_coco_file +from encord_active.lib.common.utils import ProjectNotFound +from encord_active.lib.db.tags import Tags +from encord_active.lib.encord.actions import ( # create_a_new_dataset,; create_new_project_on_encord_platform,; get_project_user_client, + EncordActions, +) def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: @@ -98,14 +102,18 @@ def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame: def export_filter(): + get_filtered_row_count, set_filtered_row_count = use_state(0) + get_clone_button, set_clone_button = use_state(False) + setup_page() message_placeholder = st.empty() st.header("Filter & Export") - filtered_df = filter_dataframe(st.session_state[state.MERGED_DATAFRAME].copy()) + filtered_df = filter_dataframe(get_state().merged_metrics.copy()) filtered_df.reset_index(inplace=True) - st.markdown(f"**Total row:** {filtered_df.shape[0]}") + row_count = filtered_df.shape[0] + st.markdown(f"**Total row:** {row_count}") st.dataframe(filtered_df, use_container_width=True) action_columns = st.columns((3, 3, 2, 2, 2, 2, 2)) @@ -116,7 +124,7 @@ def export_filter(): with st.spinner(text="Generating COCO file"): coco_json = ( - generate_coco_file(filtered_df, st.session_state.project_dir, st.session_state.ontology_file) + generate_coco_file(filtered_df, get_state().project_paths.project_dir, get_state().project_paths.ontology) if is_pressed else "" ) @@ -129,11 +137,10 @@ def export_filter(): help="Ensure you have generated an updated COCO file before downloading", ) - def _clone_button_pressed(): - st.session_state[state.ACTION_PAGE_CLONE_BUTTON] = True - action_columns[3].button( - "πŸ— Clone", on_click=_clone_button_pressed, help="Clone the filtered data into a new Encord dataset and project" + "πŸ— Clone", + on_click=lambda: set_clone_button(True), + help="Clone the filtered data into a new Encord dataset and project", ) delete_btn = action_columns[4].button("❌ Review", help="Assign the filtered data for review on the Encord platform") edit_btn = action_columns[5].button( @@ -142,7 +149,7 @@ def _clone_button_pressed(): augment_btn = action_columns[6].button("βž• Augment", help="Augment your dataset based on the filered data") if any([delete_btn, edit_btn, augment_btn]): - st.session_state[state.ACTION_PAGE_CLONE_BUTTON] = False + set_clone_button(False) message_placeholder.markdown( """
@@ -156,14 +163,12 @@ def _clone_button_pressed(): unsafe_allow_html=True, ) - if ( - state.ACTION_PAGE_PREVIOUS_FILTERED_NUM not in st.session_state - or st.session_state[state.ACTION_PAGE_PREVIOUS_FILTERED_NUM] != filtered_df.shape[0] - ): - st.session_state[state.ACTION_PAGE_PREVIOUS_FILTERED_NUM] = filtered_df.shape[0] - st.session_state[state.ACTION_PAGE_CLONE_BUTTON] = False + prev_row_count = get_filtered_row_count() + if prev_row_count != row_count: + set_filtered_row_count(row_count) + set_clone_button(False) - if st.session_state[state.ACTION_PAGE_CLONE_BUTTON]: + if get_clone_button(): with st.form("new_project_form"): st.subheader("Create a new project with the selected items") l_column, r_column = st.columns(2) @@ -178,17 +183,63 @@ def _clone_button_pressed(): ) project_description = r_column.text_area("Project description") - create_new_project = st.form_submit_button("βž• Create") - if create_new_project: - if dataset_title == "": - st.error("Dataset title cannot be empty!") - return - if project_title == "": - st.error("Project title cannot be empty!") - return - create_new_project_on_encord_platform( - dataset_title, dataset_description, project_title, project_description, filtered_df + if not st.form_submit_button("βž• Create"): + return + + if dataset_title == "": + st.error("Dataset title cannot be empty!") + return + if project_title == "": + st.error("Project title cannot be empty!") + return + + try: + action_utils = EncordActions(st.session_state.project_dir) + label = st.empty() + progress, clear = render_progress_bar() + label.text("Step 1/2: Uploading data...") + dataset_creation_result = action_utils.create_dataset( + dataset_title, dataset_description, filtered_df, progress + ) + clear() + label.text("Step 2/2: Uploading labels...") + new_project = action_utils.create_project( + dataset_creation_result, project_title, project_description, progress + ) + clear() + label.info("πŸŽ‰ New project is created!") + + new_project_link = f"https://app.encord.com/projects/view/{new_project.project_hash}/summary" + new_dataset_link = f"https://app.encord.com/datasets/view/{dataset_creation_result.hash}" + st.markdown(f"[Go to new project]({new_project_link})") + st.markdown(f"[Go to new dataset]({new_dataset_link})") + + except ProjectNotFound as e: + st.markdown( + f""" + ❌ No `project_meta.yaml` file in the project folder. + Please create `project_meta.yaml` file in **{e.project_dir}** folder with the following content + and try again: + ``` yaml + project_hash: + ssh_key_path: /path/to/your/encord/ssh_key + ``` + """ ) + except Exception as e: + st.error(str(e)) + + +def render_progress_bar(): + progress_bar = st.empty() + + def clear(): + progress_bar.empty() + + def progress_callback(value: float): + progress_bar.progress(value) + + return progress_callback, clear if __name__ == "__main__": diff --git a/src/encord_active/app/common/action_utils.py b/src/encord_active/app/common/action_utils.py deleted file mode 100644 index a23333474..000000000 --- a/src/encord_active/app/common/action_utils.py +++ /dev/null @@ -1,256 +0,0 @@ -import json -from pathlib import Path - -import pandas as pd -import streamlit as st -from encord import Dataset, EncordUserClient, Project -from encord.constants.enums import DataType -from encord.exceptions import AuthorisationError -from encord.orm.dataset import Image, StorageLocation -from encord.utilities.label_utilities import construct_answer_dictionaries -from tqdm import tqdm - -from encord_active.lib.common.utils import fetch_project_meta - - -def _update_mapping( - user_client: EncordUserClient, new_dataset_hash: str, label_row_hash: str, data_unit_hash: str, out_mapping: dict -): - updated_dataset = user_client.get_dataset(new_dataset_hash) - for new_data_row in updated_dataset.data_rows: - if new_data_row["data_hash"] not in out_mapping: - out_mapping[new_data_row["data_hash"]] = { - "label_row_hash": label_row_hash, - "data_unit_hash": data_unit_hash, - } - return - - -def create_a_new_dataset( - user_client: EncordUserClient, dataset_title: str, dataset_description: str, filtered_dataset: pd.DataFrame -) -> tuple[str, dict[str, dict[str, str]]]: - new_du_to_original: dict[str, dict] = {} - user_client.create_dataset( - dataset_title=dataset_title, dataset_type=StorageLocation.CORD_STORAGE, dataset_description=dataset_description - ) - dataset_hash: str = user_client.get_datasets(title_eq=dataset_title)[0]["dataset"].dataset_hash - dataset: Dataset = user_client.get_dataset(dataset_hash) - - # The following operation is for image groups (to upload them efficiently) - label_hash_to_data_units: dict[str, list] = {} - for index, item in tqdm(filtered_dataset.iterrows(), total=filtered_dataset.shape[0]): - label_row_hash, data_unit_hash, *rest = item["identifier"].split("_") - label_hash_to_data_units.setdefault(label_row_hash, []).append(data_unit_hash) - - temp_progress_bar = st.empty() - temp_progress_bar.progress(0.0) - uploaded_label_rows: set = set() - for counter, (index, item) in enumerate(filtered_dataset.iterrows()): - label_row_hash, data_unit_hash, frame_id, *rest = item["identifier"].split("_") - json_txt = (st.session_state.project_dir / "data" / label_row_hash / "label_row.json").expanduser().read_text() - label_row = json.loads(json_txt) - - if label_row_hash not in uploaded_label_rows: - if label_row["data_type"] == DataType.IMAGE.value: - image_path = list( - Path(st.session_state.project_dir / "data" / label_row_hash / "images").glob(f"{data_unit_hash}.*") - )[0] - uploaded_image: Image = dataset.upload_image( - file_path=image_path, title=label_row["data_units"][data_unit_hash]["data_title"] - ) - - new_du_to_original[uploaded_image["data_hash"]] = { - "label_row_hash": label_row_hash, - "data_unit_hash": data_unit_hash, - } - - elif label_row["data_type"] == DataType.IMG_GROUP.value: - image_paths = [] - image_names = [] - if len(label_hash_to_data_units[label_row_hash]) > 0: - for data_unit in label_hash_to_data_units[label_row_hash]: - img_path: Path = list( - Path(st.session_state.project_dir / "data" / label_row_hash / "images").glob( - f"{data_unit}.*" - ) - )[0] - image_paths.append(img_path.as_posix()) - image_names.append(img_path.name) - - # Unfortunately the following function does not return metadata related to the uploaded items - dataset.create_image_group(file_paths=image_paths, title=label_row["data_title"]) - - # Since create_image_group does not return info related to the uploaded images, we should find its - # data_hash in a hacky way - _update_mapping(user_client, dataset_hash, label_row_hash, data_unit_hash, new_du_to_original) - - elif label_row["data_type"] == DataType.VIDEO.value: - video_path = list( - Path(st.session_state.project_dir / "data" / label_row_hash / "images").glob(f"{data_unit_hash}.*") - )[0].as_posix() - - # Unfortunately the following function does not return metadata related to the uploaded items - dataset.upload_video(file_path=video_path, title=label_row["data_units"][data_unit_hash]["data_title"]) - - # Since upload_video does not return info related to the uploaded video, we should find its data_hash - # in a hacky way - _update_mapping(user_client, dataset_hash, label_row_hash, data_unit_hash, new_du_to_original) - - else: - st.error(f'Undefined data type {label_row["data_type"]} for label_row={label_row["label_hash"]}') - - uploaded_label_rows.add(label_row_hash) - - temp_progress_bar.progress(counter / filtered_dataset.shape[0]) - - temp_progress_bar.empty() - return dataset_hash, new_du_to_original - - -def create_new_project_on_encord_platform( - dataset_title: str, - dataset_description: str, - project_title: str, - project_description: str, - filtered_dataset: pd.DataFrame, -): - try: - project_meta = fetch_project_meta(st.session_state.project_dir) - except (KeyError, FileNotFoundError) as _: - st.markdown( - f""" - ❌ No `project_meta.yaml` file in the project folder. - Please create `project_meta.yaml` file in **{st.session_state.project_dir}** folder with the following content - and try again: - ``` yaml - project_hash: - ssh_key_path: /path/to/your/encord/ssh_key - ``` - """ - ) - return - - meta_file = st.session_state.project_dir / "project_meta.yaml" - ssh_key_path = project_meta.get("ssh_key_path") - if not ssh_key_path: - st.error(f"`ssh_key_path` not specified in the project meta data file `{meta_file}`.") - return - - ssh_key_path = Path(ssh_key_path).expanduser() - if not ssh_key_path.is_file(): - st.error(f"No SSH file in location:{ssh_key_path}") - return - - user_client = EncordUserClient.create_with_ssh_private_key( - Path(ssh_key_path).expanduser().read_text(encoding="utf-8"), - ) - - original_project_hash = project_meta.get("project_hash") - if not original_project_hash: - st.error(f"`project_hash` not specified in the project meta data file `{meta_file}`.") - return - - original_project: Project = user_client.get_project(original_project_hash) - try: - if original_project.project_hash == original_project_hash: - pass - except AuthorisationError: - st.error( - f'The user associated to the ssh key `{ssh_key_path}` does not have access to the project with project hash `{original_project_hash}`. Run "encord-active config set ssh_key_path /path/to/your/key_file" to set it.' - ) - return - - datasets_with_same_title = user_client.get_datasets(title_eq=dataset_title) - if len(datasets_with_same_title) > 0: - st.error( - f"Dataset title '{dataset_title}' already exists in your list of datasets at Encord. Please use a different title." - ) - return None - - label = st.empty() - label.text("Step 1/2: Uploading data...") - new_dataset_hash, new_du_to_original = create_a_new_dataset( - user_client, dataset_title, dataset_description, filtered_dataset - ) - - new_project_hash: str = user_client.create_project( - project_title=project_title, - dataset_hashes=[new_dataset_hash], - project_description=project_description, - ontology_hash=original_project.get_project().ontology_hash, - ) - - new_project: Project = user_client.get_project(new_project_hash) - - # Copy labels from old project to new project - # Three things to copy: labels, object_answers, classification_answers - - all_new_label_rows = new_project.label_rows - label.text("Step 2/2: Uploading labels...") - temp_progress_bar = st.empty() - temp_progress_bar.progress(0.0) - for counter, new_label_row in enumerate(all_new_label_rows): - initiated_label_row: dict = new_project.create_label_row(new_label_row["data_hash"]) - - with open( - ( - st.session_state.project_dir - / "data" - / new_du_to_original[new_label_row["data_hash"]]["label_row_hash"] - / "label_row.json" - ).expanduser(), - "r", - encoding="utf-8", - ) as file: - original_label_row = json.load(file) - - if initiated_label_row["data_type"] in [DataType.IMAGE.value, DataType.VIDEO.value]: - - original_labels = original_label_row["data_units"][ - new_du_to_original[new_label_row["data_hash"]]["data_unit_hash"] - ]["labels"] - initiated_label_row["data_units"][new_label_row["data_hash"]]["labels"] = original_labels - initiated_label_row["object_answers"] = original_label_row["object_answers"] - initiated_label_row["classification_answers"] = original_label_row["classification_answers"] - - if original_labels != {}: - initiated_label_row = construct_answer_dictionaries(initiated_label_row) - new_project.save_label_row(initiated_label_row["label_hash"], initiated_label_row) - - elif initiated_label_row["data_type"] == DataType.IMG_GROUP.value: - object_hashes: set = set() - classification_hashes: set = set() - - # Currently img_groups are matched using data_title, it should be fixed after SDK update - for data_unit in initiated_label_row["data_units"].values(): - for original_data_unit in original_label_row["data_units"].values(): - if original_data_unit["data_hash"] == data_unit["data_title"].split(".")[0]: - data_unit["labels"] = original_data_unit["labels"] - for obj in data_unit["labels"].get("objects", []): - object_hashes.add(obj["objectHash"]) - for classification in data_unit["labels"].get("classifications", []): - classification_hashes.add(classification["classificationHash"]) - - initiated_label_row["object_answers"] = original_label_row["object_answers"] - initiated_label_row["classification_answers"] = original_label_row["classification_answers"] - - # Remove unused object/classification answers - for object_hash in object_hashes: - initiated_label_row["object_answers"].pop(object_hash) - - for classification_hash in classification_hashes: - initiated_label_row["classification_answers"].pop(classification_hash) - - initiated_label_row = construct_answer_dictionaries(initiated_label_row) - new_project.save_label_row(initiated_label_row["label_hash"], initiated_label_row) - - # remove unused object and classification answers - - temp_progress_bar.progress(counter / len(all_new_label_rows)) - - temp_progress_bar.empty() - label.info("πŸŽ‰ New project is created!") - new_project_link = f"https://app.encord.com/projects/view/{new_project_hash}/summary" - new_dataset_link = f"https://app.encord.com/datasets/view/{new_dataset_hash}" - st.markdown(f"[Go to new project]({new_project_link})") - st.markdown(f"[Go to new dataset]({new_dataset_link})") diff --git a/src/encord_active/app/common/cli_helpers.py b/src/encord_active/app/common/cli_helpers.py deleted file mode 100644 index 5b16d186c..000000000 --- a/src/encord_active/app/common/cli_helpers.py +++ /dev/null @@ -1,17 +0,0 @@ -from pathlib import Path - -from encord import EncordUserClient -from encord import Project as EncordProject - -from encord_active.lib.common.utils import fetch_project_meta - - -def get_local_project(project_dir: Path) -> EncordProject: - project_meta = fetch_project_meta(project_dir) - - ssh_key_path = Path(project_meta["ssh_key_path"]) - with open(ssh_key_path.expanduser(), "r", encoding="utf-8") as f: - key = f.read() - - client = EncordUserClient.create_with_ssh_private_key(key) - return client.get_project(project_meta.get("project_hash")) diff --git a/src/encord_active/app/common/components/__init__.py b/src/encord_active/app/common/components/__init__.py index 0198a2439..644c34709 100644 --- a/src/encord_active/app/common/components/__init__.py +++ b/src/encord_active/app/common/components/__init__.py @@ -1,3 +1,2 @@ from .data_tags import build_data_tags -from .multi_select import multiselect_with_all_option from .sticky.sticky import sticky_header diff --git a/src/encord_active/app/common/components/annotator_statistics.py b/src/encord_active/app/common/components/annotator_statistics.py new file mode 100644 index 000000000..479b6fe30 --- /dev/null +++ b/src/encord_active/app/common/components/annotator_statistics.py @@ -0,0 +1,61 @@ +import numpy as np +import pandas as pd +import plotly.express as px +import streamlit as st +from pandera.typing import DataFrame + +from encord_active.lib.metrics.utils import ( + AnnotatorInfo, + MetricSchema, + get_annotator_level_info, +) + + +def render_annotator_properties(df: DataFrame[MetricSchema]): + annotators = get_annotator_level_info(df) + left_col, right_col = st.columns([2, 2]) + + # 1. Pie Chart + left_col.markdown( + "
Distribution of the annotations
", unsafe_allow_html=True + ) + annotators_df = pd.DataFrame(annotators.values()) + + fig = px.pie(annotators_df, values="total_annotations", names="name", hover_data=["mean_score"]) + + left_col.plotly_chart(fig, use_container_width=True) + + # 2. Table View + right_col.markdown( + "
Detailed annotator statistics
", unsafe_allow_html=True + ) + + total_mean_score = annotators_df["mean_score"].mean() + annotators_df.loc[len(annotators_df.index)] = AnnotatorInfo( + name="all", total_annotations=annotators_df.shape[0], mean_score=total_mean_score + ) + + deviations = 100 * ((np.array(annotators_df["mean_score"]) - total_mean_score) / total_mean_score) + annotators_df["deviations"] = deviations + + right_col.dataframe(annotators_df.style.pipe(make_pretty), use_container_width=True) + + +def _format_deviation(val): + return f"{val:.1f}%" + + +def _format_score(val): + return f"{val:.3f}" + + +def _color_red_or_green(val): + color = "red" if val < 0 else "green" + return f"color: {color}" + + +def make_pretty(styler): + styler.format(_format_deviation, subset=["deviations"]) + styler.format(_format_score, subset=["mean_score"]) + styler.applymap(_color_red_or_green, subset=["deviations"]) + return styler diff --git a/src/encord_active/app/common/components/data_tags.py b/src/encord_active/app/common/components/data_tags.py index eed8b0014..4c1d5187f 100644 --- a/src/encord_active/app/common/components/data_tags.py +++ b/src/encord_active/app/common/components/data_tags.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional import pandas as pd import streamlit as st @@ -55,13 +55,13 @@ def get_icon_color(self, value): ] TAG_TEMPLATE = """
- + %s %s
""" -def build_data_tags(row: pd.Series, metric_name: str): +def build_data_tags(row: pd.Series, metric_name: Optional[str] = None): tag_list = [] for p in properties: key = p.get_key(metric_name) diff --git a/src/encord_active/app/common/components/label_statistics.py b/src/encord_active/app/common/components/label_statistics.py new file mode 100644 index 000000000..5db3db60c --- /dev/null +++ b/src/encord_active/app/common/components/label_statistics.py @@ -0,0 +1,38 @@ +import pandas as pd +import plotly.express as px +import streamlit as st +from natsort import natsorted +from pandera.typing import DataFrame + +from encord_active.lib.metrics.utils import MetricSchema + + +def render_dataset_properties(current_df: DataFrame[MetricSchema]): + dataset_columns = st.columns(3) + + cls_set = natsorted(list(current_df[MetricSchema.object_class].unique())) + + dataset_columns[0].metric("Number of labels", current_df.shape[0]) + dataset_columns[1].metric("Number of classes", len(cls_set)) + dataset_columns[2].metric("Number of images", get_unique_data_units_size(current_df)) + + if len(cls_set) > 1: + classes = {} + for cls in cls_set: + classes[cls] = (current_df[MetricSchema.object_class] == cls).sum() + + source = pd.DataFrame({"class": list(classes.keys()), "count": list(classes.values())}) + + fig = px.bar(source, x="class", y="count") + fig.update_layout(title_text="Distribution of the classes", title_x=0.5, title_font_size=20) + st.plotly_chart(fig, use_container_width=True) + + +def get_unique_data_units_size(current_df: pd.DataFrame): + data_units = set() + identifiers = current_df["identifier"] + for identifier in identifiers: + key_components = identifier.split("_") + data_units.add(key_components[0] + "_" + key_components[1]) + + return len(data_units) diff --git a/src/encord_active/app/common/components/metric_summary.py b/src/encord_active/app/common/components/metric_summary.py new file mode 100644 index 000000000..f8ba38a21 --- /dev/null +++ b/src/encord_active/app/common/components/metric_summary.py @@ -0,0 +1,78 @@ +from typing import List + +import pandas as pd +import streamlit as st +from pandera.typing import DataFrame + +from encord_active.app.common.components import build_data_tags +from encord_active.app.common.components.tags.individual_tagging import multiselect_tag +from encord_active.app.common.state import get_state +from encord_active.lib.common.image_utils import show_image_and_draw_polygons +from encord_active.lib.dataset.outliers import IqrOutliers, MetricWithDistanceSchema +from encord_active.lib.metrics.utils import MetricData, MetricScope + +_COLUMNS = MetricWithDistanceSchema + + +def render_metric_summary( + metric: MetricData, df: DataFrame[MetricWithDistanceSchema], iqr_outliers: IqrOutliers, metric_scope: MetricScope +): + n_cols = get_state().page_grid_settings.columns + n_rows = get_state().page_grid_settings.rows + page_size = n_cols * n_rows + + st.markdown(metric.meta["long_description"]) + + if iqr_outliers.n_severe_outliers + iqr_outliers.n_moderate_outliers == 0: + st.success("No outliers found!") + return None + + st.error(f"Number of severe outliers: {iqr_outliers.n_severe_outliers}/{len(df)}") + st.warning(f"Number of moderate outliers: {iqr_outliers.n_moderate_outliers}/{len(df)}") + + max_value = float(df[_COLUMNS.dist_to_iqr].max()) + min_value = float(df[_COLUMNS.dist_to_iqr].min()) + value = st.slider( + "distance to IQR", + min_value=min_value, + max_value=max_value, + step=max(0.1, float((max_value - min_value) / (len(df) / page_size))), + value=max_value, + key=f"dist_to_iqr{metric.name}", + ) + + selected_df: DataFrame[MetricWithDistanceSchema] = df[df[_COLUMNS.dist_to_iqr] <= value][:page_size] + + cols: List = [] + for _, row in selected_df.iterrows(): + if not cols: + cols = list(st.columns(n_cols)) + + with cols.pop(0): + render_summary_item(row, metric.name, iqr_outliers, metric_scope) + + +def render_summary_item(row, metric_name: str, iqr_outliers: IqrOutliers, metric_scope: MetricScope): + image = show_image_and_draw_polygons(row, get_state().project_paths.data) + st.image(image) + + multiselect_tag(row, f"{metric_name}_summary", metric_scope) + + # === Write scores and link to editor === # + + tags_row = row.copy() + if row["score"] > iqr_outliers.severe_ub or row["score"] < iqr_outliers.severe_lb: + tags_row["outlier"] = "Severe" + elif row["score"] > iqr_outliers.moderate_ub or row["score"] < iqr_outliers.moderate_lb: + tags_row["outlier"] = "Moderate" + else: + tags_row["outlier"] = "Low" + + if "object_class" in tags_row and not pd.isna(tags_row["object_class"]): + tags_row["label_class_name"] = tags_row["object_class"] + tags_row.drop("object_class") + tags_row[metric_name] = tags_row["score"] + build_data_tags(tags_row, metric_name) + + if not pd.isnull(row["description"]): + st.write(f"Description: {row['description']}. ") diff --git a/src/encord_active/app/common/components/multi_select.py b/src/encord_active/app/common/components/multi_select.py deleted file mode 100644 index b6bd9adc5..000000000 --- a/src/encord_active/app/common/components/multi_select.py +++ /dev/null @@ -1,39 +0,0 @@ -import logging -from typing import List, Optional, cast - -import streamlit as st -from natsort import natsorted - -logger = logging.getLogger(__name__) - - -def multiselect_with_all_option( - label: str, - options: List[str], - key: str, - custom_all_name: str = "All", - default: Optional[List[str]] = None, - **kwargs, -): - if "on_change" in kwargs: - logger.warning("`st.multiselect.on_change` is being overwritten by custom multiselect with All option") - - default_list = [custom_all_name] if default is None else default - - # Filter classes - def on_change(prev: List[str] = default_list): - next: List[str] = (st.session_state.get(key) or default_list).copy() - - if custom_all_name not in prev and custom_all_name in next: - next = [custom_all_name] - if len(next) > 1 and custom_all_name in next: - next.remove(custom_all_name) - - st.session_state[key] = next - - sorted_options = cast(List[str], natsorted(list(options))) - options = [custom_all_name] + sorted_options - - current: List[str] = st.session_state.get(key) or default_list - - return st.multiselect(label, options, key=key, on_change=on_change, args=(current,), default=default_list, **kwargs) diff --git a/src/encord_active/app/common/components/paginator.py b/src/encord_active/app/common/components/paginator.py new file mode 100644 index 000000000..81ad0f855 --- /dev/null +++ b/src/encord_active/app/common/components/paginator.py @@ -0,0 +1,21 @@ +import pandas as pd +import streamlit as st + + +def render_pagination(df: pd.DataFrame, n_cols: int, n_rows: int, sort_key: str): + n_items = n_cols * n_rows + col1, col2 = st.columns(spec=[1, 4]) + + with col1: + sorting_order = st.selectbox("Sort samples within selected interval", ["Ascending", "Descending"]) + + with col2: + last = len(df) // n_items + 1 + page_num = st.slider("Page", 1, last) if last > 1 else 1 + + low_lim = (page_num - 1) * n_items + high_lim = page_num * n_items + + sorted_subset = df.sort_values(by=sort_key, ascending=sorting_order == "Ascending") + paginated_subset = sorted_subset[low_lim:high_lim] + return paginated_subset diff --git a/src/encord_active/app/common/components/prediction_grid.py b/src/encord_active/app/common/components/prediction_grid.py new file mode 100644 index 000000000..a6de5fe01 --- /dev/null +++ b/src/encord_active/app/common/components/prediction_grid.py @@ -0,0 +1,115 @@ +from pathlib import Path +from typing import List, Optional + +import pandas as pd +import streamlit as st +from pandera.typing import DataFrame + +from encord_active.app.common.components import build_data_tags +from encord_active.app.common.components.paginator import render_pagination +from encord_active.app.common.components.slicer import render_df_slicer +from encord_active.app.common.components.tags.bulk_tagging_form import ( + BulkLevel, + action_bulk_tags, + bulk_tagging_form, +) +from encord_active.app.common.components.tags.individual_tagging import multiselect_tag +from encord_active.app.common.state import get_state +from encord_active.lib.common.colors import Color +from encord_active.lib.common.image_utils import ( + draw_object, + show_image_and_draw_polygons, + show_image_with_predictions_and_label, +) +from encord_active.lib.metrics.utils import MetricScope +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) + + +def build_card_for_labels( + label: pd.Series, + predictions: DataFrame[PredictionMatchSchema], + data_dir: Path, + label_color: Color = Color.RED, +): + class_colors = {int(k): idx["color"] for k, idx in get_state().predictions.all_classes.items()} + image = show_image_with_predictions_and_label( + label, predictions, data_dir, label_color=label_color, class_colors=class_colors + ) + st.image(image) + multiselect_tag(label, "false_negatives", MetricScope.MODEL_QUALITY) + + cls = get_state().predictions.all_classes[str(label["class_id"])]["name"] + label = label.copy() + label["label_class_name"] = cls + # === Write scores and link to editor === # + build_data_tags(label, get_state().predictions.metric_datas.selected_label) + + +def build_card_for_predictions(row: pd.Series, data_dir: Path, box_color=Color.GREEN): + image = show_image_and_draw_polygons(row, data_dir, draw_polygons=True, skip_object_hash=True) + image = draw_object(image, row, color=box_color, with_box=True) + st.image(image) + multiselect_tag(row, "metric_view", MetricScope.MODEL_QUALITY) + + # === Write scores and link to editor === # + build_data_tags(row, get_state().predictions.metric_datas.selected_predicion) + + if row[PredictionMatchSchema.false_positive_reason] and not row[PredictionMatchSchema.is_true_positive]: + st.write(f"Reason: {row[PredictionMatchSchema.false_positive_reason]}") + + +def build_card( + row: pd.Series, + predictions: Optional[DataFrame[PredictionMatchSchema]], + data_dir: Path, + box_color: Color = Color.GREEN, +): + if predictions is not None: + build_card_for_labels(row, predictions, data_dir, box_color) + else: + build_card_for_predictions(row, data_dir, box_color) + + +def prediction_grid( + data_dir: Path, + model_predictions: DataFrame[PredictionMatchSchema], + labels: Optional[DataFrame[LabelMatchSchema]] = None, + box_color: Color = Color.GREEN, +): + use_labels = labels is not None + if use_labels: + df = labels + additionals = model_predictions + selected_metric = get_state().predictions.metric_datas.selected_label or "" + else: + df = model_predictions + additionals = None + selected_metric = get_state().predictions.metric_datas.selected_predicion or "" + + n_cols, n_rows = get_state().page_grid_settings.columns, get_state().page_grid_settings.rows + subset = render_df_slicer(df, selected_metric) + paginated_subset = render_pagination(subset, n_cols, n_rows, selected_metric) + + form = bulk_tagging_form(MetricScope.MODEL_QUALITY) + if form and form.submitted: + df = paginated_subset if form.level == BulkLevel.PAGE else subset + action_bulk_tags(df, form.tags, form.action) + + if len(paginated_subset) == 0: + st.error("No data in selected quality interval") + else: + cols: List = [] + for _, row in paginated_subset.iterrows(): + frame_additionals: Optional[DataFrame[PredictionMatchSchema]] = None + if additionals is not None: + frame_additionals = additionals[ + additionals[PredictionMatchSchema.img_id] == row[LabelMatchSchema.img_id] + ] + + if not cols: + cols = list(st.columns(n_cols)) + with cols.pop(0): + build_card(row, frame_additionals, data_dir, box_color=box_color) diff --git a/src/encord_active/app/common/components/similarities.py b/src/encord_active/app/common/components/similarities.py new file mode 100644 index 000000000..dd5081139 --- /dev/null +++ b/src/encord_active/app/common/components/similarities.py @@ -0,0 +1,175 @@ +import numpy as np +import streamlit as st +from pandas import Series +from streamlit.delta_generator import DeltaGenerator + +from encord_active.app.common.state import ( + COLLECTIONS_IMAGES, + COLLECTIONS_OBJECTS, + CURRENT_INDEX_HAS_ANNOTATION, + FAISS_INDEX_IMAGE, + FAISS_INDEX_IMAGE_NO_LABEL, + FAISS_INDEX_OBJECT, + IMAGE_KEYS_HAVING_SIMILARITIES, + IMAGE_SIMILARITIES, + IMAGE_SIMILARITIES_NO_LABEL, + K_NEAREST_NUM, + OBJECT_KEYS_HAVING_SIMILARITIES, + OBJECT_SIMILARITIES, + QUESTION_HASH_TO_COLLECTION_INDEXES, + get_state, +) +from encord_active.lib.common.image_utils import ( + load_or_fill_image, + show_image_and_draw_polygons, +) +from encord_active.lib.common.utils import ( + fix_duplicate_image_orders_in_knn_graph_single_row, +) +from encord_active.lib.embeddings.utils import get_key_from_index + + +def show_similar_classification_images(row: Series, expander: DeltaGenerator): + feature_hash = row["identifier"].split("_")[-1] + + if row["identifier"] not in st.session_state[IMAGE_SIMILARITIES][feature_hash].keys(): + add_labeled_image_neighbors_to_cache(row["identifier"], feature_hash) + + nearest_images = st.session_state[IMAGE_SIMILARITIES][feature_hash][row["identifier"]] + + division = 4 + column_id = 0 + st_columns = [] + + for nearest_image in nearest_images: + if column_id == 0: + st_columns = expander.columns(division) + + image = load_or_fill_image(nearest_image["key"], get_state().project_paths.data) + + st_columns[column_id].image(image) + st_columns[column_id].write(f"Annotated as `{nearest_image['name']}`") + column_id += 1 + column_id = column_id % division + + +def show_similar_images(row: Series, expander: DeltaGenerator): + image_identifier = "_".join(row["identifier"].split("_")[:3]) + + if image_identifier not in st.session_state[IMAGE_SIMILARITIES_NO_LABEL].keys(): + add_image_neighbors_to_cache(image_identifier) + + nearest_images = st.session_state[IMAGE_SIMILARITIES_NO_LABEL][image_identifier] + + division = 4 + column_id = 0 + st_columns = [] + + for nearest_image in nearest_images: + if column_id == 0: + st_columns = expander.columns(division) + + image = load_or_fill_image(nearest_image["key"], get_state().project_paths.data) + + st_columns[column_id].image(image) + st_columns[column_id].write(f"Annotated as `{nearest_image['name']}`") + column_id += 1 + column_id = column_id % division + + +def show_similar_object_images(row: Series, expander: DeltaGenerator): + object_identifier = "_".join(row["identifier"].split("_")[:4]) + + if object_identifier not in st.session_state[OBJECT_KEYS_HAVING_SIMILARITIES]: + expander.write("Similarity search is not available for this object.") + return + + if object_identifier not in st.session_state[OBJECT_SIMILARITIES].keys(): + add_object_neighbors_to_cache(object_identifier) + + nearest_images = st.session_state[OBJECT_SIMILARITIES][object_identifier] + + division = 4 + column_id = 0 + st_columns = [] + + for nearest_image in nearest_images: + if column_id == 0: + st_columns = expander.columns(division) + + image = show_image_and_draw_polygons(nearest_image["key"], get_state().project_paths.data) + + st_columns[column_id].image(image) + st_columns[column_id].write(f"Annotated as `{nearest_image['name']}`") + column_id += 1 + column_id = column_id % division + + +def add_labeled_image_neighbors_to_cache(image_identifier: str, question_feature_hash: str) -> None: + collection_id = st.session_state[IMAGE_KEYS_HAVING_SIMILARITIES]["_".join(image_identifier.split("_")[:3])] + collection_item_index = st.session_state[QUESTION_HASH_TO_COLLECTION_INDEXES][question_feature_hash].index( + collection_id + ) + embedding = np.array([st.session_state[COLLECTIONS_IMAGES][collection_id]["embedding"]]).astype(np.float32) + _, nearest_indexes = st.session_state[FAISS_INDEX_IMAGE][question_feature_hash].search( + embedding, int(st.session_state[K_NEAREST_NUM] + 1) + ) + nearest_indexes = fix_duplicate_image_orders_in_knn_graph_single_row(collection_item_index, nearest_indexes) + + temp_list = [] + for nearest_index in nearest_indexes[0, 1:]: + collection_index = st.session_state[QUESTION_HASH_TO_COLLECTION_INDEXES][question_feature_hash][nearest_index] + temp_list.append( + { + "key": get_key_from_index( + st.session_state[COLLECTIONS_IMAGES][collection_index], + question_hash=question_feature_hash, + has_annotation=st.session_state[CURRENT_INDEX_HAS_ANNOTATION], + ), + "name": st.session_state[COLLECTIONS_IMAGES][collection_index]["classification_answers"][ + question_feature_hash + ]["answer_name"], + } + ) + + st.session_state[IMAGE_SIMILARITIES][question_feature_hash][image_identifier] = temp_list + + +def _get_nearest_items_from_nearest_indexes(nearest_indexes: np.ndarray, collection_type: str) -> list[dict]: + temp_list = [] + for nearest_index in nearest_indexes[0, 1:]: + temp_list.append( + { + "key": get_key_from_index( + st.session_state[collection_type][nearest_index], + has_annotation=st.session_state[CURRENT_INDEX_HAS_ANNOTATION], + ), + "name": st.session_state[collection_type][nearest_index].get("name", "No label"), + } + ) + + return temp_list + + +def add_object_neighbors_to_cache(object_identifier: str) -> None: + item_index = st.session_state[OBJECT_KEYS_HAVING_SIMILARITIES][object_identifier] + embedding = np.array([st.session_state[COLLECTIONS_OBJECTS][item_index]["embedding"]]).astype(np.float32) + _, nearest_indexes = st.session_state[FAISS_INDEX_OBJECT].search( + embedding, int(st.session_state[K_NEAREST_NUM] + 1) + ) + + st.session_state[OBJECT_SIMILARITIES][object_identifier] = _get_nearest_items_from_nearest_indexes( + nearest_indexes, COLLECTIONS_OBJECTS + ) + + +def add_image_neighbors_to_cache(image_identifier: str): + collection_id = st.session_state[IMAGE_KEYS_HAVING_SIMILARITIES][image_identifier] + embedding = np.array([st.session_state[COLLECTIONS_IMAGES][collection_id]["embedding"]]).astype(np.float32) + _, nearest_indexes = st.session_state[FAISS_INDEX_IMAGE_NO_LABEL].search( + embedding, int(st.session_state[K_NEAREST_NUM] + 1) + ) + + st.session_state[IMAGE_SIMILARITIES_NO_LABEL][image_identifier] = _get_nearest_items_from_nearest_indexes( + nearest_indexes, COLLECTIONS_IMAGES + ) diff --git a/src/encord_active/app/common/components/slicer.py b/src/encord_active/app/common/components/slicer.py new file mode 100644 index 000000000..0058dcfcb --- /dev/null +++ b/src/encord_active/app/common/components/slicer.py @@ -0,0 +1,22 @@ +from typing import Optional + +import numpy as np +import pandas as pd +import streamlit as st + + +def render_df_slicer(df: pd.DataFrame, selected_metric: Optional[str]): + if selected_metric not in df: + return df + + max_val = float(df[selected_metric].max()) + np.finfo(float).eps.item() + min_val = float(df[selected_metric].min()) + + if max_val <= min_val: + return df + + step = max(0.01, (max_val - min_val) // 100) + start, end = st.slider("Choose quality", max_value=max_val, min_value=min_val, value=(min_val, max_val), step=step) + subset = df[df[selected_metric].between(start, end)] + + return subset diff --git a/src/encord_active/app/actions_page/coco_parser/__init__.py b/src/encord_active/app/common/components/tags/__init__.py similarity index 100% rename from src/encord_active/app/actions_page/coco_parser/__init__.py rename to src/encord_active/app/common/components/tags/__init__.py diff --git a/src/encord_active/app/common/components/bulk_tagging_form.py b/src/encord_active/app/common/components/tags/bulk_tagging_form.py similarity index 72% rename from src/encord_active/app/common/components/bulk_tagging_form.py rename to src/encord_active/app/common/components/tags/bulk_tagging_form.py index f69c55240..65450e690 100644 --- a/src/encord_active/app/common/components/bulk_tagging_form.py +++ b/src/encord_active/app/common/components/tags/bulk_tagging_form.py @@ -4,15 +4,14 @@ import streamlit as st from pandas import DataFrame -import encord_active.app.common.state as state -from encord_active.app.common.components.individual_tagging import target_identifier -from encord_active.app.common.components.tag_creator import ( - METRIC_TYPE_SCOPES, - scoped_tags, +from encord_active.app.common.components.tags.individual_tagging import ( + target_identifier, ) -from encord_active.app.data_quality.common import MetricType -from encord_active.app.db.merged_metrics import MergedMetrics -from encord_active.app.db.tags import Tag +from encord_active.app.common.components.tags.tag_creator import scoped_tags +from encord_active.app.common.state import get_state +from encord_active.lib.db.merged_metrics import MergedMetrics +from encord_active.lib.db.tags import METRIC_SCOPE_TAG_SCOPES, Tag +from encord_active.lib.metrics.utils import MetricScope class TagAction(str, Enum): @@ -36,7 +35,7 @@ def action_bulk_tags(subset: DataFrame, selected_tags: List[Tag], action: TagAct if not selected_tags: return - all_df: DataFrame = st.session_state[state.MERGED_DATAFRAME].copy() + all_df = get_state().merged_metrics.copy() for tag in selected_tags: target_ids = [target_identifier(id, tag.scope) for id in subset.identifier.to_list()] @@ -50,15 +49,15 @@ def action_bulk_tags(subset: DataFrame, selected_tags: List[Tag], action: TagAct all_df.at[id, "tags"] = next - st.session_state[state.MERGED_DATAFRAME] = all_df + get_state().merged_metrics = all_df MergedMetrics().replace_all(all_df) -def bulk_tagging_form(metric_type: MetricType) -> Optional[TaggingFormResult]: +def bulk_tagging_form(metric_type: MetricScope) -> Optional[TaggingFormResult]: with st.expander("Bulk Tagging"): with st.form("bulk_tagging"): select, level_radio, action_radio, button = st.columns([6, 2, 2, 1]) - allowed_tags = scoped_tags(METRIC_TYPE_SCOPES[metric_type]) + allowed_tags = scoped_tags(METRIC_SCOPE_TAG_SCOPES[metric_type]) selected_tags = select.multiselect( label="Tags", options=allowed_tags, format_func=lambda x: x[0], label_visibility="collapsed" ) diff --git a/src/encord_active/app/common/components/individual_tagging.py b/src/encord_active/app/common/components/tags/individual_tagging.py similarity index 71% rename from src/encord_active/app/common/components/individual_tagging.py rename to src/encord_active/app/common/components/tags/individual_tagging.py index 9d4b1eb74..0f6ad1954 100644 --- a/src/encord_active/app/common/components/individual_tagging.py +++ b/src/encord_active/app/common/components/tags/individual_tagging.py @@ -3,14 +3,11 @@ import streamlit as st from pandas import Series -import encord_active.app.common.state as state -from encord_active.app.common.components.tag_creator import ( - METRIC_TYPE_SCOPES, - scoped_tags, -) -from encord_active.app.data_quality.common import MetricType -from encord_active.app.db.merged_metrics import MergedMetrics -from encord_active.app.db.tags import Tag, TagScope +from encord_active.app.common.components.tags.tag_creator import scoped_tags +from encord_active.app.common.state import get_state +from encord_active.lib.db.merged_metrics import MergedMetrics +from encord_active.lib.db.tags import METRIC_SCOPE_TAG_SCOPES, Tag, TagScope +from encord_active.lib.metrics.utils import MetricScope def target_identifier(identifier: str, scope: TagScope) -> str: @@ -29,23 +26,23 @@ def update_tags(identifier: str, key: str): targeted_tags.setdefault(target_id, []).append(tag) for id, tags in targeted_tags.items(): - st.session_state[state.MERGED_DATAFRAME].at[id, "tags"] = tags + get_state().merged_metrics.at[id, "tags"] = tags MergedMetrics().update_tags(id, tags) -def multiselect_tag(row: Series, key_prefix: str, metric_type: MetricType): +def multiselect_tag(row: Series, key_prefix: str, metric_type: MetricScope): identifier = row["identifier"] if not isinstance(identifier, str): st.error("Multiple rows with the same identifier were found. Please create a new issue.") return - metric_scopes = METRIC_TYPE_SCOPES[metric_type] + metric_scopes = METRIC_SCOPE_TAG_SCOPES[metric_type] tag_status = [] for scope in metric_scopes: id = target_identifier(identifier, scope) - tag_status += st.session_state[state.MERGED_DATAFRAME].at[id, "tags"] + tag_status += get_state().merged_metrics.at[id, "tags"] key = f"{key_prefix}_multiselect_{identifier}" diff --git a/src/encord_active/app/common/components/tag_creator.py b/src/encord_active/app/common/components/tags/tag_creator.py similarity index 61% rename from src/encord_active/app/common/components/tag_creator.py rename to src/encord_active/app/common/components/tags/tag_creator.py index 1de99918c..8bbc06a80 100644 --- a/src/encord_active/app/common/components/tag_creator.py +++ b/src/encord_active/app/common/components/tags/tag_creator.py @@ -2,28 +2,17 @@ import streamlit as st -import encord_active.app.common.state as state -from encord_active.app.data_quality.common import MetricType -from encord_active.app.db.tags import Tag, Tags, TagScope - -SCOPE_EMOJI = { - TagScope.DATA.value: "πŸ–ΌοΈ", - TagScope.LABEL.value: "✏️", -} - -METRIC_TYPE_SCOPES = { - MetricType.DATA_QUALITY: {TagScope.DATA}, - MetricType.LABEL_QUALITY: {TagScope.DATA, TagScope.LABEL}, - MetricType.MODEL_QUALITY: {TagScope.DATA}, -} +from encord_active.app.common.state import get_state +from encord_active.lib.db.helpers.tags import count_of_tags +from encord_active.lib.db.tags import SCOPE_EMOJI, Tag, Tags, TagScope def scoped_tags(scopes: Set[TagScope]) -> List[Tag]: - all_tags: List[Tag] = st.session_state.get(state.ALL_TAGS) or [] + all_tags: List[Tag] = get_state().all_tags or [] return [tag for tag in all_tags if tag.scope in scopes] -def on_tag_entered(all_tags: List[Tag], name: str, scope: str): +def on_tag_entered(all_tags: List[Tag], name: str, scope: TagScope): tag = Tag(f"{SCOPE_EMOJI[scope]} {name}", scope) if tag in all_tags: @@ -33,7 +22,7 @@ def on_tag_entered(all_tags: List[Tag], name: str, scope: str): Tags().create_tag(tag) all_tags.append(tag) - st.session_state[state.ALL_TAGS] = all_tags + get_state().all_tags = all_tags def tag_creator(): @@ -47,7 +36,7 @@ def tag_creator(): del st.session_state.new_tag_message all_tags = Tags().all() - st.session_state[state.ALL_TAGS] = all_tags + get_state().all_tags = all_tags with st.form("tag_creation_form", clear_on_submit=False): left, right = st.columns([2, 1]) @@ -71,17 +60,9 @@ def tag_creator(): def tag_display_with_counts(): - all_tags = Tags().all() - if not all_tags: - return - - tag_counts = st.session_state[state.MERGED_DATAFRAME]["tags"].value_counts() - all_counts = {name: 0 for name, _ in all_tags} - for unique_list, count in tag_counts.items(): - for name, *_ in unique_list: - all_counts[name] = all_counts.get(name, 0) + count + all_counts = count_of_tags(get_state().merged_metrics) - sorted_tags = sorted(all_counts.items(), key=lambda x: x[0][0].lower()) + sorted_tags = sorted(all_counts.items(), key=lambda x: x[0].lower()) st.markdown( f"""
diff --git a/src/encord_active/app/common/css.py b/src/encord_active/app/common/css.py index be0684f46..4f25d1ad0 100644 --- a/src/encord_active/app/common/css.py +++ b/src/encord_active/app/common/css.py @@ -5,7 +5,7 @@ import streamlit.elements.image as st_image from PIL import Image -from encord_active.app.common.colors import Color +from encord_active.lib.common.colors import Color def write_page_css(): diff --git a/src/encord_active/app/common/embedding_utils.py b/src/encord_active/app/common/embedding_utils.py deleted file mode 100644 index 68f732951..000000000 --- a/src/encord_active/app/common/embedding_utils.py +++ /dev/null @@ -1,219 +0,0 @@ -import os -import pickle -from typing import List, Optional, Tuple - -import faiss -import numpy as np -import streamlit as st -from faiss import IndexFlatL2 - -import encord_active.app.common.state as state -from encord_active.lib.common.utils import ( - fix_duplicate_image_orders_in_knn_graph_single_row, -) - - -@st.experimental_memo(show_spinner=False) -def get_collections(embedding_name: str) -> list[dict]: - embedding_path = st.session_state.embeddings_dir / embedding_name - collections = [] - if os.path.isfile(embedding_path): - with open(embedding_path, "rb") as f: - collections = pickle.load(f) - return collections - - -@st.experimental_memo -def get_collections_and_metadata(embedding_name: str) -> Tuple[list[dict], dict]: - try: - collections = get_collections(embedding_name) - - embedding_metadata_file_name = "embedding_classifications_metadata.pkl" - embedding_metadata_path = st.session_state.embeddings_dir / embedding_metadata_file_name - if os.path.isfile(embedding_metadata_path): - with open(embedding_metadata_path, "rb") as f: - question_hash_to_collection_indexes_local = pickle.load(f) - else: - question_hash_to_collection_indexes_local = {} - - return collections, question_hash_to_collection_indexes_local - except Exception as e: - return [], {} - - -def get_key_from_index(collection: dict, question_hash: Optional[str] = None, has_annotation=True) -> str: - label_hash = collection["label_row"] - du_hash = collection["data_unit"] - frame_idx = int(collection["frame"]) - - if not has_annotation: - key = f"{label_hash}_{du_hash}_{frame_idx:05d}" - else: - if question_hash: - key = f"{label_hash}_{du_hash}_{frame_idx:05d}_{question_hash}" - else: - object_hash = collection["objectHash"] - key = f"{label_hash}_{du_hash}_{frame_idx:05d}_{object_hash}" - - return key - - -def get_identifier_to_neighbors( - collections: list[dict], nearest_indexes: np.ndarray, has_annotation=True -) -> dict[str, list]: - nearest_neighbors = {} - n, k = nearest_indexes.shape - for i in range(n): - key = get_key_from_index(collections[i], has_annotation=has_annotation) - temp_list = [] - for j in range(1, k): - temp_list.append( - { - "key": get_key_from_index(collections[nearest_indexes[i, j]], has_annotation=has_annotation), - "name": collections[nearest_indexes[i, j]].get("name", "Does not have a label"), - } - ) - nearest_neighbors[key] = temp_list - return nearest_neighbors - - -def convert_to_indexes(collections, question_hash_to_collection_indexes): - embedding_databases, indexes = {}, {} - - for question_hash in question_hash_to_collection_indexes: - selected_collections = [collections[i] for i in question_hash_to_collection_indexes[question_hash]] - - if len(selected_collections) > 10: - embedding_database = np.stack(list(map(lambda x: x["embedding"], selected_collections))) - - index = faiss.IndexFlatL2(embedding_database.shape[1]) - index.add(embedding_database) # pylint: disable=no-value-for-parameter - - embedding_databases[question_hash] = embedding_database - indexes[question_hash] = index - - return embedding_databases, indexes - - -@st.experimental_memo -def get_faiss_index_image(_collections: list) -> dict[str, IndexFlatL2]: - indexes = {} - - for question_hash in st.session_state[state.QUESTION_HASH_TO_COLLECTION_INDEXES]: - selected_collections = [ - _collections[i] for i in st.session_state[state.QUESTION_HASH_TO_COLLECTION_INDEXES][question_hash] - ] - - if len(selected_collections) > 10: - embedding_database = np.stack(list(map(lambda x: x["embedding"], selected_collections))) - - index = faiss.IndexFlatL2(embedding_database.shape[1]) - index.add(embedding_database) # pylint: disable=no-value-for-parameter - - indexes[question_hash] = index - - return indexes - - -@st.experimental_memo -def get_faiss_index_object(_collections: list[dict], faiss_index_name: str) -> IndexFlatL2: - """ - - Args: - _collections: Underscore is used to skip hashing for this object to make this function faster. - faiss_index_name: Since we skip hashing for collections, we need another parameter for memoization. - - Returns: Faiss Index object for searching embeddings. - - """ - embeddings_list: List[list] = [x["embedding"] for x in _collections] - embeddings = np.array(embeddings_list).astype(np.float32) - - if len(embeddings.shape) != 2: - return - - db_index = faiss.IndexFlatL2(embeddings.shape[1]) - db_index.add(embeddings) # pylint: disable=no-value-for-parameter - return db_index - - -@st.experimental_memo -def get_object_keys_having_similarities(_collections: list[dict]) -> dict: - return {get_key_from_index(collection): i for i, collection in enumerate(_collections)} - - -@st.experimental_memo -def get_image_keys_having_similarities(_collections: list[dict]) -> dict: - return {get_key_from_index(collection, has_annotation=False): i for i, collection in enumerate(_collections)} - - -def add_labeled_image_neighbors_to_cache(image_identifier: str, question_feature_hash: str) -> None: - collection_id = st.session_state[state.IMAGE_KEYS_HAVING_SIMILARITIES]["_".join(image_identifier.split("_")[:3])] - collection_item_index = st.session_state[state.QUESTION_HASH_TO_COLLECTION_INDEXES][question_feature_hash].index( - collection_id - ) - embedding = np.array([st.session_state[state.COLLECTIONS_IMAGES][collection_id]["embedding"]]).astype(np.float32) - nearest_distances, nearest_indexes = st.session_state[state.FAISS_INDEX_IMAGE][question_feature_hash].search( - embedding, int(st.session_state[state.K_NEAREST_NUM] + 1) - ) - nearest_indexes = fix_duplicate_image_orders_in_knn_graph_single_row(collection_item_index, nearest_indexes) - - temp_list = [] - for nearest_index in nearest_indexes[0, 1:]: - collection_index = st.session_state[state.QUESTION_HASH_TO_COLLECTION_INDEXES][question_feature_hash][ - nearest_index - ] - temp_list.append( - { - "key": get_key_from_index( - st.session_state[state.COLLECTIONS_IMAGES][collection_index], - question_hash=question_feature_hash, - has_annotation=st.session_state[state.CURRENT_INDEX_HAS_ANNOTATION], - ), - "name": st.session_state[state.COLLECTIONS_IMAGES][collection_index]["classification_answers"][ - question_feature_hash - ]["answer_name"], - } - ) - - st.session_state[state.IMAGE_SIMILARITIES][question_feature_hash][image_identifier] = temp_list - - -def _get_nearest_items_from_nearest_indexes(nearest_indexes: np.ndarray, collection_type: str) -> list[dict]: - temp_list = [] - for nearest_index in nearest_indexes[0, 1:]: - temp_list.append( - { - "key": get_key_from_index( - st.session_state[collection_type][nearest_index], - has_annotation=st.session_state[state.CURRENT_INDEX_HAS_ANNOTATION], - ), - "name": st.session_state[collection_type][nearest_index].get("name", "No label"), - } - ) - - return temp_list - - -def add_object_neighbors_to_cache(object_identifier: str) -> None: - item_index = st.session_state[state.OBJECT_KEYS_HAVING_SIMILARITIES][object_identifier] - embedding = np.array([st.session_state[state.COLLECTIONS_OBJECTS][item_index]["embedding"]]).astype(np.float32) - nearest_distances, nearest_indexes = st.session_state[state.FAISS_INDEX_OBJECT].search( - embedding, int(st.session_state[state.K_NEAREST_NUM] + 1) - ) - - st.session_state[state.OBJECT_SIMILARITIES][object_identifier] = _get_nearest_items_from_nearest_indexes( - nearest_indexes, state.COLLECTIONS_OBJECTS - ) - - -def add_image_neighbors_to_cache(image_identifier: str): - collection_id = st.session_state[state.IMAGE_KEYS_HAVING_SIMILARITIES][image_identifier] - embedding = np.array([st.session_state[state.COLLECTIONS_IMAGES][collection_id]["embedding"]]).astype(np.float32) - nearest_distances, nearest_indexes = st.session_state[state.FAISS_INDEX_IMAGE_NO_LABEL].search( - embedding, int(st.session_state[state.K_NEAREST_NUM] + 1) - ) - - st.session_state[state.IMAGE_SIMILARITIES_NO_LABEL][image_identifier] = _get_nearest_items_from_nearest_indexes( - nearest_indexes, state.COLLECTIONS_IMAGES - ) diff --git a/src/encord_active/app/common/metric.py b/src/encord_active/app/common/metric.py deleted file mode 100644 index 20cd2a0a4..000000000 --- a/src/encord_active/app/common/metric.py +++ /dev/null @@ -1,43 +0,0 @@ -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict - -import pandas as pd -import streamlit as st - - -@dataclass -class MetricData: - name: str - path: Path - meta: Dict[str, Any] - level: str - - -@st.cache(allow_output_mutation=True) -def load_metric(metric: MetricData, normalize: bool, *, sorting_key="score") -> pd.DataFrame: - """ - Load and sort the selected csv file and cache it, so we don't need to perform this - heavy computation each time the slider in the UI is moved. - :param metric: The metric to load data from. - :param normalize: whether to apply normalisation to the scores or not. - :param sorting_key: key by which to sort dataframe (default: "score") - :return: a pandas data frame with all the scores. - """ - df = pd.read_csv(metric.path).sort_values([sorting_key, "identifier"], ascending=True).reset_index() - - if normalize: - min_val = metric.meta.get("min_value") - max_val = metric.meta.get("max_value") - if min_val is None: - min_val = df["score"].min() - if max_val is None: - max_val = df["score"].max() - - diff = max_val - min_val - if diff == 0: # Avoid dividing by zero - diff = 1.0 - - df["score"] = (df["score"] - min_val) / diff - - return df diff --git a/src/encord_active/app/common/page.py b/src/encord_active/app/common/page.py index daae56dc3..73a8cdc85 100644 --- a/src/encord_active/app/common/page.py +++ b/src/encord_active/app/common/page.py @@ -3,7 +3,8 @@ import streamlit as st import encord_active.app.common.state as state -from encord_active.lib.common.metric import EmbeddingType +from encord_active.app.common.state import get_state +from encord_active.lib.metrics.metric import EmbeddingType class Page(ABC): @@ -32,52 +33,46 @@ def __repr__(self): @staticmethod def row_col_settings_in_sidebar(): col_default_max, row_default_max = 10, 5 - default_mv_column_num = 4 - default_mv_row_num = 5 default_knn_num = 8 with st.expander("Settings"): - if state.MAIN_VIEW_COLUMN_NUM not in st.session_state: - st.session_state[state.MAIN_VIEW_COLUMN_NUM] = default_mv_column_num - - if state.MAIN_VIEW_ROW_NUM not in st.session_state: - st.session_state[state.MAIN_VIEW_ROW_NUM] = default_mv_row_num - if state.K_NEAREST_NUM not in st.session_state: st.session_state[state.K_NEAREST_NUM] = default_knn_num with st.form(key="settings_from"): - st.checkbox( + get_state().normalize_metrics = st.checkbox( "Metric normalization", - key=state.NORMALIZATION_STATUS, + value=get_state().normalize_metrics, help="If checked, score values will be normalized between 0 and 1. Otherwise, \ original values will be shown.", ) setting_columns = st.columns(2) - setting_columns[0].number_input( - "Columns", - min_value=2, - max_value=col_default_max, - value=st.session_state[state.MAIN_VIEW_COLUMN_NUM], - key=state.MAIN_VIEW_COLUMN_NUM, - help="Number of columns to show images in the main view", + get_state().page_grid_settings.columns = int( + setting_columns[0].number_input( + "Columns", + min_value=2, + max_value=col_default_max, + value=get_state().page_grid_settings.columns, + help="Number of columns to show images in the main view", + ) ) - setting_columns[1].number_input( - "Rows", - min_value=1, - max_value=row_default_max, - value=st.session_state[state.MAIN_VIEW_ROW_NUM], - key=state.MAIN_VIEW_ROW_NUM, - help="Number of rows to show images in the main view", + get_state().page_grid_settings.rows = int( + setting_columns[1].number_input( + "Rows", + min_value=1, + max_value=row_default_max, + value=get_state().page_grid_settings.rows, + help="Number of rows to show images in the main view", + ) ) - if state.DATA_PAGE_METRIC in st.session_state.keys(): - if st.session_state[state.DATA_PAGE_METRIC].meta.get( - "embedding_type", EmbeddingType.NONE.value - ) in [ + selected_metric = get_state().selected_metric + + if selected_metric: + if selected_metric.meta.get("embedding_type", EmbeddingType.NONE.value) in [ EmbeddingType.CLASSIFICATION.value, EmbeddingType.OBJECT.value, ]: diff --git a/src/encord_active/app/common/state.py b/src/encord_active/app/common/state.py index 0a61d6286..1a9b58004 100644 --- a/src/encord_active/app/common/state.py +++ b/src/encord_active/app/common/state.py @@ -1,20 +1,75 @@ +from dataclasses import dataclass, field from pathlib import Path +from typing import Any, Callable, Dict, List, Optional +import pandas as pd import streamlit as st +from pandera.typing import DataFrame -# CONSTANTS -PROJECT_CACHE_FILE = Path.home() / ".encord_quality" / "current_project_dir.txt" +from encord_active.lib.db.merged_metrics import MergedMetrics +from encord_active.lib.db.tags import Tag, Tags +from encord_active.lib.metrics.utils import MetricData +from encord_active.lib.model_predictions.reader import LabelSchema, OntologyObjectJSON +from encord_active.lib.project.project_file_structure import ProjectFileStructure -# DATABASE -DB_FILE_NAME = "sqlite.db" +GLOBAL_STATE = "global_state" -# STATE VARIABLE KEYS -CLASS_SELECTION = "class_selection" -IGNORE_FRAMES_WO_PREDICTIONS = "ignore_frames_wo_predictions" -IOU_THRESHOLD = "iou_threshold_scaled" # After normalization -IOU_THRESHOLD_ = "iou_threshold_" # Before normalization -MERGED_DATAFRAME = "merged_dataframe" -ALL_TAGS = "all_tags" + +@dataclass +class MetricNames: + predictions: Dict[str, MetricData] = field(default_factory=dict) + selected_predicion: Optional[str] = None + labels: Dict[str, MetricData] = field(default_factory=dict) + selected_label: Optional[str] = None + + +@dataclass +class PredictionsState: + decompose_classes = False + metric_datas = MetricNames() + all_classes: Dict[str, OntologyObjectJSON] = field(default_factory=dict) + selected_classes: Dict[str, OntologyObjectJSON] = field(default_factory=dict) + labels: Optional[DataFrame[LabelSchema]] = None + nbins = 50 + + +@dataclass +class PageGridSettings: + columns: int = 4 + rows: int = 5 + + +@dataclass +class State: + """This is not intended for usage, please use the `get_state` constant instead.""" + + project_paths: ProjectFileStructure + all_tags: List[Tag] + merged_metrics: pd.DataFrame + ignore_frames_without_predictions = False + iou_threshold = 0.5 + selected_metric: Optional[MetricData] = None + page_grid_settings = PageGridSettings() + normalize_metrics = False + predictions = PredictionsState() + + @classmethod + def init(cls, project_dir: Path): + if GLOBAL_STATE in st.session_state: + return + + st.session_state[GLOBAL_STATE] = State( + project_paths=ProjectFileStructure(project_dir), + merged_metrics=MergedMetrics().all(), + all_tags=Tags().all(), + ) + + +def get_state() -> State: + return st.session_state.get(GLOBAL_STATE) # type: ignore + + +# EVERYTHING BELOW SHOULD BE DEPRACATED # SIMILARITY KEYS OBJECT_KEYS_HAVING_SIMILARITIES = "object_keys_having_similarities" @@ -29,61 +84,10 @@ QUESTION_HASH_TO_COLLECTION_INDEXES = "question_hash_to_collection_indexes" COLLECTIONS_IMAGES = "collections_images" COLLECTIONS_OBJECTS = "collections_objects" - - -# DATA QUALITY PAGE -DATA_PAGE_METRIC = "data_page_metric" # metric -DATA_PAGE_METRIC_NAME = "data_page_metric_name" # metric name -DATA_PAGE_CLASS = "data_page_class_selection" # class -DATA_PAGE_ANNOTATOR = "data_page_annotator_selection" # annotator - - -# PREDICTIONS PAGE -PREDICTIONS_LABEL_METRIC = "predictions_label_metric" -PREDICTIONS_DECOMPOSE_CLASSES = "predictions_decompose_classes" -PREDICTIONS_METRIC = "predictions_metric" -PREDICTIONS_NBINS = "predictions_nbins" - -# TILING & PAGINATION -MAIN_VIEW_COLUMN_NUM = "main_view_column_num" -MAIN_VIEW_ROW_NUM = "main_view_row_num" K_NEAREST_NUM = "k_nearest_num" -METRIC_VIEW_PAGE_NUMBER = "metric_view_page_number" -FALSE_NEGATIVE_VIEW_PAGE_NUMBER = "false_negative_view_page_number" - -NORMALIZATION_STATUS = "normalization_status" -METRIC_METADATA_SCORE_NORMALIZATION = "score_normalization" - -# Export page -NUMBER_OF_PARTITIONS = "number_of_partitions" -ACTION_PAGE_CLONE_BUTTON = "action_page_clone_button" -ACTION_PAGE_PREVIOUS_FILTERED_NUM = "action_page_previous_filtered" - - -def set_project_dir(project_dir: str) -> bool: - _project_dir = Path(project_dir).expanduser().absolute() - - if not _project_dir.is_dir(): - return False - else: - PROJECT_CACHE_FILE.parent.mkdir(parents=True, exist_ok=True) - with PROJECT_CACHE_FILE.open("w", encoding="utf-8") as f: - f.write(_project_dir.as_posix()) - st.session_state.project_dir = _project_dir - return True - - -def populate_session_state(): - if "project_dir" not in st.session_state: - # Try using a cached one - if PROJECT_CACHE_FILE.is_file(): - with PROJECT_CACHE_FILE.open("r", encoding="utf-8") as f: - st.session_state.project_dir = Path(f.readline()) - st.session_state.metric_dir = st.session_state.project_dir / "metrics" - st.session_state.embeddings_dir = st.session_state.project_dir / "embeddings" - st.session_state.predictions_dir = st.session_state.project_dir / "predictions" - st.session_state.data_dir = st.session_state.project_dir / "data" - st.session_state.ontology_file = st.session_state.project_dir / "ontology.json" - st.session_state.db_path = st.session_state.project_dir / DB_FILE_NAME +def setdefault(key: str, fn: Callable, *args, **kwargs) -> Any: + if not key in st.session_state: + st.session_state[key] = fn(*args, **kwargs) + return st.session_state.get(key) diff --git a/src/encord_active/app/common/state_hooks.py b/src/encord_active/app/common/state_hooks.py new file mode 100644 index 000000000..63faaeb76 --- /dev/null +++ b/src/encord_active/app/common/state_hooks.py @@ -0,0 +1,65 @@ +import inspect +from typing import Callable, Optional, TypeVar, Union, overload + +import streamlit as st + +T = TypeVar("T") +Reducer = Callable[[T], T] + +SCOPED_STATES = "scoped_states" + + +def create_key(): + stk = inspect.stack() + frame = stk[2] + return f"{frame.filename}:{frame.function}:{frame.lineno}" + + +def use_memo(initial: Callable[[], T], key: Optional[str] = None): + key = key or create_key() + st.session_state.setdefault(SCOPED_STATES, {}) + + if key not in st.session_state[SCOPED_STATES]: + st.session_state[SCOPED_STATES][key] = initial() + + value: T = st.session_state[SCOPED_STATES][key] + return value + + +def use_lazy_state(initial: Callable[[], T], key: Optional[str] = None): + key = key or create_key() + st.session_state.setdefault(SCOPED_STATES, {}) + + if key not in st.session_state[SCOPED_STATES]: + st.session_state[SCOPED_STATES][key] = initial() + value: T = st.session_state[SCOPED_STATES][key] + + return use_state(value, key) + + +# TODO: is there a way to make this work and have proper types? +# def use_state(initial: Union[T, Callable[[], T]], key: Optional[str] = create_key()): +# st.session_state.setdefault(SCOPED_STATES, {}) +# st.session_state[SCOPED_STATES].setdefault(key, initial() if callable(initial) else initial) +def use_state(initial: T, key: Optional[str] = None): + key = key or create_key() + st.session_state.setdefault(SCOPED_STATES, {}).setdefault(key, initial) + + @overload + def set_state(arg: T): + ... + + @overload + def set_state(arg: Reducer[T]): + ... + + def set_state(arg: Union[T, Reducer[T]]): + if callable(arg): + st.session_state[SCOPED_STATES][key] = arg(st.session_state[SCOPED_STATES][key]) + else: + st.session_state[SCOPED_STATES][key] = arg + + def get_state() -> T: + return st.session_state[SCOPED_STATES][key] + + return get_state, set_state diff --git a/src/encord_active/app/common/utils.py b/src/encord_active/app/common/utils.py index fa3fe83bd..46f5ee744 100644 --- a/src/encord_active/app/common/utils.py +++ b/src/encord_active/app/common/utils.py @@ -1,24 +1,12 @@ -import json -from json import JSONDecodeError from pathlib import Path -from typing import List, Optional, Tuple, Union -import cv2 -import numpy as np -import pandas as pd import streamlit as st -import encord_active.app.common.state as state -from encord_active.app.common.colors import Color, hex_to_rgb from encord_active.app.common.css import write_page_css -from encord_active.app.common.state import populate_session_state -from encord_active.app.db.merged_metrics import MergedMetrics -from encord_active.lib.common.utils import get_du_size def set_page_config(): - project_root = Path(__file__).parents[1] - favicon_pth = project_root / "assets" / "favicon-32x32.png" + favicon_pth = Path(__file__).parents[1] / "assets" / "favicon-32x32.png" st.set_page_config( page_title="Encord Active", layout="wide", @@ -27,203 +15,4 @@ def set_page_config(): def setup_page(): - populate_session_state() write_page_css() - - -def load_json(json_file: Path) -> Optional[dict]: - if not json_file.exists(): - return None - - with json_file.open("r", encoding="utf-8") as f: - try: - return json.load(f) - except JSONDecodeError: - return None - - -def load_or_fill_image(row: Union[pd.Series, str]) -> np.ndarray: - """ - Tries to read the infered image path. If not possible, generates a white image - and indicates what the error seemd to be embedded in the image. - :param row: A csv row from either a metric, a prediction, or a label csv file. - :return: Numpy / cv2 image. - """ - - read_error = False - key = __get_key(row) - - img_pth: Optional[Path] = key_to_image_path(key) - - if img_pth and img_pth.is_file(): - try: - image = cv2.imread(img_pth.as_posix()) - return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - except cv2.error: - pass - - # Read not successful, so tell the user why - error_text = "Image not found" if not img_pth else "File seems broken" - - _, du_hash, *_ = key.split("_") - lr = json.loads(key_to_lr_path(key).read_text(encoding="utf-8")) - - h, w = get_du_size(lr["data_units"].get(du_hash, {}), None) or (600, 900) - - image = np.ones((h, w, 3), dtype=np.uint8) * 255 - image[:4, :] = [255, 0, 0] - image[-4:, :] = [255, 0, 0] - image[:, :4] = [255, 0, 0] - image[:, -4:] = [255, 0, 0] - font = cv2.FONT_HERSHEY_SIMPLEX - pos = int(0.05 * min(w, h)) - cv2.putText(image, error_text, (pos, 2 * pos), font, w / 900, hex_to_rgb("#999999"), 2, cv2.LINE_AA) - - return image - - -def get_df_subset(df: pd.DataFrame, selected_metric: Optional[str]): - if selected_metric not in df: - return df - - max_val = float(df[selected_metric].max()) + np.finfo(float).eps.item() - min_val = float(df[selected_metric].min()) - - if max_val <= min_val: - return df - - step = max(0.01, (max_val - min_val) // 100) - start, end = st.slider("Choose quality", max_value=max_val, min_value=min_val, value=(min_val, max_val), step=step) - subset = df[df[selected_metric].between(start, end)] - - return subset - - -def build_pagination(subset, n_cols, n_rows, selected_metric): - n_items = n_cols * n_rows - col1, col2 = st.columns(spec=[1, 4]) - - with col1: - sorting_order = st.selectbox("Sort samples within selected interval", ["Ascending", "Descending"]) - - with col2: - last = len(subset) // n_items + 1 - page_num = st.slider("Page", 1, last) if last > 1 else 1 - - low_lim = (page_num - 1) * n_items - high_lim = page_num * n_items - - sorted_subset = subset.sort_values(by=selected_metric, ascending=sorting_order == "Ascending") - paginated_subset = sorted_subset[low_lim:high_lim] - return paginated_subset - - -def __get_key(row: Union[pd.Series, str]): - if isinstance(row, pd.Series): - if "identifier" not in row: - raise ValueError("A Series passed but the series doesn't contain 'identifier'") - return str(row["identifier"]) - elif isinstance(row, str): - return row - else: - raise Exception(f"Undefined row type {row}") - - -def __get_geometry(obj: dict, img_h: int, img_w: int) -> Optional[Tuple[str, np.ndarray]]: - """ - Convert Encord object dictionary to polygon coordinates used to draw geometries - with opencv. - - :param obj: the encord object dict - :param w: the image width - :param h: the image height - :return: The polygon coordinates - """ - - if obj["shape"] == "polygon": - p = obj["polygon"] - polygon = np.array([[p[str(i)]["x"] * img_w, p[str(i)]["y"] * img_h] for i in range(len(p))]) - elif obj["shape"] == "bounding_box": - b = obj["boundingBox"] - polygon = np.array( - [ - [b["x"] * img_w, b["y"] * img_h], - [(b["x"] + b["w"]) * img_w, b["y"] * img_h], - [(b["x"] + b["w"]) * img_w, (b["y"] + b["h"]) * img_h], - [b["x"] * img_w, (b["y"] + b["h"]) * img_h], - ] - ) - else: - return None - - polygon = polygon.reshape((-1, 1, 2)).astype(int) - return obj.get("color", Color.PURPLE.value), polygon - - -def get_geometries( - row: Union[pd.Series, str], img_h: int, img_w: int, skip_object_hash: bool = False -) -> List[Tuple[str, np.ndarray]]: - """ - Loads cached label row and computes geometries from the label row. - If the ``identifier`` in the ``row`` contains an object hash, only that object will - be drawn. If no object hash exists, all polygons / bboxes will be drawn. - :param row: the pandas row of the selected csv file. - :return: List of tuples of (hex color, polygon: [[x, y], ...]) - """ - key = __get_key(row) - _, du_hash, frame, *remainder = key.split("_") - - lr_pth = key_to_lr_path(key) - with lr_pth.open("r") as f: - label_row = json.load(f) - - du = label_row["data_units"][du_hash] - - geometries = [] - objects = ( - du["labels"].get("objects", []) - if "video" not in du["data_type"] - else du["labels"][str(int(frame))].get("objects", []) - ) - - if remainder and not skip_object_hash: - # Return specific geometries - geometry_object_hashes = set(remainder) - for obj in objects: - if obj["objectHash"] in geometry_object_hashes: - geometries.append(__get_geometry(obj, img_h=img_h, img_w=img_w)) - else: - # Get all geometries - for obj in objects: - if obj["shape"] not in {"polygon", "bounding_box"}: - continue - geometries.append(__get_geometry(obj, img_h=img_h, img_w=img_w)) - - valid_geometries = list(filter(None, geometries)) - return valid_geometries - - -def key_to_lr_path(key: str) -> Path: - label_hash, *_ = key.split("_") - return st.session_state.data_dir / label_hash / "label_row.json" - - -def key_to_image_path(key: str) -> Optional[Path]: - """ - Infer image path from the identifier stored in the csv files. - :param key: the row["identifier"] from a csv row - :return: The associated image path if it exists or a path to a placeholder otherwise - """ - label_hash, du_hash, frame, *_ = key.split("_") - img_folder = st.session_state.data_dir / label_hash / "images" - - # check if it is a video frame - frame_pth = next(img_folder.glob(f"{du_hash}_{int(frame)}.*"), None) - if frame_pth is not None: - return frame_pth - return next(img_folder.glob(f"{du_hash}.*"), None) # So this is an img_group image - - -def load_merged_df(): - if state.MERGED_DATAFRAME not in st.session_state: - st.session_state[state.MERGED_DATAFRAME] = MergedMetrics().all() diff --git a/src/encord_active/app/data_quality/common.py b/src/encord_active/app/data_quality/common.py deleted file mode 100644 index 2fd7da8c8..000000000 --- a/src/encord_active/app/data_quality/common.py +++ /dev/null @@ -1,82 +0,0 @@ -from enum import Enum -from pathlib import Path -from typing import List, Union - -import cv2 -import numpy as np -import streamlit as st -from natsort import natsorted -from pandas import Series - -from encord_active.app.common.colors import hex_to_rgb -from encord_active.app.common.metric import MetricData -from encord_active.app.common.utils import get_geometries, load_json, load_or_fill_image - - -class MetricType(Enum): - DATA_QUALITY = "data_quality" - LABEL_QUALITY = "label_quality" - MODEL_QUALITY = "model_quality" - - -def get_metric_operation_level(pth: Path) -> str: - if not all([pth.exists(), pth.is_file(), pth.suffix == ".csv"]): - return "" - - with pth.open("r", encoding="utf-8") as f: - _ = f.readline() # Header, which we don't care about - csv_row = f.readline() # Content line - - if not csv_row: # Empty metric - return "" - - key, _ = csv_row.split(",", 1) - _, _, _, *object_hashes = key.split("_") - return "O" if object_hashes else "F" - - -@st.experimental_memo -def load_available_metrics(metric_type_selection: str) -> List[MetricData]: - def criterion(x): - return x is None if metric_type_selection == MetricType.DATA_QUALITY.value else x is not None - - metric_dir = st.session_state.metric_dir - if not metric_dir.is_dir(): - return [] - - paths = natsorted([p for p in metric_dir.iterdir() if p.suffix == ".csv"], key=lambda x: x.stem.split("_", 1)[1]) - levels = list(map(get_metric_operation_level, paths)) - - make_name = lambda p: p.name.split("_", 1)[1].rsplit(".", 1)[0].replace("_", " ").title() - names = [f"{make_name(p)}" for p, l in zip(paths, levels)] - meta_data = [load_json(f.with_suffix(".meta.json")) for f in paths] - - out: List[MetricData] = [] - - if not meta_data: - return out - - for p, n, m, l in zip(paths, names, meta_data, levels): - if m is None or not l or not criterion(m.get("annotation_type")): - continue - - out.append(MetricData(name=n, path=p, meta=m, level=l)) - - out = natsorted(out, key=lambda i: (i.level, i.name)) # type: ignore - return out - - -def show_image_and_draw_polygons(row: Union[Series, str], draw_polygons: bool = True) -> np.ndarray: - # === Read and annotate the image === # - image = load_or_fill_image(row) - - # === Draw polygons / bboxes if available === # - is_closed = True - thickness = int(image.shape[1] / 150) - - img_h, img_w = image.shape[:2] - if draw_polygons: - for color, geometry in get_geometries(row, img_h=img_h, img_w=img_w): - image = cv2.polylines(image, [geometry], is_closed, hex_to_rgb(color), thickness) - - return image diff --git a/src/encord_active/app/data_quality/sub_pages/explorer.py b/src/encord_active/app/data_quality/sub_pages/explorer.py index 699801106..a42d4d13d 100644 --- a/src/encord_active/app/data_quality/sub_pages/explorer.py +++ b/src/encord_active/app/data_quality/sub_pages/explorer.py @@ -1,37 +1,74 @@ import re -from typing import List +from typing import Any, List, Optional -import altair as alt -import numpy as np import pandas as pd -import plotly.express as px import streamlit as st -from natsort import natsorted from pandas import Series +from pandera.typing import DataFrame from streamlit.delta_generator import DeltaGenerator -import encord_active.app.common.state as state -from encord_active.app.common import embedding_utils -from encord_active.app.common.components import ( - build_data_tags, - multiselect_with_all_option, +from encord_active.app.common.components import build_data_tags +from encord_active.app.common.components.annotator_statistics import ( + render_annotator_properties, ) -from encord_active.app.common.components.bulk_tagging_form import ( +from encord_active.app.common.components.label_statistics import ( + render_dataset_properties, +) +from encord_active.app.common.components.paginator import render_pagination +from encord_active.app.common.components.similarities import ( + show_similar_classification_images, + show_similar_images, + show_similar_object_images, +) +from encord_active.app.common.components.slicer import render_df_slicer +from encord_active.app.common.components.tags.bulk_tagging_form import ( BulkLevel, action_bulk_tags, bulk_tagging_form, ) -from encord_active.app.common.components.individual_tagging import multiselect_tag -from encord_active.app.common.components.tag_creator import tag_creator -from encord_active.app.common.metric import MetricData, load_metric +from encord_active.app.common.components.tags.individual_tagging import multiselect_tag +from encord_active.app.common.components.tags.tag_creator import tag_creator from encord_active.app.common.page import Page -from encord_active.app.common.utils import build_pagination, get_df_subset -from encord_active.app.data_quality.common import ( - MetricType, +from encord_active.app.common.state import ( + COLLECTIONS_IMAGES, + COLLECTIONS_OBJECTS, + CURRENT_INDEX_HAS_ANNOTATION, + FAISS_INDEX_IMAGE, + FAISS_INDEX_IMAGE_NO_LABEL, + FAISS_INDEX_OBJECT, + IMAGE_KEYS_HAVING_SIMILARITIES, + IMAGE_SIMILARITIES, + IMAGE_SIMILARITIES_NO_LABEL, + OBJECT_KEYS_HAVING_SIMILARITIES, + OBJECT_SIMILARITIES, + QUESTION_HASH_TO_COLLECTION_INDEXES, + get_state, +) +from encord_active.lib.charts.histogram import get_histogram +from encord_active.lib.common.image_utils import ( load_or_fill_image, show_image_and_draw_polygons, ) -from encord_active.lib.common.metric import AnnotationType, EmbeddingType +from encord_active.lib.embeddings.utils import ( + get_collections, + get_collections_and_metadata, + get_faiss_index_image, + get_faiss_index_object, + get_image_keys_having_similarities, + get_object_keys_having_similarities, +) +from encord_active.lib.metrics.metric import ( + AnnotationType, + EmbeddingType, + MetricMetadata, +) +from encord_active.lib.metrics.utils import ( + MetricData, + MetricSchema, + MetricScope, + get_annotator_level_info, + load_metric_dataframe, +) class ExplorerPage(Page): @@ -44,7 +81,9 @@ def sidebar_options(self, available_metrics: List[MetricData]): st.error("Your data has not been indexed. Make sure you have imported your data correctly.") st.stop() - non_empty_metrics = [metric for metric in available_metrics if not load_metric(metric, normalize=False).empty] + non_empty_metrics = [ + metric for metric in available_metrics if not load_metric_dataframe(metric, normalize=False).empty + ] sorted_metrics = sorted(non_empty_metrics, key=lambda i: i.name) metric_names = list(map(lambda i: i.name, sorted_metrics)) @@ -56,193 +95,106 @@ def sidebar_options(self, available_metrics: List[MetricData]): help="The data in the main view will be sorted by the selected metric. ", ) - metric_idx = metric_names.index(selected_metric_name) - st.session_state[state.DATA_PAGE_METRIC] = sorted_metrics[metric_idx] - st.session_state[state.DATA_PAGE_METRIC_NAME] = selected_metric_name - - if state.NORMALIZATION_STATUS not in st.session_state: - st.session_state[state.NORMALIZATION_STATUS] = st.session_state[state.DATA_PAGE_METRIC].meta.get( - state.METRIC_METADATA_SCORE_NORMALIZATION, True - ) # If there is no information on the meta file, just normalize (its probability is higher) - - df = load_metric( - st.session_state[state.DATA_PAGE_METRIC], normalize=st.session_state[state.NORMALIZATION_STATUS] - ) - - if df.shape[0] > 0: - class_set = sorted(list(df["object_class"].unique())) - with col2: - selected_classes = multiselect_with_all_option("Filter by class", class_set, key=state.DATA_PAGE_CLASS) - - is_class_selected = ( - df.shape[0] * [True] if "All" in selected_classes else df["object_class"].isin(selected_classes) - ) - df_class_selected = df[is_class_selected] - - annotators = get_annotator_level_info(df_class_selected) - annotator_set = sorted(annotators["annotator"]) - - with col3: - selected_annotators = multiselect_with_all_option( - "Filter by annotator", - annotator_set, - key=state.DATA_PAGE_ANNOTATOR, - ) - - annotator_selected = ( - df_class_selected.shape[0] * [True] - if "All" in selected_annotators - else df_class_selected["annotator"].isin(selected_annotators) - ) - - self.row_col_settings_in_sidebar() - # For now go the easy route and just filter the dataframe here - return df_class_selected[annotator_selected] - - def build(self, selected_df: pd.DataFrame, metric_type: MetricType): - st.markdown(f"# {self.title}") - meta = st.session_state[state.DATA_PAGE_METRIC].meta - st.markdown(f"## {meta['title']}") - st.markdown(meta["long_description"]) - - if selected_df.empty: + if not selected_metric_name: return - fill_dataset_properties_window(selected_df) - fill_annotator_properties_window(selected_df) - fill_data_quality_window(selected_df, metric_type) - - -def get_annotator_level_info(df: pd.DataFrame) -> dict[str, list]: - annotator_set = natsorted(list(df["annotator"].unique())) - annotators: dict[str, list] = {"annotator": annotator_set, "total annotation": [], "score": []} - - for annotator in annotator_set: - annotators["total annotation"].append(df[df["annotator"] == annotator].shape[0]) - annotators["score"].append(df[df["annotator"] == annotator]["score"].mean()) - - return annotators - - -def fill_dataset_properties_window(current_df: pd.DataFrame): - dataset_expander = st.expander("Dataset Properties", expanded=True) - dataset_columns = dataset_expander.columns(3) - - cls_set = natsorted(list(current_df["object_class"].unique())) - - dataset_columns[0].metric("Number of labels", current_df.shape[0]) - dataset_columns[1].metric("Number of classes", len(cls_set)) - dataset_columns[2].metric("Number of images", get_unique_data_units_size(current_df)) - - if len(cls_set) > 1: - classes = {} - for cls in cls_set: - classes[cls] = (current_df["object_class"] == cls).sum() - - source = pd.DataFrame({"class": list(classes.keys()), "count": list(classes.values())}) - - fig = px.bar(source, x="class", y="count") - fig.update_layout(title_text="Distribution of the classes", title_x=0.5, title_font_size=20) - dataset_expander.plotly_chart(fig, use_container_width=True) + metric_idx = metric_names.index(selected_metric_name) + selected_metric = sorted_metrics[metric_idx] + get_state().selected_metric = selected_metric + normalize = selected_metric.meta.get("score_normalization", get_state().normalize_metrics) + df = load_metric_dataframe(selected_metric, normalize=normalize) -def fill_annotator_properties_window(current_df: pd.DataFrame): - annotators = get_annotator_level_info(current_df) - if not (len(annotators["annotator"]) == 1 and (not isinstance(annotators["annotator"][0], str))): - annotator_expander = st.expander("Annotator Statistics", expanded=False) + if df.shape[0] <= 0: + return - annotator_columns = annotator_expander.columns([2, 2]) + class_set = sorted(list(df["object_class"].unique())) + with col2: + selected_classes = st.multiselect("Filter by class", class_set) - # 1. Pie Chart - annotator_columns[0].markdown( - "
Distribution of the annotations
", unsafe_allow_html=True - ) - source = pd.DataFrame( - { - "annotator": annotators["annotator"], - "total": annotators["total annotation"], - "score": [f"{score:.3f}" for score in annotators["score"]], - } - ) + is_class_selected = df.shape[0] * [True] if not selected_classes else df["object_class"].isin(selected_classes) + df_class_selected: DataFrame[MetricSchema] = df[is_class_selected] - fig = px.pie(source, values="total", names="annotator", hover_data=["score"]) - # fig.update_layout(title_text="Distribution of the annotations", title_x=0.5, title_font_size=20) + annotators = get_annotator_level_info(df_class_selected) + annotator_set = sorted(annotators.keys()) - annotator_columns[0].plotly_chart(fig, use_container_width=True) + with col3: + selected_annotators = st.multiselect("Filter by annotator", annotator_set) - # 2. Table View - annotator_columns[1].markdown( - "
Detailed annotator statistics
", unsafe_allow_html=True + annotator_selected = ( + df_class_selected.shape[0] * [True] + if not selected_annotators + else df_class_selected["annotator"].isin(selected_annotators) ) - annotators["annotator"].append("all") - annotators["total annotation"].append(current_df.shape[0]) + self.row_col_settings_in_sidebar() + # For now go the easy route and just filter the dataframe here + return df_class_selected[annotator_selected] - df_mean_score = current_df["score"].mean() - annotators["score"].append(df_mean_score) - deviations = 100 * ((np.array(annotators["score"]) - df_mean_score) / df_mean_score) - annotators["deviations"] = deviations - annotators_df = pd.DataFrame.from_dict(annotators) - - def _format_deviation(val): - return f"{val:.1f}%" + def build(self, selected_df: DataFrame[MetricSchema], metric_scope: MetricScope): + selected_metric = get_state().selected_metric + if not selected_metric: + return - def _format_score(val): - return f"{val:.3f}" + st.markdown(f"# {self.title}") + st.markdown(f"## {selected_metric.meta['title']}") + st.markdown(selected_metric.meta["long_description"]) - def _color_red_or_green(val): - color = "red" if val < 0 else "green" - return f"color: {color}" + if selected_df.empty: + return - def make_pretty(styler): - styler.format(_format_deviation, subset=["deviations"]) - styler.format(_format_score, subset=["score"]) - styler.applymap(_color_red_or_green, subset=["deviations"]) - return styler + with st.expander("Dataset Properties", expanded=True): + render_dataset_properties(selected_df) + with st.expander("Annotator Statistics", expanded=False): + render_annotator_properties(selected_df) - annotator_columns[1].dataframe(annotators_df.style.pipe(make_pretty), use_container_width=True) + fill_data_quality_window(selected_df, metric_scope, selected_metric) -def fill_data_quality_window(current_df: pd.DataFrame, metric_type: MetricType): - annotation_type = st.session_state[state.DATA_PAGE_METRIC].meta.get("annotation_type") +# TODO: move me to lib +def get_embedding_type(metric_title: str, annotation_type: Optional[List[Any]]) -> EmbeddingType: if ( - (annotation_type is None) + annotation_type is None or (len(annotation_type) == 1 and annotation_type[0] == str(AnnotationType.CLASSIFICATION.RADIO.value)) - or ( - st.session_state[state.DATA_PAGE_METRIC].meta.get("title") - in [ - "Frame object density", - "Object Count", - ] - ) + or (metric_title in ["Frame object density", "Object Count"]) ): # TODO find a better way to filter these later because titles can change - embedding_type = str(EmbeddingType.CLASSIFICATION.value) + return EmbeddingType.CLASSIFICATION else: - embedding_type = str(EmbeddingType.OBJECT.value) + return EmbeddingType.OBJECT - populate_embedding_information(embedding_type) - if (embedding_type == str(EmbeddingType.CLASSIFICATION.value)) and len( - st.session_state[state.COLLECTIONS_IMAGES] - ) == 0: +def fill_data_quality_window( + current_df: DataFrame[MetricSchema], metric_scope: MetricScope, selected_metric: MetricData +): + meta = selected_metric.meta + embedding_type = get_embedding_type(meta["title"], meta["annotation_type"]) + + populate_embedding_information(embedding_type, meta) + + if (embedding_type == str(EmbeddingType.CLASSIFICATION.value)) and len(st.session_state[COLLECTIONS_IMAGES]) == 0: st.write("Image-level embedding file is not available for this project.") return - if (embedding_type == str(EmbeddingType.OBJECT.value)) and len(st.session_state[state.COLLECTIONS_OBJECTS]) == 0: + if (embedding_type == str(EmbeddingType.OBJECT.value)) and len(st.session_state[COLLECTIONS_OBJECTS]) == 0: st.write("Object-level embedding file is not available for this project.") return - n_cols = int(st.session_state[state.MAIN_VIEW_COLUMN_NUM]) - n_rows = int(st.session_state[state.MAIN_VIEW_ROW_NUM]) + n_cols = get_state().page_grid_settings.columns + n_rows = get_state().page_grid_settings.rows + + metric = get_state().selected_metric + if not metric: + st.error("Metric not selected.") + return - chart = get_histogram(current_df) + chart = get_histogram(current_df, "score", metric.name) st.altair_chart(chart, use_container_width=True) - subset = get_df_subset(current_df, "score") + subset = render_df_slicer(current_df, "score") st.write(f"Interval contains {subset.shape[0]} of {current_df.shape[0]} annotations") - paginated_subset = build_pagination(subset, n_cols, n_rows, "score") + paginated_subset = render_pagination(subset, n_cols, n_rows, "score") - form = bulk_tagging_form(metric_type) + form = bulk_tagging_form(metric_scope) if form and form.submitted: df = paginated_subset if form.level == BulkLevel.PAGE else subset @@ -259,88 +211,95 @@ def fill_data_quality_window(current_df: pd.DataFrame, metric_type: MetricType): similarity_expanders.append(st.expander("Similarities", expanded=True)) with cols.pop(0): - build_card(embedding_type, i, row, similarity_expanders, metric_type) + build_card(embedding_type, i, row, similarity_expanders, metric_scope, metric) -def populate_embedding_information(embedding_type: str): - if embedding_type == EmbeddingType.CLASSIFICATION.value: - if st.session_state[state.DATA_PAGE_METRIC].meta.get("title") == "Image-level Annotation Quality": - collections, question_hash_to_collection_indexes = embedding_utils.get_collections_and_metadata( - "cnn_classifications.pkl" +# @dataclass +# class EmbeddingInformation: +# embdedding_type: EmbeddingType +# collection_images: List[LabelEmbedding] +# question_hash_to_collection_indexes: Dict[str, Any] +# image_keys_having_similarity: Dict[str, Any] +# + + +def populate_embedding_information(embedding_type: EmbeddingType, meta: MetricMetadata): + embeddings_dir = get_state().project_paths.embeddings + + if embedding_type == EmbeddingType.CLASSIFICATION: + if meta["title"] == "Image-level Annotation Quality": + collections, question_hash_to_collection_indexes = get_collections_and_metadata( + "cnn_classifications.pkl", embeddings_dir ) - st.session_state[state.COLLECTIONS_IMAGES] = collections - st.session_state[state.QUESTION_HASH_TO_COLLECTION_INDEXES] = question_hash_to_collection_indexes - st.session_state[state.IMAGE_KEYS_HAVING_SIMILARITIES] = embedding_utils.get_image_keys_having_similarities( - collections + st.session_state[COLLECTIONS_IMAGES] = collections + st.session_state[QUESTION_HASH_TO_COLLECTION_INDEXES] = question_hash_to_collection_indexes + st.session_state[IMAGE_KEYS_HAVING_SIMILARITIES] = get_image_keys_having_similarities(collections) + st.session_state[FAISS_INDEX_IMAGE] = get_faiss_index_image( + collections, question_hash_to_collection_indexes ) - st.session_state[state.FAISS_INDEX_IMAGE] = embedding_utils.get_faiss_index_image(collections) - st.session_state[state.CURRENT_INDEX_HAS_ANNOTATION] = True + st.session_state[CURRENT_INDEX_HAS_ANNOTATION] = True - if state.IMAGE_SIMILARITIES not in st.session_state: - st.session_state[state.IMAGE_SIMILARITIES] = {} - for question_hash in st.session_state[state.QUESTION_HASH_TO_COLLECTION_INDEXES].keys(): - st.session_state[state.IMAGE_SIMILARITIES][question_hash] = {} + if IMAGE_SIMILARITIES not in st.session_state: + st.session_state[IMAGE_SIMILARITIES] = {} + for question_hash in st.session_state[QUESTION_HASH_TO_COLLECTION_INDEXES].keys(): + st.session_state[IMAGE_SIMILARITIES][question_hash] = {} else: - collections = embedding_utils.get_collections("cnn_classifications.pkl") - st.session_state[state.COLLECTIONS_IMAGES] = collections - st.session_state[state.IMAGE_KEYS_HAVING_SIMILARITIES] = embedding_utils.get_image_keys_having_similarities( - collections - ) - st.session_state[state.FAISS_INDEX_IMAGE_NO_LABEL] = embedding_utils.get_faiss_index_object( - collections, state.FAISS_INDEX_IMAGE_NO_LABEL - ) - st.session_state[state.CURRENT_INDEX_HAS_ANNOTATION] = False + collections = get_collections("cnn_classifications.pkl", embeddings_dir) + st.session_state[COLLECTIONS_IMAGES] = collections + st.session_state[IMAGE_KEYS_HAVING_SIMILARITIES] = get_image_keys_having_similarities(collections) + st.session_state[FAISS_INDEX_IMAGE_NO_LABEL] = get_faiss_index_object(collections) + st.session_state[CURRENT_INDEX_HAS_ANNOTATION] = False - if state.IMAGE_SIMILARITIES_NO_LABEL not in st.session_state: - st.session_state[state.IMAGE_SIMILARITIES_NO_LABEL] = {} + if IMAGE_SIMILARITIES_NO_LABEL not in st.session_state: + st.session_state[IMAGE_SIMILARITIES_NO_LABEL] = {} - elif embedding_type == EmbeddingType.OBJECT.value: - collections = embedding_utils.get_collections("cnn_objects.pkl") - st.session_state[state.COLLECTIONS_OBJECTS] = collections - st.session_state[state.OBJECT_KEYS_HAVING_SIMILARITIES] = embedding_utils.get_object_keys_having_similarities( - collections - ) - st.session_state[state.FAISS_INDEX_OBJECT] = embedding_utils.get_faiss_index_object( - collections, state.FAISS_INDEX_OBJECT - ) - st.session_state[state.CURRENT_INDEX_HAS_ANNOTATION] = True + elif embedding_type == EmbeddingType.OBJECT: + collections = get_collections("cnn_objects.pkl", embeddings_dir) + st.session_state[COLLECTIONS_OBJECTS] = collections + st.session_state[OBJECT_KEYS_HAVING_SIMILARITIES] = get_object_keys_having_similarities(collections) + st.session_state[FAISS_INDEX_OBJECT] = get_faiss_index_object(collections) + st.session_state[CURRENT_INDEX_HAS_ANNOTATION] = True - if state.OBJECT_SIMILARITIES not in st.session_state: - st.session_state[state.OBJECT_SIMILARITIES] = {} + if OBJECT_SIMILARITIES not in st.session_state: + st.session_state[OBJECT_SIMILARITIES] = {} def build_card( - card_type: str, card_no: int, row: Series, similarity_expanders: list[DeltaGenerator], metric_type: MetricType + embedding_type: EmbeddingType, + card_no: int, + row: Series, + similarity_expanders: list[DeltaGenerator], + metric_scope: MetricScope, + metric: MetricData, ): """ Builds each sub card (the content displayed for each row in a csv file). """ + data_dir = get_state().project_paths.data - if card_type == EmbeddingType.CLASSIFICATION.value: - + if embedding_type == EmbeddingType.CLASSIFICATION: button_name = "show similar images" - if st.session_state[state.DATA_PAGE_METRIC].meta.get("title") == "Image-level Annotation Quality": - image = load_or_fill_image(row) + if metric.meta["title"] == "Image-level Annotation Quality": + image = load_or_fill_image(row, data_dir) similarity_callback = show_similar_classification_images else: - if st.session_state[state.DATA_PAGE_METRIC].meta.get("annotation_type") is None: - image = load_or_fill_image(row) + if metric.meta["annotation_type"] is None: + image = load_or_fill_image(row, data_dir) else: - image = show_image_and_draw_polygons(row) + image = show_image_and_draw_polygons(row, data_dir) similarity_callback = show_similar_images - elif card_type == EmbeddingType.OBJECT.value: - image = show_image_and_draw_polygons(row) + elif embedding_type == EmbeddingType.OBJECT: + image = show_image_and_draw_polygons(row, data_dir) button_name = "show similar objects" similarity_callback = show_similar_object_images - else: - st.write(f"{card_type} card type is not defined in EmbeddingTypes") + st.write(f"{embedding_type.value} card type is not defined in EmbeddingTypes") return st.image(image) - multiselect_tag(row, "explorer", metric_type) + multiselect_tag(row, "explorer", metric_scope) - target_expander = similarity_expanders[card_no // st.session_state[state.MAIN_VIEW_COLUMN_NUM]] + target_expander = similarity_expanders[card_no // get_state().page_grid_settings.columns] st.button( str(button_name), @@ -351,121 +310,14 @@ def build_card( # === Write scores and link to editor === # tags_row = row.copy() - metric_name = st.session_state[state.DATA_PAGE_METRIC_NAME] + if "object_class" in tags_row and not pd.isna(tags_row["object_class"]): tags_row["label_class_name"] = tags_row["object_class"] tags_row.drop("object_class") - tags_row[metric_name] = tags_row["score"] - build_data_tags(tags_row, metric_name) + tags_row[metric.name] = tags_row["score"] + build_data_tags(tags_row, metric.name) if not pd.isnull(row["description"]): # Hacky way for now (with incorrect rounding) description = re.sub(r"(\d+\.\d{0,3})\d*", r"\1", row["description"]) st.write(f"Description: {description}") - - -def get_histogram(current_df: pd.DataFrame): - # TODO: Unify with app/model_quality/sub_pages/__init__.py:SamplesPage.get_histogram - metric_name = st.session_state[state.DATA_PAGE_METRIC_NAME] - if metric_name: - title_suffix = f" - {metric_name}" - else: - metric_name = "Score" # Used for plotting - - bar_chart = ( - alt.Chart(current_df, title=f"Data distribution{title_suffix}") - .mark_bar() - .encode( - alt.X("score:Q", bin=alt.Bin(maxbins=100), title=metric_name), - alt.Y("count()", title="Num. samples"), - tooltip=[ - alt.Tooltip("score:Q", title=metric_name, format=",.3f", bin=True), - alt.Tooltip("count():Q", title="Num. samples", format="d"), - ], - ) - .properties(height=200) - ) - return bar_chart - - -def show_similar_classification_images(row: Series, expander: DeltaGenerator): - feature_hash = row["identifier"].split("_")[-1] - - if row["identifier"] not in st.session_state[state.IMAGE_SIMILARITIES][feature_hash].keys(): - embedding_utils.add_labeled_image_neighbors_to_cache(row["identifier"], feature_hash) - - nearest_images = st.session_state[state.IMAGE_SIMILARITIES][feature_hash][row["identifier"]] - - division = 4 - column_id = 0 - - for nearest_image in nearest_images: - if column_id == 0: - st_columns = expander.columns(division) - - image = load_or_fill_image(nearest_image["key"]) - - st_columns[column_id].image(image) - st_columns[column_id].write(f"Annotated as `{nearest_image['name']}`") - column_id += 1 - column_id = column_id % division - - -def show_similar_images(row: Series, expander: DeltaGenerator): - image_identifier = "_".join(row["identifier"].split("_")[:3]) - - if image_identifier not in st.session_state[state.IMAGE_SIMILARITIES_NO_LABEL].keys(): - embedding_utils.add_image_neighbors_to_cache(image_identifier) - - nearest_images = st.session_state[state.IMAGE_SIMILARITIES_NO_LABEL][image_identifier] - - division = 4 - column_id = 0 - - for nearest_image in nearest_images: - if column_id == 0: - st_columns = expander.columns(division) - - image = load_or_fill_image(nearest_image["key"]) - - st_columns[column_id].image(image) - st_columns[column_id].write(f"Annotated as `{nearest_image['name']}`") - column_id += 1 - column_id = column_id % division - - -def show_similar_object_images(row: Series, expander: DeltaGenerator): - object_identifier = "_".join(row["identifier"].split("_")[:4]) - - if object_identifier not in st.session_state[state.OBJECT_KEYS_HAVING_SIMILARITIES]: - expander.write("Similarity search is not available for this object.") - return - - if object_identifier not in st.session_state[state.OBJECT_SIMILARITIES].keys(): - embedding_utils.add_object_neighbors_to_cache(object_identifier) - - nearest_images = st.session_state[state.OBJECT_SIMILARITIES][object_identifier] - - division = 4 - column_id = 0 - - for nearest_image in nearest_images: - if column_id == 0: - st_columns = expander.columns(division) - - image = show_image_and_draw_polygons(nearest_image["key"]) - - st_columns[column_id].image(image) - st_columns[column_id].write(f"Annotated as `{nearest_image['name']}`") - column_id += 1 - column_id = column_id % division - - -def get_unique_data_units_size(current_df: pd.DataFrame): - data_units = set() - identifiers = current_df["identifier"] - for identifier in identifiers: - key_components = identifier.split("_") - data_units.add(key_components[0] + "_" + key_components[1]) - - return len(data_units) diff --git a/src/encord_active/app/data_quality/sub_pages/summary.py b/src/encord_active/app/data_quality/sub_pages/summary.py index 5fe73676c..d71ac857f 100644 --- a/src/encord_active/app/data_quality/sub_pages/summary.py +++ b/src/encord_active/app/data_quality/sub_pages/summary.py @@ -1,18 +1,14 @@ -from typing import List - -import pandas as pd import streamlit as st -import encord_active.app.common.state as state -from encord_active.app.common import metric as iutils -from encord_active.app.common.components import build_data_tags -from encord_active.app.common.components.individual_tagging import multiselect_tag -from encord_active.app.common.components.tag_creator import tag_creator +from encord_active.app.common.components.metric_summary import render_metric_summary +from encord_active.app.common.components.tags.tag_creator import tag_creator from encord_active.app.common.page import Page -from encord_active.app.data_quality.common import ( - MetricType, +from encord_active.app.common.state import get_state +from encord_active.lib.dataset.outliers import get_iqr_outliers +from encord_active.lib.metrics.utils import ( + MetricScope, load_available_metrics, - show_image_and_draw_polygons, + load_metric_dataframe, ) @@ -23,7 +19,7 @@ def sidebar_options(self): self.row_col_settings_in_sidebar() tag_creator() - def build(self, metric_type_selection: MetricType): + def build(self, metric_scope: MetricScope): st.markdown(f"# {self.title}") st.subheader("Outliers by IQR range for every metric") @@ -55,93 +51,17 @@ def build(self, metric_type_selection: MetricType): "an outlier for one metric and a non-outlier for another." ) - n_cols = int(st.session_state[state.MAIN_VIEW_COLUMN_NUM]) - n_rows = int(st.session_state[state.MAIN_VIEW_ROW_NUM]) - n_items_in_page = n_cols * n_rows - - metrics = load_available_metrics(metric_type_selection.value) - - for idx in metrics: - df = iutils.load_metric(idx, normalize=False) + metrics = load_available_metrics(get_state().project_paths.metrics, metric_scope) - current_df = df.copy() - - if df.empty: + for metric in metrics: + original_df = load_metric_dataframe(metric, normalize=False) + res = get_iqr_outliers(original_df) + if not res: continue - moderate_iqr_scale = 1.5 - severe_iqr_scale = 2.5 - Q1 = current_df["score"].quantile(0.25) - Q3 = current_df["score"].quantile(0.75) - IQR = Q3 - Q1 - - current_df["dist_to_iqr"] = 0 - current_df.loc[current_df["score"] > Q3, "dist_to_iqr"] = (current_df["score"] - Q3).abs() - current_df.loc[current_df["score"] < Q1, "dist_to_iqr"] = (current_df["score"] - Q1).abs() - current_df.sort_values(by="dist_to_iqr", inplace=True, ascending=False) - - moderate_lb, moderate_ub = Q1 - moderate_iqr_scale * IQR, Q3 + moderate_iqr_scale * IQR - severe_lb, severe_ub = Q1 - severe_iqr_scale * IQR, Q3 + severe_iqr_scale * IQR - - n_moderate_outliers = ( - ((severe_lb <= current_df["score"]) & (current_df["score"] < moderate_lb)) - | ((severe_ub >= current_df["score"]) & (current_df["score"] > moderate_ub)) - ).sum() - - n_severe_outliers = ((current_df["score"] < severe_lb) | (current_df["score"] > severe_ub)).sum() - - with st.expander(label=f"{idx.name} Outliers - {n_severe_outliers} severe, {n_moderate_outliers} moderate"): - st.markdown(idx.meta["long_description"]) - - if n_severe_outliers + n_moderate_outliers == 0: - st.success("No outliers found!") - continue - - st.error(f"Number of severe outliers: {n_severe_outliers}/{len(current_df)}") - st.warning(f"Number of moderate outliers: {n_moderate_outliers}/{len(current_df)}") - value = st.slider( - "distance to IQR", - min_value=float(current_df["dist_to_iqr"].min()), - max_value=float(current_df["dist_to_iqr"].max()), - step=max( - 0.1, - float( - (current_df["dist_to_iqr"].max() - current_df["dist_to_iqr"].min()) - / (len(current_df) / n_items_in_page) - ), - ), - value=float(current_df["dist_to_iqr"].max()), - key=f"dist_to_iqr{idx.name}", - ) - - selected_df = current_df[current_df["dist_to_iqr"] <= value][:n_items_in_page] - - cols: List = [] - for i, (row_no, row) in enumerate(selected_df.iterrows()): - if not cols: - cols = list(st.columns(n_cols)) - - with cols.pop(0): - image = show_image_and_draw_polygons(row) - st.image(image) - - multiselect_tag(row, f"{idx.name}_summary", metric_type_selection) - - # === Write scores and link to editor === # - - tags_row = row.copy() - if row["score"] > severe_ub or row["score"] < severe_lb: - tags_row["outlier"] = "Severe" - elif row["score"] > moderate_ub or row["score"] < moderate_lb: - tags_row["outlier"] = "Moderate" - else: - tags_row["outlier"] = "Low" - - if "object_class" in tags_row and not pd.isna(tags_row["object_class"]): - tags_row["label_class_name"] = tags_row["object_class"] - tags_row.drop("object_class") - tags_row[idx.name] = tags_row["score"] - build_data_tags(tags_row, idx.name) + df, iqr_outliers = res - if not pd.isnull(row["description"]): - st.write(f"Description: {row['description']}. ") + with st.expander( + label=f"{metric.name} Outliers - {iqr_outliers.n_severe_outliers} severe, {iqr_outliers.n_moderate_outliers} moderate" + ): + render_metric_summary(metric, df, iqr_outliers, metric_scope) diff --git a/src/encord_active/app/db/connection.py b/src/encord_active/app/db/connection.py deleted file mode 100644 index cc179d136..000000000 --- a/src/encord_active/app/db/connection.py +++ /dev/null @@ -1,19 +0,0 @@ -import sqlite3 -from typing import Optional - -import streamlit as st - - -class DBConnection: - def __init__(self, file: Optional[str] = None): - if file: - self.file = file - else: - self.file = st.session_state.db_path - - def __enter__(self): - self.conn = sqlite3.connect(self.file) - return self.conn - - def __exit__(self, type, value, traceback): - self.conn.__exit__(type, value, traceback) diff --git a/src/encord_active/app/model_quality/components/__init__.py b/src/encord_active/app/model_quality/components/__init__.py deleted file mode 100644 index 1965cc738..000000000 --- a/src/encord_active/app/model_quality/components/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .false_negative_view import false_negative_view -from .metric_view import metric_view diff --git a/src/encord_active/app/model_quality/components/false_negative_view.py b/src/encord_active/app/model_quality/components/false_negative_view.py deleted file mode 100644 index e1c533362..000000000 --- a/src/encord_active/app/model_quality/components/false_negative_view.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import List - -import cv2 -import pandas as pd -import streamlit as st -from streamlit.delta_generator import DeltaGenerator - -import encord_active.app.model_quality.components.utils as cutils -from encord_active.app.common import state -from encord_active.app.common.colors import Color, hex_to_rgb -from encord_active.app.common.components import build_data_tags -from encord_active.app.common.components.bulk_tagging_form import ( - BulkLevel, - action_bulk_tags, - bulk_tagging_form, -) -from encord_active.app.common.components.individual_tagging import multiselect_tag -from encord_active.app.common.utils import ( - build_pagination, - get_df_subset, - load_or_fill_image, -) -from encord_active.app.data_quality.common import MetricType - - -def __show_image_and_fn( - label: pd.Series, - predictions: pd.DataFrame, - box_color: Color = Color.RED, - mask_opacity=0.5, -): - """ - :param label: The csv row of the false-negative label to display. - :param predictions: All the predictions on the same image with the samme predicted class. - :param box_color: The hex color to use when drawing the prediction. - """ - isClosed = True - thickness = 5 - - image = load_or_fill_image(label) - - # Draw predictions - for _, pred in predictions.iterrows(): - pred_color: str = st.session_state.full_class_idx[str(pred["class_id"])]["color"] - image = cutils.draw_mask(pred, image, mask_opacity=mask_opacity, color=pred_color) - - # Draw label - label_color = hex_to_rgb(box_color.value) - image = cv2.polylines(image, [cutils.get_bbox(label)], isClosed, label_color, thickness, lineType=cv2.LINE_8) - image = cutils.draw_mask(label, image, mask_opacity=mask_opacity, color=box_color) - - st.image(image) - - -def __build_card( - label: pd.Series, - predictions: pd.DataFrame, - st_col: DeltaGenerator, - box_color: Color = Color.RED, -): - with st_col: - __show_image_and_fn(label, predictions, box_color=box_color) - multiselect_tag(label, "false_negatives", MetricType.MODEL_QUALITY) - - cls = st.session_state.full_class_idx[str(label["class_id"])]["name"] - label = label.copy() - label["label_class_name"] = cls - # === Write scores and link to editor === # - build_data_tags(label, st.session_state[state.PREDICTIONS_LABEL_METRIC]) - - -def false_negative_view(false_negatives, model_predictions, color: Color): - if state.FALSE_NEGATIVE_VIEW_PAGE_NUMBER not in st.session_state: - st.session_state[state.FALSE_NEGATIVE_VIEW_PAGE_NUMBER] = 1 - - n_cols, n_rows = int(st.session_state[state.MAIN_VIEW_COLUMN_NUM]), int(st.session_state[state.MAIN_VIEW_ROW_NUM]) - selected_metric = st.session_state.get(state.PREDICTIONS_LABEL_METRIC) - subset = get_df_subset(false_negatives, selected_metric) - paginated_subset = build_pagination(subset, n_cols, n_rows, selected_metric) - - form = bulk_tagging_form(MetricType.MODEL_QUALITY) - - if form and form.submitted: - df = paginated_subset if form.level == BulkLevel.PAGE else subset - action_bulk_tags(df, form.tags, form.action) - - if len(paginated_subset) == 0: - st.error("No data in selected quality interval") - else: - # === Fill in the container === # - cols: List = [] - for i, label in paginated_subset.iterrows(): - if not cols: - cols = list(st.columns(n_cols)) - - frame_preds = model_predictions[model_predictions["img_id"] == label["img_id"]] - __build_card(label, frame_preds, cols.pop(0), box_color=color) diff --git a/src/encord_active/app/model_quality/components/index_view.py b/src/encord_active/app/model_quality/components/index_view.py deleted file mode 100644 index 41b699dc8..000000000 --- a/src/encord_active/app/model_quality/components/index_view.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import List - -import cv2 -import numpy as np -import pandas as pd -import streamlit as st -from streamlit.delta_generator import DeltaGenerator - -import encord_active.app.model_quality.components.utils as cutils -from encord_active.app.common import state -from encord_active.app.common.colors import Color, hex_to_rgb -from encord_active.app.common.components import build_data_tags -from encord_active.app.common.utils import ( - build_pagination, - get_df_subset, - get_geometries, - load_or_fill_image, -) - - -def __show_image_and_draw_polygons_plus_prediction(row: pd.Series, box_color: Color = Color.GREEN) -> np.ndarray: - isClosed = True - thickness = 5 - - # Label polygons - image = load_or_fill_image(row) - img_h, img_w = image.shape[:2] - for color, geometry in get_geometries(row, img_h=img_h, img_w=img_w, skip_object_hash=True): - image = cv2.polylines(image, [geometry], isClosed, hex_to_rgb(color), thickness) - - # Prediction polygon - box = cutils.get_bbox(row) - pred_color = hex_to_rgb(box_color.value) - image = cv2.polylines(image, [box], isClosed, pred_color, thickness // 2, lineType=cv2.LINE_8) - image = cutils.draw_mask(row, image, color=box_color) - - return image - - -def __build_card(row: pd.Series, st_col: DeltaGenerator, box_color: Color = Color.GREEN): - with st_col: - image = __show_image_and_draw_polygons_plus_prediction(row, box_color=box_color) - st.image(image) - - # === Write scores and link to editor === # - build_data_tags(row, st.session_state.predictions_metric) - - if row["fp_reason"] and not row["tps"] == 1.0: - st.write(f"Reason: {row['fp_reason']}") - - with st.expander("Details"): - tmp_df = pd.DataFrame(row) - if row["tps"] == 1.0: - tmp_df = tmp_df.drop("fp_reason", axis=0) - tmp_df = tmp_df.drop(["url", "x1", "y1", "x2", "y2", "rle", "Unnamed: 0"], axis=0) - st.dataframe(tmp_df) - - -def metric_view( - df: pd.DataFrame, - box_color: Color = Color.GREEN, -): - if state.METRIC_VIEW_PAGE_NUMBER not in st.session_state: - st.session_state[state.METRIC_VIEW_PAGE_NUMBER] = 1 - - n_cols, n_rows = int(st.session_state[state.MAIN_VIEW_COLUMN_NUM]), int(st.session_state[state.MAIN_VIEW_ROW_NUM]) - selected_metric = st.session_state.get(state.PREDICTIONS_METRIC, "") - subset = get_df_subset(df, selected_metric) - paginated_subset = build_pagination(subset, n_cols, n_rows, selected_metric) - - if len(paginated_subset) == 0: - st.error("No data in selected quality interval") - else: - # === Fill in the container === # - cols: List = [] - for i, row in paginated_subset.iterrows(): - if not cols: - cols = list(st.columns(n_cols)) - __build_card(row, cols.pop(0), box_color=box_color) diff --git a/src/encord_active/app/model_quality/components/metric_view.py b/src/encord_active/app/model_quality/components/metric_view.py deleted file mode 100644 index 3a1be599c..000000000 --- a/src/encord_active/app/model_quality/components/metric_view.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import List - -import cv2 -import numpy as np -import pandas as pd -import streamlit as st -from streamlit.delta_generator import DeltaGenerator - -import encord_active.app.model_quality.components.utils as cutils -from encord_active.app.common import state -from encord_active.app.common.colors import Color, hex_to_rgb -from encord_active.app.common.components import build_data_tags -from encord_active.app.common.components.bulk_tagging_form import ( - BulkLevel, - action_bulk_tags, - bulk_tagging_form, -) -from encord_active.app.common.components.individual_tagging import multiselect_tag -from encord_active.app.common.utils import ( - build_pagination, - get_df_subset, - get_geometries, - load_or_fill_image, -) -from encord_active.app.data_quality.common import MetricType - - -def __show_image_and_draw_polygons_plus_prediction(row: pd.Series, box_color: Color = Color.GREEN) -> np.ndarray: - isClosed = True - thickness = 5 - - # Label polygons - image = load_or_fill_image(row) - img_h, img_w = image.shape[:2] - for color, geometry in get_geometries(row, img_h=img_h, img_w=img_w, skip_object_hash=True): - image = cv2.polylines(image, [geometry], isClosed, hex_to_rgb(color), thickness) - - # Prediction polygon - box = cutils.get_bbox(row) - pred_color = hex_to_rgb(box_color.value) - image = cv2.polylines(image, [box], isClosed, pred_color, thickness // 2, lineType=cv2.LINE_8) - image = cutils.draw_mask(row, image, color=box_color) - - return image - - -def __build_card(row: pd.Series, st_col: DeltaGenerator, box_color: Color = Color.GREEN): - with st_col: - image = __show_image_and_draw_polygons_plus_prediction(row, box_color=box_color) - st.image(image) - multiselect_tag(row, "metric_view", MetricType.MODEL_QUALITY) - - # === Write scores and link to editor === # - build_data_tags(row, st.session_state.predictions_metric) - - if row["fp_reason"] and not row["tps"] == 1.0: - st.write(f"Reason: {row['fp_reason']}") - - -def metric_view( - df: pd.DataFrame, - box_color: Color = Color.GREEN, -): - if state.METRIC_VIEW_PAGE_NUMBER not in st.session_state: - st.session_state[state.METRIC_VIEW_PAGE_NUMBER] = 1 - - n_cols, n_rows = int(st.session_state[state.MAIN_VIEW_COLUMN_NUM]), int(st.session_state[state.MAIN_VIEW_ROW_NUM]) - selected_metric = st.session_state.get(state.PREDICTIONS_METRIC, "") - subset = get_df_subset(df, selected_metric) - paginated_subset = build_pagination(subset, n_cols, n_rows, selected_metric) - - form = bulk_tagging_form(MetricType.MODEL_QUALITY) - if form and form.submitted: - df = paginated_subset if form.level == BulkLevel.PAGE else subset - action_bulk_tags(df, form.tags, form.action) - - if len(paginated_subset) == 0: - st.error("No data in selected quality interval") - else: - # === Fill in the container === # - cols: List = [] - for i, row in paginated_subset.iterrows(): - if not cols: - cols = list(st.columns(n_cols)) - __build_card(row, cols.pop(0), box_color=box_color) diff --git a/src/encord_active/app/model_quality/components/utils.py b/src/encord_active/app/model_quality/components/utils.py deleted file mode 100644 index 201e34a4c..000000000 --- a/src/encord_active/app/model_quality/components/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Tuple, Union - -import cv2 -import numpy as np -import pandas as pd - -from encord_active.app.common.colors import Color, hex_to_rgb -from encord_active.lib.common.utils import rle_to_binary_mask - - -def get_bbox(row: pd.Series): - x1, y1, x2, y2 = row["x1"], row["y1"], row["x2"], row["y2"] - return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]).reshape((-1, 1, 2)).astype(int) - - -def draw_mask(row: pd.Series, image: np.ndarray, mask_opacity: float = 0.5, color: Union[Color, str] = Color.PURPLE): - isClosed = True - thickness = 2 - - hex_color = color.value if isinstance(color, Color) else color - _color: Tuple[int, ...] = hex_to_rgb(hex_color) - _color_outline: Tuple[int, ...] = hex_to_rgb(hex_color, lighten=-0.5) - if isinstance(row["rle"], str): - mask = rle_to_binary_mask(eval(row["rle"])) - - # Draw contour line - contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] - image = cv2.polylines(image, contours, isClosed, _color_outline, thickness, lineType=cv2.LINE_8) - - # Fill polygon with opacity - patch = np.zeros_like(image) - mask_select = mask == 1 - patch[mask_select] = _color - image[mask_select] = cv2.addWeighted(image, (1 - mask_opacity), patch, mask_opacity, 0)[mask_select] - return image diff --git a/src/encord_active/app/model_quality/data.py b/src/encord_active/app/model_quality/data.py deleted file mode 100644 index 2bd12f578..000000000 --- a/src/encord_active/app/model_quality/data.py +++ /dev/null @@ -1,218 +0,0 @@ -import json -from pathlib import Path -from typing import Dict, List, Optional, Tuple, cast - -import pandas as pd -import streamlit as st -from natsort import natsorted - -from encord_active.app.common.utils import load_json - - -def check_model_prediction_availability(): - predictions_path = st.session_state.predictions_dir / "predictions.csv" - return predictions_path.is_file() - - -def merge_objects_and_scores( - object_df: pd.DataFrame, metric_pth: Optional[Path] = None, ignore_object_scores=True -) -> Tuple[pd.DataFrame, List[str]]: - metric_names: List[str] = [] - object_df["identifier_no_oh"] = object_df["identifier"].str.replace(r"^(\S{73}_\d+)(.*)", r"\1", regex=True) - - if metric_pth is not None: - # Import prediction scores - for metric in metric_pth.iterdir(): - if not metric.suffix == ".csv": - continue - - meta_pth = metric.with_suffix(".meta.json") - if not meta_pth.is_file(): - continue - - with meta_pth.open("r", encoding="utf-8") as f: - meta = json.load(f) - - metric_scores = pd.read_csv(metric, index_col="identifier") - # Ignore empty data frames. - if metric_scores.shape[0] == 0: - continue - - title = f"{meta['title']} (P)" - metric_names.append(title) - - has_object_level_keys = len(metric_scores.index[0].split("_")) > 3 - metric_column = "identifier" if has_object_level_keys else "identifier_no_oh" - # Join data and rename column to metric name. - object_df = object_df.join(metric_scores["score"], on=[metric_column]) - object_df[title] = object_df["score"] - object_df.drop("score", axis=1, inplace=True) - - # Import frame level scores - for metric_file in st.session_state.metric_dir.iterdir(): - if metric_file.is_dir() or metric_file.suffix != ".csv": - continue - - # Read first row to see if metric has frame level scores - with metric_file.open("r", encoding="utf-8") as f: - f.readline() # header - key, *_ = f.readline().split(",") - - if not key: # Empty metric - continue - - label_hash, du_hash, frame, *rest = key.split("_") - type_indicator = "F" # Frame level - join_column = "identifier_no_oh" - if rest and ignore_object_scores: - # There are object hashes included in the key, so ignore. - continue - elif rest: - type_indicator = "O" - join_column = "identifier" - - meta_pth = metric_file.with_suffix(".meta.json") - if not meta_pth.is_file(): - continue - - with meta_pth.open("r", encoding="utf-8") as f: - meta = json.load(f) - - metric_scores = pd.read_csv(metric_file, index_col="identifier") - - title = f"{meta['title']} ({type_indicator})" - metric_names.append(title) - - # Join data and rename column to metric name. - object_df = object_df.join(metric_scores["score"], on=join_column) - object_df[title] = object_df["score"] - object_df.drop("score", axis=1, inplace=True) - object_df.drop("identifier_no_oh", axis=1, inplace=True) - metric_names = cast(List[str], natsorted(metric_names, key=lambda x: x[-3:] + x[:-3])) - return object_df, metric_names - - -@st.cache(allow_output_mutation=True) -def get_model_predictions() -> Optional[Tuple[pd.DataFrame, List[str]]]: - """ - Loads predictions and their associated metric scores. - :param predictions_path: - :return: - - predictions: The predictions with their metric scores - - column names: Unit test column names. - """ - predictions_path = st.session_state.predictions_dir / "predictions.csv" - if not predictions_path.is_file(): - st.error(f"Labels file `{st.session_state.predictions_dir / 'predictions.csv'} is missing.") - return None - - predictions_df = pd.read_csv(predictions_path) - - # Extract label_hash, du_hash, frame - identifiers = predictions_df["identifier"].str.split("_", expand=True) - identifiers.columns = ["label_hash", "du_hash", "frame", "object_hash"][: len(identifiers.columns)] - identifiers["frame"] = pd.to_numeric(identifiers["frame"]) - predictions_df = pd.concat([predictions_df, identifiers], axis=1) - - # Load predictions scores (metrics) - pred_idx_pth = st.session_state.predictions_dir / "metrics" - if not pred_idx_pth.exists(): - return predictions_df, [] - - predictions_df, metric_names = merge_objects_and_scores(predictions_df, metric_pth=pred_idx_pth) - - return predictions_df, metric_names - - -@st.cache(allow_output_mutation=True) -def get_labels() -> Optional[Tuple[pd.DataFrame, List[str]]]: - labels_path = st.session_state.predictions_dir / "labels.csv" - if not labels_path.is_file(): - st.error(f"Labels file `{st.session_state.predictions_dir / 'labels.csv'} is missing") - return None - - labels_df = pd.read_csv(labels_path) - - # Extract label_hash, du_hash, frame - identifiers = labels_df["identifier"].str.split("_", expand=True) - identifiers = identifiers.iloc[:, :3] - identifiers.columns = ["label_hash", "du_hash", "frame"] - identifiers["frame"] = pd.to_numeric(identifiers["frame"]) - - labels_df = pd.concat([labels_df, identifiers], axis=1) - labels_df, label_metric_names = merge_objects_and_scores(labels_df, ignore_object_scores=False) - return labels_df, label_metric_names - - -@st.cache(allow_output_mutation=True) -def get_gt_matched() -> Optional[dict]: - gt_path = st.session_state.predictions_dir / "ground_truths_matched.json" - return load_json(gt_path) - - -@st.cache() -def get_class_idx() -> Optional[dict]: - class_idx_pth = st.session_state.predictions_dir / "class_idx.json" - return load_json(class_idx_pth) - - -@st.cache() -def get_metadata_files() -> Dict[str, dict]: - out: Dict[str, dict] = {} - # Look in data metrics - data_metrics = out.setdefault("data", {}) - if st.session_state.metric_dir.is_dir(): - for f in st.session_state.metric_dir.iterdir(): - if not f.name.endswith(".meta.json"): - continue - - meta = load_json(f) - if meta is None: - continue - - data_metrics[meta["title"]] = meta - - prediction_metrics = out.setdefault("prediction", {}) - if (st.session_state.predictions_dir / "metrics").is_dir(): - for f in st.session_state.metric_dir.iterdir(): - if not f.name.endswith(".meta.json"): - continue - - meta = load_json(f) - if meta is None: - continue - - prediction_metrics[meta["title"]] = meta - - return out - - -def filter_labels_for_frames_wo_predictions(): - """ - Note: data_root is not used in the code, but utilized by `st` to determine what - to cache, so please don't remove. - """ - _predictions = st.session_state.model_predictions - pred_keys = _predictions["img_id"].unique() - _labels = st.session_state.sorted_labels - return _labels[_labels["img_id"].isin(pred_keys)] - - -def prediction_and_label_filtering(labels, metrics, model_pred, precisions): - # Filtering based on selection - # In the following a "_" prefix means the the data has been filtered according to selected classes. - # Predictions - class_idx = st.session_state.selected_class_idx - row_selection = model_pred["class_id"].isin(set(map(int, class_idx.keys()))) - _model_pred = model_pred[row_selection].copy() - # Labels - row_selection = labels["class_id"].isin(set(map(int, class_idx.keys()))) - _labels = labels[row_selection] - - chosen_name_set = set(map(lambda x: x["name"], class_idx.values())).union({"Mean"}) - _metrics = metrics[metrics["class_name"].isin(chosen_name_set)] - _precisions = precisions[precisions["class_name"].isin(chosen_name_set)] - name_map = {int(k): v["name"] for k, v in class_idx.items()} - _model_pred["class_name"] = _model_pred["class_id"].map(name_map) - _labels["class_name"] = _labels["class_id"].map(name_map) - return _labels, _metrics, _model_pred, _precisions diff --git a/src/encord_active/app/model_quality/settings.py b/src/encord_active/app/model_quality/settings.py index 97f9ab6d3..ef62775e0 100644 --- a/src/encord_active/app/model_quality/settings.py +++ b/src/encord_active/app/model_quality/settings.py @@ -2,29 +2,33 @@ import streamlit as st -import encord_active.app.common.components as cst -import encord_active.app.common.state as state -from encord_active.app.common.components.tag_creator import tag_creator +# import encord_active.app.common.state as state +from encord_active.app.common.components.tags.tag_creator import tag_creator +from encord_active.app.common.state import get_state +from encord_active.lib.model_predictions.reader import get_class_idx def common_settings(): tag_creator() - class_idx = st.session_state.full_class_idx + if not get_state().predictions.all_classes: + get_state().predictions.all_classes = get_class_idx(get_state().project_paths.predictions) + + all_classes = get_state().predictions.all_classes col1, col2, col3 = st.columns([4, 4, 3]) with col1: - selected_classes = cst.multiselect_with_all_option( - "Select classes to include", - list(map(lambda x: x["name"], class_idx.values())), - key=state.CLASS_SELECTION, - help="With this selection, you can choose which classes to include in the main page.", + selected_classes = st.multiselect( + "Filter by class", + list(all_classes.items()), + format_func=lambda x: x[1]["name"], + help=""" + With this selection, you can choose which classes to include in the main page.\n + This acts as a filter, i.e. when nothing is selected all classes are included. + """, ) - if "All" in selected_classes: - st.session_state.selected_class_idx = deepcopy(class_idx) - else: - st.session_state.selected_class_idx = {k: v for k, v in class_idx.items() if v["name"] in selected_classes} + get_state().predictions.selected_classes = dict(selected_classes) or deepcopy(all_classes) with col2: # IOU @@ -33,20 +37,19 @@ def common_settings(): min_value=0, max_value=100, value=50, - key=state.IOU_THRESHOLD_, help="The mean average precision (mAP) score is based on true positives and false positives. " "The IOU threshold determines how closely predictions need to match labels to be considered " "as true positives.", ) - st.session_state[state.IOU_THRESHOLD] = iou_threshold / 100 + get_state().iou_threshold = iou_threshold / 100 with col3: st.write("") st.write("") # Ignore unmatched frames - st.checkbox( + get_state().ignore_frames_without_predictions = st.checkbox( "Ignore frames without predictions", - key=state.IGNORE_FRAMES_WO_PREDICTIONS, + value=get_state().ignore_frames_without_predictions, help="Scores like mAP and mAR are effected negatively if there are frames in the dataset for which there " "exist no predictions. With this flag, you can ignore those.", ) diff --git a/src/encord_active/app/model_quality/sub_pages/__init__.py b/src/encord_active/app/model_quality/sub_pages/__init__.py index 5e2f562be..abc93dee1 100644 --- a/src/encord_active/app/model_quality/sub_pages/__init__.py +++ b/src/encord_active/app/model_quality/sub_pages/__init__.py @@ -1,11 +1,20 @@ from abc import abstractmethod +from typing import List, Optional -import altair as alt -import pandas as pd import streamlit as st +from pandera.typing import DataFrame import encord_active.app.common.state as state from encord_active.app.common.page import Page +from encord_active.lib.metrics.utils import MetricData +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, +) +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) class ModelQualityPage(Page): @@ -24,19 +33,19 @@ def sidebar_options(self): @abstractmethod def build( self, - model_predictions: pd.DataFrame, - labels: pd.DataFrame, - metrics: pd.DataFrame, - precisions: pd.DataFrame, + model_predictions: DataFrame[PredictionMatchSchema], + labels: DataFrame[LabelMatchSchema], + metrics: DataFrame[PerformanceMetricSchema], + precisions: DataFrame[PrecisionRecallSchema], ): pass def __call__( self, - model_predictions: pd.DataFrame, - labels: pd.DataFrame, - metrics: pd.DataFrame, - precisions: pd.DataFrame, + model_predictions: DataFrame[PredictionMatchSchema], + labels: DataFrame[LabelMatchSchema], + metrics: DataFrame[PerformanceMetricSchema], + precisions: DataFrame[PrecisionRecallSchema], ): return self.build(model_predictions, labels, metrics, precisions) @@ -52,42 +61,26 @@ def prediction_metric_in_sidebar(): `st.session_state.model_predictions` data frame. """ fixed_options = {"confidence": "Model Confidence", "iou": "IOU"} - st.selectbox( + column_names = list(state.get_state().predictions.metric_datas.predictions.keys()) + state.get_state().predictions.metric_datas.selected_predicion = st.selectbox( "Select metric for your predictions", - st.session_state.prediction_metric_names + list(fixed_options.keys()), - key=state.PREDICTIONS_METRIC, + column_names + list(fixed_options.keys()), format_func=lambda s: fixed_options.get(s, s), help="The data in the main view will be sorted by the selected metric. " "(F) := frame scores, (P) := prediction scores.", ) @staticmethod - def metric_details_description(metric_name: str = ""): + def metric_details_description(): + metric_name = state.get_state().predictions.metric_datas.selected_predicion if not metric_name: - metric_name = st.session_state[state.PREDICTIONS_METRIC] - metric_meta = st.session_state.metric_meta["prediction"].get(metric_name[:-4], {}) # Remove " (P)" - if not metric_meta: - metric_meta = st.session_state.metric_meta["data"].get(metric_name[:-4], {}) # Remove " (P)" - if metric_meta: - st.markdown(f"### The {metric_meta['title']} metric") - st.markdown(metric_meta["long_description"]) + return + metric_data = state.get_state().predictions.metric_datas.predictions.get(metric_name) -class HistogramMixin: - @staticmethod - def get_histogram(data_frame: pd.DataFrame, metric_column: str): - title_suffix = f" - {metric_column}" - bar_chart = ( - alt.Chart(data_frame, title=f"Data distribution{title_suffix}") - .mark_bar() - .encode( - alt.X(f"{metric_column}:Q", bin=alt.Bin(maxbins=100), title=metric_column), - alt.Y("count()", title="Num. samples"), - tooltip=[ - alt.Tooltip(f"{metric_column}:Q", title=metric_column, format=",.3f", bin=True), - alt.Tooltip("count():Q", title="Num. samples", format="d"), - ], - ) - .properties(height=200) - ) - return bar_chart + if not metric_data: + metric_data = state.get_state().predictions.metric_datas.labels.get(metric_name) + + if metric_data: + st.markdown(f"### The {metric_data.name[:-4]} metric") + st.markdown(metric_data.meta["long_description"]) diff --git a/src/encord_active/app/model_quality/sub_pages/false_negatives.py b/src/encord_active/app/model_quality/sub_pages/false_negatives.py index 20212935f..67480ec7e 100644 --- a/src/encord_active/app/model_quality/sub_pages/false_negatives.py +++ b/src/encord_active/app/model_quality/sub_pages/false_negatives.py @@ -1,21 +1,30 @@ -import pandas as pd import streamlit as st +from pandera.typing import DataFrame -import encord_active.app.common.state as state -from encord_active.app.common.colors import Color -from encord_active.app.model_quality.components import false_negative_view +from encord_active.app.common.components.prediction_grid import prediction_grid +from encord_active.app.common.state import get_state +from encord_active.lib.charts.histogram import get_histogram +from encord_active.lib.common.colors import Color +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, +) +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) -from . import HistogramMixin, ModelQualityPage +from . import ModelQualityPage -class FalseNegativesPage(ModelQualityPage, HistogramMixin): +class FalseNegativesPage(ModelQualityPage): title = "πŸ” False Negatives" def sidebar_options(self): - st.selectbox( + metric_columns = list(get_state().predictions.metric_datas.labels.keys()) + get_state().predictions.metric_datas.selected_label = st.selectbox( "Select metric for your labels", - st.session_state.label_metric_names, - key=state.PREDICTIONS_LABEL_METRIC, + metric_columns, help="The data in the main view will be sorted by the selected metric. " "(F) := frame scores, (O) := object scores.", ) @@ -23,14 +32,18 @@ def sidebar_options(self): def build( self, - model_predictions: pd.DataFrame, - labels: pd.DataFrame, - metrics: pd.DataFrame, - precisions: pd.DataFrame, + model_predictions: DataFrame[PredictionMatchSchema], + labels: DataFrame[LabelMatchSchema], + metrics: DataFrame[PerformanceMetricSchema], + precisions: DataFrame[PrecisionRecallSchema], ): st.markdown(f"# {self.title}") st.header("False Negatives") - metric_name = st.session_state[state.PREDICTIONS_LABEL_METRIC] + metric_name = get_state().predictions.metric_datas.selected_label + if not metric_name: + st.error("Prediction label not selected") + return + with st.expander("Details"): color = Color.PURPLE st.markdown( @@ -45,11 +58,13 @@ def build( """, unsafe_allow_html=True, ) - self.metric_details_description(metric_name) - fns_df = labels[labels["fns"]].dropna(subset=[metric_name]) + self.metric_details_description() + fns_df = labels[labels[LabelMatchSchema.is_false_negative]].dropna(subset=[metric_name]) if fns_df.shape[0] == 0: st.write("No false negatives") else: - histogram = self.get_histogram(fns_df, metric_name) + histogram = get_histogram(fns_df, metric_name) st.altair_chart(histogram, use_container_width=True) - false_negative_view(fns_df, model_predictions, color=color) + prediction_grid( + get_state().project_paths.data, labels=fns_df, model_predictions=model_predictions, box_color=color + ) diff --git a/src/encord_active/app/model_quality/sub_pages/false_positives.py b/src/encord_active/app/model_quality/sub_pages/false_positives.py index 8d5871604..1f605832a 100644 --- a/src/encord_active/app/model_quality/sub_pages/false_positives.py +++ b/src/encord_active/app/model_quality/sub_pages/false_positives.py @@ -1,13 +1,23 @@ -import pandas as pd import streamlit as st +from pandera.typing import DataFrame -from encord_active.app.common.colors import Color -from encord_active.app.model_quality.components import metric_view +from encord_active.app.common.components.prediction_grid import prediction_grid +from encord_active.app.common.state import get_state +from encord_active.lib.charts.histogram import get_histogram +from encord_active.lib.common.colors import Color +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, +) +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) -from . import HistogramMixin, ModelQualityPage +from . import ModelQualityPage -class FalsePositivesPage(ModelQualityPage, HistogramMixin): +class FalsePositivesPage(ModelQualityPage): title = "🌑 False Positives" def sidebar_options(self): @@ -16,11 +26,16 @@ def sidebar_options(self): def build( self, - model_predictions: pd.DataFrame, - labels: pd.DataFrame, - metrics: pd.DataFrame, - precisions: pd.DataFrame, + model_predictions: DataFrame[PredictionMatchSchema], + labels: DataFrame[LabelMatchSchema], + metrics: DataFrame[PerformanceMetricSchema], + precisions: DataFrame[PrecisionRecallSchema], ): + metric_name = get_state().predictions.metric_datas.selected_predicion + if not metric_name: + st.error("No prediction metric selected") + return + st.markdown(f"# {self.title}") color = Color.RED with st.expander("Details"): @@ -41,11 +56,12 @@ def build( ) self.metric_details_description() - metric_name = st.session_state.predictions_metric - fp_df = model_predictions[model_predictions["tps"] == 0.0].dropna(subset=[metric_name]) + fp_df = model_predictions[model_predictions[PredictionMatchSchema.is_true_positive] == 0.0].dropna( + subset=[metric_name] + ) if fp_df.shape[0] == 0: st.write("No false positives") else: - histogram = self.get_histogram(fp_df, metric_name) + histogram = get_histogram(fp_df, metric_name) st.altair_chart(histogram, use_container_width=True) - metric_view(fp_df, box_color=color) + prediction_grid(get_state().project_paths.data, model_predictions=fp_df, box_color=color) diff --git a/src/encord_active/app/model_quality/sub_pages/metrics.py b/src/encord_active/app/model_quality/sub_pages/metrics.py index 44047d069..1c82b2b04 100644 --- a/src/encord_active/app/model_quality/sub_pages/metrics.py +++ b/src/encord_active/app/model_quality/sub_pages/metrics.py @@ -1,153 +1,40 @@ -import altair as alt -import pandas as pd import streamlit as st -from sklearn.feature_selection import mutual_info_regression +from pandera.typing import DataFrame + +from encord_active.app.common.state import get_state +from encord_active.lib.charts.metric_importance import create_metric_importance_charts +from encord_active.lib.charts.precision_recall import create_pr_charts +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, +) +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) from . import ModelQualityPage +_M_COLS = PerformanceMetricSchema + class MetricsPage(ModelQualityPage): title = "πŸ“ˆ Metrics" - @staticmethod - def create_pr_charts(metrics, precisions): - _metrics = metrics[~metrics["metric"].isin({"mAR", "mAP"})].copy() - - tmp = "m" + _metrics["metric"].str.split("_", n=1, expand=True) - tmp.columns = ["group", "_"] - _metrics["group"] = tmp["group"] - _metrics["average"] = "average" # Legend title - - class_selection = alt.selection_multi(fields=["class_name"]) - - class_bars = ( - alt.Chart(_metrics, title="Mean scores") - .mark_bar() - .encode( - alt.X("value", title="", scale=alt.Scale(domain=[0.0, 1.0])), - alt.Y("metric", title=""), - alt.Color("class_name"), - tooltip=[alt.Tooltip("metric", title="Metric"), alt.Tooltip("value", title="Value", format=",.3f")], - opacity=alt.condition(class_selection, alt.value(1), alt.value(0.1)), - ) - .properties(height=300) - ) - # Average - mean_bars = class_bars.encode( - alt.X("mean(value):Q", title="", scale=alt.Scale(domain=[0.0, 1.0])), - alt.Y("group:N", title=""), - alt.Color("average:N"), - tooltip=[ - alt.Tooltip("group:N", title="Metric"), - alt.Tooltip("mean(value):Q", title="Value", format=",.3f"), - ], - ) - bar_chart = (class_bars + mean_bars).add_selection(class_selection) - - class_precisions = ( - alt.Chart(precisions, title="Precision-Recall Curve") - .mark_line(point=True) - .encode( - alt.X("rc_threshold", title="Recall", scale=alt.Scale(domain=[0.0, 1.0])), - alt.Y("precision", scale=alt.Scale(domain=[0.0, 1.0])), - alt.Color("class_name"), - tooltip=[ - alt.Tooltip("class_name", title="Class"), - alt.Tooltip("rc_threshold", title="Recall"), - alt.Tooltip("precision", title="Precision", format=",.3f"), - ], - opacity=alt.condition(class_selection, alt.value(1.0), alt.value(0.2)), - ) - .properties(height=300) - ) - mean_precisions = ( - class_precisions.transform_calculate(average="'average'") - .mark_line(point=True) - .encode( - alt.X("rc_threshold"), - alt.Y("average(precision):Q"), - alt.Color("average:N"), - tooltip=[ - alt.Tooltip("average:N", title="Aggregate"), - alt.Tooltip("rc_threshold", title="Recall"), - alt.Tooltip("average(precision)", title="Avg. precision", format=",.3f"), - ], - ) - ) - precision_chart = (class_precisions + mean_precisions).add_selection(class_selection) - return bar_chart | precision_chart - - @staticmethod - def create_metric_importance_charts(model_predictions, num_samples): - if num_samples < model_predictions.shape[0]: - _predictions = model_predictions.sample(num_samples, axis=0, random_state=42) - else: - _predictions = model_predictions - - num_tps = (_predictions["tps"].iloc[:num_samples] != 0).sum() - if num_tps < 50: - st.warning( - f"Not enough true positives ({num_tps}) to calculate reliable metric importance. " - "Try increasing the number of samples or lower the IoU threshold in the side bar." - ) - - scores = _predictions["iou"] * _predictions["tps"] - metrics = _predictions[st.session_state.prediction_metric_names] - - correlations = metrics.fillna(0).corrwith(scores, axis=0).to_frame("correlation") - correlations["index"] = correlations.index.T - - mi = pd.DataFrame.from_dict( - {"index": metrics.columns, "importance": mutual_info_regression(metrics.fillna(0), scores)} - ) - sorted_metrics = mi.sort_values("importance", ascending=False, inplace=False)["index"].to_list() - - mutual_info_bars = ( - alt.Chart(mi, title="Metric Importance") - .mark_bar() - .encode( - alt.X("importance", title="Importance", scale=alt.Scale(domain=[0.0, 1.0])), - alt.Y("index", title="Metric", sort=sorted_metrics), - alt.Color("importance", scale=alt.Scale(scheme="blues")), - tooltip=[ - alt.Tooltip("index", title="Metric"), - alt.Tooltip("importance", title="Importance", format=",.3f"), - ], - ) - .properties(height=400) - ) - - correlation_bars = ( - alt.Chart(correlations, title="Metric Correlations") - .mark_bar() - .encode( - alt.X("correlation", title="Correlation", scale=alt.Scale(domain=[-1.0, 1.0])), - alt.Y("index", title="Metric", sort=sorted_metrics), - alt.Color("correlation", scale=alt.Scale(scheme="redyellowgreen", align=0.0)), - tooltip=[ - alt.Tooltip("index", title="Metric"), - alt.Tooltip("correlation", title="Correlation", format=",.3f"), - ], - ) - .properties(height=400) - ) - - return alt.hconcat(mutual_info_bars, correlation_bars).resolve_scale(color="independent") - def sidebar_options(self): pass def build( self, - model_predictions: pd.DataFrame, - labels: pd.DataFrame, - metrics: pd.DataFrame, - precisions: pd.DataFrame, + model_predictions: DataFrame[PredictionMatchSchema], + labels: DataFrame[LabelMatchSchema], + metrics: DataFrame[PerformanceMetricSchema], + precisions: DataFrame[PrecisionRecallSchema], ): st.markdown(f"# {self.title}") - _map = metrics[metrics["metric"] == "mAP"]["value"].item() - _mar = metrics[metrics["metric"] == "mAR"]["value"].item() + _map = metrics[metrics[_M_COLS.metric] == "mAP"]["value"].item() + _mar = metrics[metrics[_M_COLS.metric] == "mAR"]["value"].item() col1, col2 = st.columns(2) col1.metric("mAP", f"{_map:.3f}") col2.metric("mAR", f"{_mar:.3f}") @@ -199,14 +86,16 @@ def build( else: num_samples = model_predictions.shape[0] + metric_columns = list(get_state().predictions.metric_datas.predictions.keys()) with st.spinner("Computing index importance..."): - chart = self.create_metric_importance_charts( + chart = create_metric_importance_charts( model_predictions, + metric_columns=metric_columns, num_samples=num_samples, ) st.altair_chart(chart, use_container_width=True) st.subheader("Subset selection scores") with st.container(): - chart = self.create_pr_charts(metrics, precisions) + chart = create_pr_charts(metrics, precisions) st.altair_chart(chart, use_container_width=True) diff --git a/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py b/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py index 3b9013dc3..0999cc278 100644 --- a/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py +++ b/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py @@ -1,13 +1,21 @@ import re -from enum import Enum -from typing import Optional import altair as alt -import altair.vegalite.v4.api as alt_api -import pandas as pd import streamlit as st +from pandera.typing import DataFrame import encord_active.app.common.state as state +from encord_active.app.common.state import get_state +from encord_active.lib.charts.performance_by_metric import performance_rate_by_metric +from encord_active.lib.charts.scopes import PredictionMatchScope +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, +) +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) from . import ModelQualityPage @@ -16,136 +24,36 @@ COUNT_FMT = ",d" -class ChartSubject(Enum): - TPR = "true positive rate" - FNR = "false negative rate" +CHART_TITLES = { + PredictionMatchScope.TRUE_POSITIVES: "True Positive Rate", + PredictionMatchScope.FALSE_POSITIVES: "False Positive Rate", + PredictionMatchScope.FALSE_NEGATIVES: "False Negative Rate", +} class PerformanceMetric(ModelQualityPage): title = "⚑️ Performance by Metric" - def build_bar_chart( - self, - sorted_predictions: pd.DataFrame, - metric_name: str, - show_decomposition: bool, - title: str, - subject: ChartSubject, - ) -> alt_api.Chart: - str_type = "predictions" if subject == ChartSubject.TPR else "labels" - largest_bin_count = sorted_predictions["bin"].value_counts().max() - chart = ( - alt.Chart(sorted_predictions, title=title) - .transform_joinaggregate(total="count(*)") - .transform_calculate( - pctf=f"1 / {largest_bin_count}", - pct="100 / datum.total", - ) - .mark_bar(align="center", opacity=0.2) - ) - if show_decomposition: - # Aggregate over each class - return chart.encode( - alt.X("bin:Q"), - alt.Y("sum(pctf):Q", stack="zero"), - alt.Color("class_name:N", scale=self.class_scale, legend=alt.Legend(symbolOpacity=1)), - tooltip=[ - alt.Tooltip("bin", title=metric_name, format=FLOAT_FMT), - alt.Tooltip("count():Q", title=f"Num. {str_type}", format=COUNT_FMT), - alt.Tooltip("sum(pct):Q", title=f"% of total {str_type}", format=PCT_FMT), - alt.Tooltip("class_name:N", title="Class name"), - ], - ) - else: - # Only use aggregate over all classes - return chart.encode( - alt.X("bin:Q"), - alt.Y("sum(pctf):Q", stack="zero"), - tooltip=[ - alt.Tooltip("bin", title=metric_name, format=FLOAT_FMT), - alt.Tooltip("count():Q", title=f"Num. {str_type}", format=COUNT_FMT), - alt.Tooltip("sum(pct):Q", title=f"% of total {str_type}", format=PCT_FMT), - ], - ) - - def build_line_chart( - self, bar_chart: alt.Chart, metric_name: str, show_decomposition: bool, title: str - ) -> alt_api.Chart: - legend = alt.Legend(title="class name".title()) - title_shorthand = "".join(w[0].upper() for w in title.split()) - - line_chart = bar_chart.mark_line(point=True, opacity=0.5 if show_decomposition else 1.0).encode( - alt.X("bin:Q"), - alt.Y("mean(indicator):Q"), - alt.Color("average:N", legend=legend, scale=self.class_scale), - tooltip=[ - alt.Tooltip("bin", title=metric_name, format=FLOAT_FMT), - alt.Tooltip("mean(indicator):Q", title=title_shorthand, format=FLOAT_FMT), - alt.Tooltip("average:N", title="Class name"), - ], - strokeDash=alt.value([5, 5]), - ) - - if show_decomposition: - line_chart += line_chart.mark_line(point=True).encode( - alt.Color("class_name:N", legend=legend, scale=self.class_scale), - tooltip=[ - alt.Tooltip("bin", title=metric_name, format=FLOAT_FMT), - alt.Tooltip("mean(indicator):Q", title=title_shorthand, format=FLOAT_FMT), - alt.Tooltip("class_name:N", title="Class name"), - ], - strokeDash=alt.value([10, 0]), - ) - return line_chart - - def build_average_rule(self, indicator_mean: float, title: str) -> alt_api.Chart: - title_shorthand = "".join(w[0].upper() for w in title.split()) - return ( - alt.Chart(pd.DataFrame({"y": [indicator_mean], "average": ["Average"]})) - .mark_rule() - .encode( - alt.Y("y"), - alt.Color("average:N", scale=self.class_scale), - strokeDash=alt.value([5, 5]), - tooltip=[alt.Tooltip("y", title=f"Average {title_shorthand}", format=FLOAT_FMT)], - ) - ) + def description_expander(self): + with st.expander("Details", expanded=False): + st.markdown( + """### The View - def make_composite_chart( - self, df: pd.DataFrame, title: str, metric_name: str, subject: ChartSubject - ) -> Optional[alt_api.LayerChart]: - # Avoid over-shooting number of bins. - if metric_name not in df.columns: - return None - - num_unique_values = df[metric_name].unique().shape[0] - n_bins = min(st.session_state.get(state.PREDICTIONS_NBINS, 100), num_unique_values) - - # Avoid nans - df = df[[metric_name, "indicator", "class_name"]].copy().dropna(subset=[metric_name]) - - if df.empty: - st.info(f"{title}: No values for the selected metric: {metric_name}") - return None - - # Bin the data - df["bin_info"] = pd.qcut(df[metric_name], q=n_bins, duplicates="drop") - df["bin"] = df["bin_info"].map(lambda x: x.mid) - - if df["bin"].dropna().empty: - st.info(f"No scores for the selected metric: {metric_name}") - return None +On this page, your model scores are displayed as a function of the metric that you selected in the top bar. +Samples are discritized into $n$ equally sized buckets and the middle point of each bucket is displayed as the x-value in the plots. +Bars indicate the number of samples in each bucket, while lines indicate the true positive and false negative rates of each bucket. - df["average"] = "Average" # Indicator for altair charts - show_decomposition = st.session_state[state.PREDICTIONS_DECOMPOSE_CLASSES] - bar_chart = self.build_bar_chart(df, metric_name, show_decomposition, title=title, subject=subject) - line_chart = self.build_line_chart(bar_chart, metric_name, show_decomposition, title=title) - mean_rule = self.build_average_rule(df["indicator"].mean(), title=title) +Metrics marked with (P) are metrics computed on your predictions. +Metrics marked with (F) are frame level metrics, which depends on the frame that each prediction is associated +with. In the "False Negative Rate" plot, (O) means metrics compoted on Object labels. - chart_composition: alt_api.LayerChart = bar_chart + line_chart + mean_rule - chart_composition = chart_composition.encode(alt.X(title=metric_name.title()), alt.Y(title=title.title())) - return chart_composition +For metrics that are computed on predictions (P) in the "True Positive Rate" plot, the corresponding "label metrics" (O/F) computed +on your labels are used for the "False Negative Rate" plot. +""", + unsafe_allow_html=True, + ) + self.metric_details_description() def sidebar_options(self): c1, c2, c3 = st.columns([4, 4, 3]) @@ -153,35 +61,39 @@ def sidebar_options(self): self.prediction_metric_in_sidebar() with c2: - st.number_input( - "Number of buckets (n)", - min_value=5, - max_value=200, - value=50, - help="Choose the number of bins to discritize the prediction metric values into.", - key=state.PREDICTIONS_NBINS, + get_state().predictions.nbins = int( + st.number_input( + "Number of buckets (n)", + min_value=5, + max_value=200, + value=get_state().predictions.nbins, + help="Choose the number of bins to discritize the prediction metric values into.", + ) ) with c3: st.write("") # Make some spacing. st.write("") - st.checkbox( + get_state().predictions.decompose_classes = st.checkbox( "Show class decomposition", - key=state.PREDICTIONS_DECOMPOSE_CLASSES, + value=get_state().predictions.decompose_classes, help="When checked, every plot will have a separate component for each class.", ) def build( self, - model_predictions: pd.DataFrame, - labels: pd.DataFrame, - metrics: pd.DataFrame, - precisions: pd.DataFrame, + model_predictions: DataFrame[PredictionMatchSchema], + labels: DataFrame[LabelMatchSchema], + metrics: DataFrame[PerformanceMetricSchema], + precisions: DataFrame[PrecisionRecallSchema], ): st.markdown(f"# {self.title}") if model_predictions.shape[0] == 0: st.write("No predictions of the given class(es).") - elif state.PREDICTIONS_METRIC not in st.session_state: + return + + metric_name = state.get_state().predictions.metric_datas.selected_predicion + if not metric_name: # This shouldn't happen with the current flow. The only way a user can do this # is if he/she write custom code to bypass running the metrics. In this case, # I think that it is fair to not give more information than this. @@ -190,64 +102,44 @@ def build( "With `encord-active import predictions /path/to/predictions.pkl`, " "Encord Active will automatically run compute the metrics." ) - else: - with st.expander("Details", expanded=False): - st.markdown( - """### The View - -On this page, your model scores are displayed as a function of the metric that you selected in the top bar. -Samples are discritized into $n$ equally sized buckets and the middle point of each bucket is displayed as the x-value in the plots. -Bars indicate the number of samples in each bucket, while lines indicate the true positive and false negative rates of each bucket. - + return + + self.description_expander() + + label_metric_name = metric_name + if metric_name[-3:] == "(P)": # Replace the P with O: "Metric (P)" -> "Metric (O)" + label_metric_name = re.sub(r"(.*?\()P(\))", r"\1O\2", metric_name) + + if not label_metric_name in labels.columns: + label_metric_name = re.sub( + r"(.*?\()O(\))", r"\1F\2", label_metric_name + ) # Look for it in frame label metrics. + + classes_for_coloring = ["Average"] + decompose_classes = get_state().predictions.decompose_classes + if decompose_classes: + unique_classes = set(model_predictions["class_name"].unique()).union(labels["class_name"].unique()) + classes_for_coloring += sorted(list(unique_classes)) + + # Ensure same colors between plots + chart_args = dict( + color_params={"scale": alt.Scale(domain=classes_for_coloring)}, + bins=get_state().predictions.nbins, + show_decomposition=decompose_classes, + ) -Metrics marked with (P) are metrics computed on your predictions. -Metrics marked with (F) are frame level metrics, which depends on the frame that each prediction is associated -with. In the "False Negative Rate" plot, (O) means metrics compoted on Object labels. + tpr = performance_rate_by_metric( + model_predictions, metric_name, scope=PredictionMatchScope.TRUE_POSITIVES, **chart_args + ) + if tpr is None: + st.stop() + st.altair_chart(tpr.interactive(), use_container_width=True) -For metrics that are computed on predictions (P) in the "True Positive Rate" plot, the corresponding "label metrics" (O/F) computed -on your labels are used for the "False Negative Rate" plot. -""", - unsafe_allow_html=True, - ) - self.metric_details_description() - - metric_name = st.session_state[state.PREDICTIONS_METRIC] - - label_metric_name = metric_name - if metric_name[-3:] == "(P)": # Replace the P with O: "Metric (P)" -> "Metric (O)" - label_metric_name = re.sub(r"(.*?\()P(\))", r"\1O\2", metric_name) - - if not label_metric_name in labels.columns: - label_metric_name = re.sub( - r"(.*?\()O(\))", r"\1F\2", label_metric_name - ) # Look for it in frame label metrics. - - classes_for_coloring = ["Average"] - if st.session_state.get(state.PREDICTIONS_DECOMPOSE_CLASSES, False): - unique_classes = set(model_predictions["class_name"].unique()).union(labels["class_name"].unique()) - classes_for_coloring += sorted(list(unique_classes)) - - # Ensure same colors between plots - self.class_scale = alt.Scale( - domain=classes_for_coloring, - ) # Used to sync colors between plots. - - # TPR - predictions = model_predictions.rename(columns={"tps": "indicator"}) - tpr = self.make_composite_chart(predictions, "True Positive Rate", metric_name, subject=ChartSubject.TPR) - if tpr is None: - st.stop() - st.altair_chart(tpr.interactive(), use_container_width=True) - - # FNR - fnr = self.make_composite_chart( - labels.rename(columns={"fns": "indicator"}), - "False Negative Rate", - label_metric_name, - subject=ChartSubject.FNR, - ) + fnr = performance_rate_by_metric( + labels, label_metric_name, scope=PredictionMatchScope.FALSE_NEGATIVES, **chart_args + ) - if fnr is None: # Label metric couldn't be matched to - st.stop() + if fnr is None: # Label metric couldn't be matched to + st.stop() - st.altair_chart(fnr.interactive(), use_container_width=True) + st.altair_chart(fnr.interactive(), use_container_width=True) diff --git a/src/encord_active/app/model_quality/sub_pages/true_positives.py b/src/encord_active/app/model_quality/sub_pages/true_positives.py index 568b76257..d74a9aa9c 100644 --- a/src/encord_active/app/model_quality/sub_pages/true_positives.py +++ b/src/encord_active/app/model_quality/sub_pages/true_positives.py @@ -1,13 +1,23 @@ -import pandas as pd import streamlit as st +from pandera.typing import DataFrame -from encord_active.app.common.colors import Color -from encord_active.app.model_quality.components import metric_view +from encord_active.app.common.components.prediction_grid import prediction_grid +from encord_active.app.common.state import get_state +from encord_active.lib.charts.histogram import get_histogram +from encord_active.lib.common.colors import Color +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, +) +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) -from . import HistogramMixin, ModelQualityPage +from . import ModelQualityPage -class TruePositivesPage(ModelQualityPage, HistogramMixin): +class TruePositivesPage(ModelQualityPage): title = "βœ… True Positives" def sidebar_options(self): @@ -16,10 +26,10 @@ def sidebar_options(self): def build( self, - model_predictions: pd.DataFrame, - labels: pd.DataFrame, - metrics: pd.DataFrame, - precisions: pd.DataFrame, + model_predictions: DataFrame[PredictionMatchSchema], + labels: DataFrame[LabelMatchSchema], + metrics: DataFrame[PerformanceMetricSchema], + precisions: DataFrame[PrecisionRecallSchema], ): st.markdown(f"# {self.title}") @@ -39,11 +49,18 @@ def build( unsafe_allow_html=True, ) self.metric_details_description() - metric_name = st.session_state.predictions_metric - tp_df = model_predictions[model_predictions["tps"] == 1.0].dropna(subset=[metric_name]) + + metric_name = get_state().predictions.metric_datas.selected_predicion + if not metric_name: + st.error("No prediction metric selected") + return + + tp_df = model_predictions[model_predictions[PredictionMatchSchema.is_true_positive] == 1.0].dropna( + subset=[metric_name] + ) if tp_df.shape[0] == 0: st.write("No true positives") else: - histogram = self.get_histogram(tp_df, metric_name) + histogram = get_histogram(tp_df, metric_name) st.altair_chart(histogram, use_container_width=True) - metric_view(tp_df, box_color=color) + prediction_grid(get_state().project_paths.data, model_predictions=tp_df, box_color=color) diff --git a/src/encord_active/app/streamlit_entrypoint.py b/src/encord_active/app/streamlit_entrypoint.py index f8b585dd1..97057f567 100644 --- a/src/encord_active/app/streamlit_entrypoint.py +++ b/src/encord_active/app/streamlit_entrypoint.py @@ -1,14 +1,14 @@ import argparse from functools import reduce +from pathlib import Path from typing import Callable, Dict, Optional, Union import streamlit as st from encord_active.app.actions_page.export_balance import export_balance from encord_active.app.actions_page.export_filter import export_filter -from encord_active.app.common import state +from encord_active.app.common.state import State from encord_active.app.common.utils import set_page_config -from encord_active.app.data_quality.common import MetricType from encord_active.app.frontend_components import pages_menu from encord_active.app.model_quality.sub_pages.false_negatives import FalseNegativesPage from encord_active.app.model_quality.sub_pages.false_positives import FalsePositivesPage @@ -20,13 +20,15 @@ from encord_active.app.views.landing_page import landing_page from encord_active.app.views.metrics import explorer, summary from encord_active.app.views.model_quality import model_quality +from encord_active.lib.db.connection import DBConnection +from encord_active.lib.metrics.utils import MetricScope Pages = Dict[str, Union[Callable, "Pages"]] # type: ignore pages: Pages = { "Encord Active": landing_page, - "Data Quality": {"Summary": summary(MetricType.DATA_QUALITY), "Explorer": explorer(MetricType.DATA_QUALITY)}, - "Label Quality": {"Summary": summary(MetricType.LABEL_QUALITY), "Explorer": explorer(MetricType.LABEL_QUALITY)}, + "Data Quality": {"Summary": summary(MetricScope.DATA_QUALITY), "Explorer": explorer(MetricScope.DATA_QUALITY)}, + "Label Quality": {"Summary": summary(MetricScope.LABEL_QUALITY), "Explorer": explorer(MetricScope.LABEL_QUALITY)}, "Model Quality": { "Metrics": model_quality(MetricsPage()), "Performance By Metric": model_quality(PerformanceMetric()), @@ -56,8 +58,14 @@ def to_items(d: dict, parent_key: Optional[str] = None): def main(project_path: str): set_page_config() - if not state.set_project_dir(project_path): + project_dir = Path(project_path).expanduser().absolute() + st.session_state.project_dir = project_dir + if not st.session_state.project_dir.is_dir(): st.error(f"Project not found for directory `{project_path}`.") + st.stop() + + DBConnection.set_project_path(project_dir) + State.init(project_dir) with st.sidebar: items = to_items(pages) diff --git a/src/encord_active/app/views/landing_page.py b/src/encord_active/app/views/landing_page.py index 6817b9a4a..cf5cfc079 100644 --- a/src/encord_active/app/views/landing_page.py +++ b/src/encord_active/app/views/landing_page.py @@ -2,7 +2,7 @@ import streamlit as st -from encord_active.app.common.utils import load_merged_df, set_page_config, setup_page +from encord_active.app.common.utils import set_page_config, setup_page def landing_page(): @@ -24,7 +24,7 @@ def landing_page(): with right_col: st.markdown((Path(__file__).parent.parent / "static" / "landing-page-documentation.md").read_text()) - load_merged_df() + # load_merged_df() if __name__ == "__main__": diff --git a/src/encord_active/app/views/metrics.py b/src/encord_active/app/views/metrics.py index b2960ccef..962fcf8d9 100644 --- a/src/encord_active/app/views/metrics.py +++ b/src/encord_active/app/views/metrics.py @@ -1,11 +1,12 @@ from encord_active.app.common.components import sticky_header +from encord_active.app.common.state import get_state from encord_active.app.common.utils import setup_page -from encord_active.app.data_quality.common import MetricType, load_available_metrics from encord_active.app.data_quality.sub_pages.explorer import ExplorerPage from encord_active.app.data_quality.sub_pages.summary import SummaryPage +from encord_active.lib.metrics.utils import MetricScope, load_available_metrics -def summary(metric_type: MetricType): +def summary(metric_type: MetricScope): def render(): setup_page() page = SummaryPage() @@ -18,11 +19,11 @@ def render(): return render -def explorer(metric_type: MetricType): +def explorer(metric_type: MetricScope): def render(): setup_page() page = ExplorerPage() - available_metrics = load_available_metrics(metric_type.value) + available_metrics = load_available_metrics(get_state().project_paths.metrics, metric_type) with sticky_header(): selected_df = page.sidebar_options(available_metrics) diff --git a/src/encord_active/app/views/model_quality.py b/src/encord_active/app/views/model_quality.py index 8a2846a0e..feeb5d70f 100644 --- a/src/encord_active/app/views/model_quality.py +++ b/src/encord_active/app/views/model_quality.py @@ -1,21 +1,24 @@ -from copy import deepcopy - import streamlit as st -import encord_active.app.common.state as state -import encord_active.app.model_quality.data as pred_data +import encord_active.lib.model_predictions.reader as reader from encord_active.app.common.components import sticky_header +from encord_active.app.common.state import MetricNames, get_state +from encord_active.app.common.state_hooks import use_memo from encord_active.app.common.utils import setup_page -from encord_active.app.model_quality.map_mar import compute_mAP_and_mAR from encord_active.app.model_quality.settings import common_settings from encord_active.app.model_quality.sub_pages import Page +from encord_active.lib.model_predictions.filters import ( + filter_labels_for_frames_wo_predictions, + prediction_and_label_filtering, +) +from encord_active.lib.model_predictions.map_mar import compute_mAP_and_mAR def model_quality(page: Page): def render(): setup_page() - if not pred_data.check_model_prediction_availability(): + if not reader.check_model_prediction_availability(get_state().project_paths.predictions): st.markdown( "# Missing Model Predictions\n" "This project does not have any imported predictions. " @@ -25,51 +28,59 @@ def render(): ) return - st.session_state.selected_class_idx = pred_data.get_class_idx() - st.session_state.full_class_idx = deepcopy(pred_data.get_class_idx()) - ( - st.session_state.model_predictions, - st.session_state.prediction_metric_names, - ) = pred_data.get_model_predictions() or (None, None) - if st.session_state.selected_class_idx is None: + predictions_dir = get_state().project_paths.predictions + metrics_dir = get_state().project_paths.metrics + + predictions_metric_datas = use_memo(lambda: reader.get_prediction_metric_data(predictions_dir, metrics_dir)) + label_metric_datas = use_memo(lambda: reader.get_prediction_metric_data(predictions_dir, metrics_dir)) + model_predictions = use_memo(lambda: reader.get_model_predictions(predictions_dir, predictions_metric_datas)) + labels = use_memo(lambda: reader.get_labels(predictions_dir, label_metric_datas)) + + if model_predictions is None: st.error("Couldn't load model predictions") return - st.session_state.labels, st.session_state.label_metric_names = pred_data.get_labels() or (None, None) - if st.session_state.labels is None: + if labels is None: + st.error("Couldn't load labels properly") return - st.session_state.gt_matched = pred_data.get_gt_matched() - st.session_state.metric_meta = pred_data.get_metadata_files() + matched_gt = use_memo(lambda: reader.get_gt_matched(predictions_dir)) + get_state().predictions.metric_datas = MetricNames( + predictions={m.name: m for m in predictions_metric_datas}, + labels={m.name: m for m in label_metric_datas}, + ) + + if not matched_gt: + st.error("Couldn't match groung truths") + return with sticky_header(): common_settings() page.sidebar_options() - metrics, precisions, tps, fp_reasons, fns = compute_mAP_and_mAR( - iou_threshold=st.session_state.get(state.IOU_THRESHOLD), - ignore_unmatched_frames=st.session_state[state.IGNORE_FRAMES_WO_PREDICTIONS], + (matched_predictions, matched_labels, metrics, precisions,) = compute_mAP_and_mAR( + model_predictions, + labels, + matched_gt, + get_state().predictions.all_classes, + iou_threshold=get_state().iou_threshold, + ignore_unmatched_frames=get_state().ignore_frames_without_predictions, ) - st.session_state.model_predictions["tps"] = tps.astype(float) - st.session_state.model_predictions["fp_reason"] = fp_reasons["fp_reason"] - st.session_state.labels["fns"] = fns # Sort predictions and labels according to selected metrics. - pred_sort_column = st.session_state.get(state.PREDICTIONS_METRIC, st.session_state.prediction_metric_names[0]) - st.session_state.sorted_model_predictions = st.session_state.model_predictions.sort_values( - [pred_sort_column], axis=0 - ) + pred_sort_column = get_state().predictions.metric_datas.selected_predicion or predictions_metric_datas[0].name + sorted_model_predictions = matched_predictions.sort_values([pred_sort_column], axis=0) - label_sort_column = st.session_state.get(state.PREDICTIONS_LABEL_METRIC, st.session_state.label_metric_names[0]) - st.session_state.sorted_labels = st.session_state.labels.sort_values([label_sort_column], axis=0) + label_sort_column = get_state().predictions.metric_datas.selected_label or label_metric_datas[0].name + sorted_labels = matched_labels.sort_values([label_sort_column], axis=0) - if st.session_state[state.IGNORE_FRAMES_WO_PREDICTIONS]: - labels = pred_data.filter_labels_for_frames_wo_predictions() + if get_state().ignore_frames_without_predictions: + matched_labels = filter_labels_for_frames_wo_predictions(matched_predictions, sorted_labels) else: - labels = st.session_state.sorted_labels + matched_labels = sorted_labels - _labels, _metrics, _model_pred, _precisions = pred_data.prediction_and_label_filtering( - labels, metrics, st.session_state.sorted_model_predictions, precisions + _labels, _metrics, _model_pred, _precisions = prediction_and_label_filtering( + get_state().predictions.selected_classes, matched_labels, metrics, sorted_model_predictions, precisions ) page.build(model_predictions=_model_pred, labels=_labels, metrics=_metrics, precisions=_precisions) diff --git a/src/encord_active/cli/imports.py b/src/encord_active/cli/imports.py index f355fb1e2..b8847a911 100644 --- a/src/encord_active/cli/imports.py +++ b/src/encord_active/cli/imports.py @@ -25,7 +25,7 @@ def import_predictions( [bold]Imports[/bold] a predictions file. The predictions should be using the `Prediction` model and be stored in a pkl file. If `--coco` option is specified the file should be a json following the coco results format. :brain: """ - from encord_active.lib.common.project import Project + from encord_active.lib.project.project import Project project = Project(target) @@ -37,7 +37,7 @@ def import_predictions( with open(predictions_path, "rb") as f: predictions = pickle.load(f) - from encord_active.app.db.predictions import ( + from encord_active.lib.db.predictions import ( import_predictions as app_import_predictions, ) diff --git a/src/encord_active/cli/main.py b/src/encord_active/cli/main.py index 93e79608e..9c2d3b612 100644 --- a/src/encord_active/cli/main.py +++ b/src/encord_active/cli/main.py @@ -40,7 +40,7 @@ def download( * If --project_name is not given as an argument, available prebuilt projects will be listed and the user can select one from the menu. """ - from encord_active.lib.metrics.fetch_prebuilt_metrics import ( + from encord_active.lib.project.sandbox_projects import ( PREBUILT_PROJECTS, fetch_prebuilt_project_size, ) @@ -68,7 +68,7 @@ def download( project_dir = target / project_name project_dir.mkdir(exist_ok=True) - from encord_active.lib.metrics.fetch_prebuilt_metrics import fetch_prebuilt_project + from encord_active.lib.project.sandbox_projects import fetch_prebuilt_project fetch_prebuilt_project(project_name, project_dir) @@ -100,7 +100,7 @@ def quickstart( Take the shortcut and start the application straight away πŸƒπŸ’¨ """ from encord_active.cli.utils.streamlit import launch_streamlit_app - from encord_active.lib.metrics.fetch_prebuilt_metrics import fetch_prebuilt_project + from encord_active.lib.project.sandbox_projects import fetch_prebuilt_project project_name = "quickstart" project_dir = target / project_name diff --git a/src/encord_active/cli/print.py b/src/encord_active/cli/print.py index 27994e1a0..3964bb998 100644 --- a/src/encord_active/cli/print.py +++ b/src/encord_active/cli/print.py @@ -27,7 +27,7 @@ def print_encord_projects( > encord-active print encord-projects --query "%validation%" """ - from encord_active.lib.encord.project import get_projects_json + from encord_active.lib.encord.utils import get_projects_json json_projects = get_projects_json(app_config.get_or_query_ssh_key(), query) if state.get("json_output"): @@ -44,7 +44,7 @@ def print_ontology( """ [bold]Prints[/bold] an ontology mapping between the class name to the `featureNodeHash` JSON format. """ - from encord_active.app.common.cli_helpers import get_local_project + from encord_active.lib.common.utils import get_local_project project = get_local_project(target) objects = project.ontology["objects"] diff --git a/src/encord_active/cli/utils/coco.py b/src/encord_active/cli/utils/coco.py index 89184eb00..46c723444 100644 --- a/src/encord_active/cli/utils/coco.py +++ b/src/encord_active/cli/utils/coco.py @@ -3,10 +3,10 @@ import numpy as np -from encord_active.app.db.predictions import BoundingBox, Format, Prediction from encord_active.lib.coco.importer import IMAGE_DATA_UNIT_FILENAME, CocoImporter from encord_active.lib.coco.parsers import parse_results -from encord_active.lib.metrics.run_all import run_metrics +from encord_active.lib.db.predictions import BoundingBox, Format, Prediction +from encord_active.lib.metrics.execute import run_metrics def import_coco_predictions( diff --git a/src/encord_active/cli/utils/encord.py b/src/encord_active/cli/utils/encord.py index b695bd9ae..e86b3498e 100644 --- a/src/encord_active/cli/utils/encord.py +++ b/src/encord_active/cli/utils/encord.py @@ -10,8 +10,8 @@ from rich.markup import escape from rich.panel import Panel -from encord_active.lib.encord.project import get_projects_json -from encord_active.lib.metrics.run_all import run_metrics +from encord_active.lib.encord.utils import get_projects_json +from encord_active.lib.metrics.execute import run_metrics PROJECT_HASH_REGEX = r"([0-9a-f]{8})-([0-9a-f]{4})-([0-9a-f]{4})-([0-9a-f]{4})-([0-9a-f]{12})" diff --git a/src/encord_active/cli/utils/streamlit.py b/src/encord_active/cli/utils/streamlit.py index ef8a0fed7..05363d425 100644 --- a/src/encord_active/cli/utils/streamlit.py +++ b/src/encord_active/cli/utils/streamlit.py @@ -1,4 +1,5 @@ import sys +from os import environ from pathlib import Path from streamlit.web import cli as stcli @@ -9,4 +10,6 @@ def launch_streamlit_app(target: Path): data_dir = target.expanduser().absolute().as_posix() sys.argv = ["streamlit", "run", streamlit_page.as_posix(), data_dir] + # NOTE: we need to set PYTHONPATH for file watching + environ["PYTHONPATH"] = (Path(__file__).parents[2]).as_posix() sys.exit(stcli.main()) # pylint: disable=no-value-for-parameter diff --git a/src/encord_active/lib/charts/__init__.py b/src/encord_active/lib/charts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/lib/charts/histogram.py b/src/encord_active/lib/charts/histogram.py new file mode 100644 index 000000000..dafb78d64 --- /dev/null +++ b/src/encord_active/lib/charts/histogram.py @@ -0,0 +1,28 @@ +from typing import Optional + +import altair as alt +import pandas as pd + + +def get_histogram(df: pd.DataFrame, column_name: str, metric_name: Optional[str] = None): + chart_title = "Data distribution" + + if metric_name: + chart_title += f" - {metric_name}" + else: + metric_name = column_name + + bar_chart = ( + alt.Chart(df, title=chart_title) + .mark_bar() + .encode( + alt.X(f"{column_name}:Q", bin=alt.Bin(maxbins=100), title=metric_name), + alt.Y("count()", title="Num. samples"), + tooltip=[ + alt.Tooltip(f"{column_name}:Q", title=metric_name, format=",.3f", bin=True), + alt.Tooltip("count():Q", title="Num. samples", format="d"), + ], + ) + .properties(height=200) + ) + return bar_chart diff --git a/src/encord_active/lib/charts/metric_importance.py b/src/encord_active/lib/charts/metric_importance.py new file mode 100644 index 000000000..9dced5f74 --- /dev/null +++ b/src/encord_active/lib/charts/metric_importance.py @@ -0,0 +1,70 @@ +from typing import List + +import altair as alt +import pandas as pd +from pandera.typing import DataFrame +from sklearn.feature_selection import mutual_info_regression + +from encord_active.lib.model_predictions.reader import PredictionMatchSchema + +_P_COLS = PredictionMatchSchema + + +def create_metric_importance_charts( + model_predictions: DataFrame[PredictionMatchSchema], metric_columns: List[str], num_samples: int +): + if num_samples < model_predictions.shape[0]: + _predictions = model_predictions.sample(num_samples, axis=0, random_state=42) + else: + _predictions = model_predictions + + num_tps = (_predictions[_P_COLS.is_true_positive].iloc[:num_samples] != 0).sum() + if num_tps < 50: + raise ValueError( + f"Not enough true positives ({num_tps}) to calculate reliable metric importance. " + "Try increasing the number of samples or lower the IoU threshold in the side bar." + ) + + scores = _predictions[_P_COLS.iou] * _predictions[_P_COLS.is_true_positive] + metrics = _predictions[metric_columns] + + correlations = metrics.fillna(0).corrwith(scores, axis=0).to_frame("correlation") + correlations["index"] = correlations.index.T + + mi = pd.DataFrame.from_dict( + {"index": metrics.columns, "importance": mutual_info_regression(metrics.fillna(0), scores, random_state=42)} + ) + # pylint: disable=unsubscriptable-object + sorted_metrics: List[str] = mi.sort_values("importance", ascending=False, inplace=False)["index"].to_list() + + mutual_info_bars = ( + alt.Chart(mi, title="Metric Importance") + .mark_bar() + .encode( + alt.X("importance", title="Importance", scale=alt.Scale(domain=[0.0, 1.0])), + alt.Y("index", title="Metric", sort=sorted_metrics), + alt.Color("importance", scale=alt.Scale(scheme="blues")), + tooltip=[ + alt.Tooltip("index", title="Metric"), + alt.Tooltip("importance", title="Importance", format=",.3f"), + ], + ) + .properties(height=400) + ) + + correlation_bars = ( + alt.Chart(correlations, title="Metric Correlations") + .mark_bar() + .encode( + alt.X("correlation", title="Correlation", scale=alt.Scale(domain=[-1.0, 1.0])), + alt.Y("index", title="Metric", sort=sorted_metrics), + alt.Color("correlation", scale=alt.Scale(scheme="redyellowgreen", align=0.0)), + tooltip=[ + alt.Tooltip("index", title="Metric"), + alt.Tooltip("correlation", title="Correlation", format=",.3f"), + ], + ) + .properties(height=400) + ) + + return alt.hconcat(mutual_info_bars, correlation_bars).resolve_scale(color="independent") diff --git a/src/encord_active/lib/charts/partition_histogram.py b/src/encord_active/lib/charts/partition_histogram.py new file mode 100644 index 000000000..20fee1c8b --- /dev/null +++ b/src/encord_active/lib/charts/partition_histogram.py @@ -0,0 +1,17 @@ +import altair as alt +import pandas as pd + + +def get_partition_histogram(balanced_df: pd.DataFrame, metric: str): + return ( + alt.Chart(balanced_df) + .mark_bar( + binSpacing=0, + ) + .encode( + x=alt.X(f"{metric}:Q", bin=alt.Bin(maxbins=50)), + y="count()", + color="partition:N", + tooltip=["partition", "count()"], + ) + ) diff --git a/src/encord_active/lib/charts/performance_by_metric.py b/src/encord_active/lib/charts/performance_by_metric.py new file mode 100644 index 000000000..62ba7e639 --- /dev/null +++ b/src/encord_active/lib/charts/performance_by_metric.py @@ -0,0 +1,195 @@ +from typing import Optional + +import altair as alt +import altair.vegalite.v4.api as alt_api +import pandas as pd +import pandera as pa +import pandera.dtypes as dtypes +from pandera.typing import DataFrame, Series + +from encord_active.lib.charts.scopes import PredictionMatchScope +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, +) + +FLOAT_FMT = ",.4f" +PCT_FMT = ",.2f" +COUNT_FMT = ",d" + + +CHART_TITLES = { + PredictionMatchScope.TRUE_POSITIVES: "True Positive Rate", + PredictionMatchScope.FALSE_POSITIVES: "False Positive Rate", + PredictionMatchScope.FALSE_NEGATIVES: "False Negative Rate", +} + + +class BinSchema(pa.SchemaModel): + metric_value: Series[float] = pa.Field() + class_name: Series[str] = pa.Field() + indicator: Series[float] = pa.Field() + bin_info: Series[dtypes.Category] = pa.Field() + bin: Series[float] = pa.Field(coerce=True) + average: Series[str] = pa.Field() + + +def bin_data(df: pd.DataFrame, metric_name: str, scope: PredictionMatchScope, bins: int = 100) -> DataFrame[BinSchema]: + # Data validation + indicator_column = ( + PredictionMatchSchema.is_true_positive + if scope == PredictionMatchScope.TRUE_POSITIVES + else LabelMatchSchema.is_false_negative + ) + + schema = pa.DataFrameSchema( + { + metric_name: pa.Column( + float, nullable=True, coerce=True, checks=[pa.Check(lambda x: x.count() > 0, element_wise=False)] + ), + "class_name": pa.Column(str), + indicator_column: pa.Column(float, coerce=True), + } + ) + df = schema.validate(df) + + df = df[[metric_name, indicator_column, "class_name"]].copy().dropna(subset=[metric_name]) + df.rename(columns={metric_name: BinSchema.metric_value, indicator_column: BinSchema.indicator}, inplace=True) + + # Avoid over-shooting number of bins. + num_unique_values = df[BinSchema.metric_value].unique().shape[0] + n_bins = min(bins, num_unique_values) + + # Bin the data + df["bin_info"] = pd.qcut(df[BinSchema.metric_value], q=n_bins, duplicates="drop") + df["bin"] = df["bin_info"].map(lambda x: x.mid) + df["average"] = "Average" # Indicator for altair charts + return df.pipe(DataFrame[BinSchema]) + + +def bin_bar_chart( + binned_df: DataFrame[BinSchema], + metric_name: str, + scope: PredictionMatchScope, + show_decomposition: bool = False, + color_params: Optional[dict] = None, +) -> alt_api.Chart: + str_type = "predictions" if scope == PredictionMatchScope.TRUE_POSITIVES else "labels" + largest_bin_count = binned_df["bin"].value_counts().max() + color_params = color_params or {} + chart = ( + alt.Chart(binned_df) + .transform_joinaggregate(total="count(*)") + .transform_calculate( + pctf=f"1 / {largest_bin_count}", + pct="100 / datum.total", + ) + .mark_bar(align="center", opacity=0.2) + ) + if show_decomposition: + # Aggregate over each class + return chart.encode( + alt.X("bin:Q"), + alt.Y("sum(pctf):Q", stack="zero"), + alt.Color(f"{BinSchema.class_name}:N", legend=alt.Legend(symbolOpacity=1), **color_params), + tooltip=[ + alt.Tooltip(BinSchema.bin, title=metric_name, format=FLOAT_FMT), + alt.Tooltip("count():Q", title=f"Num. {str_type}", format=COUNT_FMT), + alt.Tooltip("sum(pct):Q", title=f"% of total {str_type}", format=PCT_FMT), + alt.Tooltip(f"{BinSchema.class_name}:N", title="Class name"), + ], + ) + else: + # Only use aggregate over all classes + return chart.encode( + alt.X(f"{BinSchema.bin}:Q"), + alt.Y("sum(pctf):Q", stack="zero"), + tooltip=[ + alt.Tooltip(BinSchema.bin, title=metric_name, format=FLOAT_FMT), + alt.Tooltip("count():Q", title=f"Num. {str_type}", format=COUNT_FMT), + alt.Tooltip("sum(pct):Q", title=f"% of total {str_type}", format=PCT_FMT), + ], + ) + + +def performance_rate_line_chart( + bar_chart: alt.Chart, + metric_name: str, + scope: PredictionMatchScope, + show_decomposition: bool = False, + color_params: Optional[dict] = None, +) -> alt_api.Chart: + legend = alt.Legend(title="class name".title()) + title_shorthand = "".join(w[0].upper() for w in CHART_TITLES[scope].split()) + color_params = color_params or {} + + line_chart = bar_chart.mark_line(point=True, opacity=0.5 if show_decomposition else 1.0).encode( + alt.X(f"{BinSchema.bin}:Q"), + alt.Y(f"mean({BinSchema.indicator}):Q"), + alt.Color(f"{BinSchema.average}:N", legend=legend, **color_params), + tooltip=[ + alt.Tooltip("bin", title=metric_name, format=FLOAT_FMT), + alt.Tooltip(f"mean({BinSchema.indicator}):Q", title=title_shorthand, format=FLOAT_FMT), + alt.Tooltip(f"{BinSchema.average}:N", title="Class name"), + ], + strokeDash=alt.value([5, 5]), + ) + + if show_decomposition: + line_chart += line_chart.mark_line(point=True).encode( + alt.Color(f"{BinSchema.class_name}:N", legend=legend, **color_params), + tooltip=[ + alt.Tooltip("bin", title=metric_name, format=FLOAT_FMT), + alt.Tooltip(f"mean({BinSchema.indicator}):Q", title=title_shorthand, format=FLOAT_FMT), + alt.Tooltip(f"{BinSchema.class_name}:N", title="Class name"), + ], + strokeDash=alt.value([10, 0]), + ) + return line_chart + + +def performance_average_rule( + indicator_mean: float, scope: PredictionMatchScope, color_params: Optional[dict] = None +) -> alt_api.Chart: + title = CHART_TITLES[scope] + title_shorthand = "".join(w[0].upper() for w in title.split()) + color_params = color_params or {} + return ( + alt.Chart(pd.DataFrame({"y": [indicator_mean], "average": ["Average"]})) + .mark_rule() + .encode( + alt.Y("y"), + alt.Color("average:N", **color_params), + strokeDash=alt.value([5, 5]), + tooltip=[alt.Tooltip("y", title=f"Average {title_shorthand}", format=FLOAT_FMT)], + ) + ) + + +def performance_rate_by_metric( + df: pd.DataFrame, + metric_name: str, + scope: PredictionMatchScope, + bins: int = 100, + show_decomposition: bool = False, + color_params: Optional[dict] = None, +) -> Optional[alt_api.LayerChart]: + binned_df = bin_data(df, metric_name, scope, bins=bins) + if binned_df.empty: + raise ValueError(f"No scores for the selected metric: {metric_name}") + + bar_chart = bin_bar_chart( + binned_df, metric_name, scope, show_decomposition=show_decomposition, color_params=color_params + ) + line_chart = performance_rate_line_chart( + bar_chart, metric_name, scope, show_decomposition=show_decomposition, color_params=color_params + ) + mean_rule = performance_average_rule(binned_df[BinSchema.indicator].mean(), scope, color_params=color_params) + + chart_composition: alt_api.LayerChart = bar_chart + line_chart + mean_rule + + title = CHART_TITLES[scope] + chart_composition = chart_composition.encode(alt.X(title=metric_name.title()), alt.Y(title=title)).properties( + title=title + ) + return chart_composition diff --git a/src/encord_active/lib/charts/precision_recall.py b/src/encord_active/lib/charts/precision_recall.py new file mode 100644 index 000000000..db4e8d11e --- /dev/null +++ b/src/encord_active/lib/charts/precision_recall.py @@ -0,0 +1,82 @@ +import altair as alt +from pandera.typing import DataFrame + +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, +) + +_PR_COLS = PrecisionRecallSchema +_M_COLS = PerformanceMetricSchema + + +def create_pr_charts(metrics: DataFrame[PerformanceMetricSchema], precisions: DataFrame[PrecisionRecallSchema]): + _metrics: DataFrame[PerformanceMetricSchema] = metrics[~metrics[_M_COLS.metric].isin({"mAR", "mAP"})].copy() + + tmp = "m" + _metrics["metric"].str.split("_", n=1, expand=True) + tmp.columns = ["group", "_"] + _metrics["group"] = tmp["group"] + _metrics["average"] = "average" # Legend title + + class_selection = alt.selection_multi(fields=[_M_COLS.class_name]) + + class_bars = ( + alt.Chart(_metrics, title="Mean scores") + .mark_bar() + .encode( + alt.X(_M_COLS.value, title="", scale=alt.Scale(domain=[0.0, 1.0])), + alt.Y(_M_COLS.metric, title=""), + alt.Color(_M_COLS.class_name), + tooltip=[ + alt.Tooltip(_M_COLS.metric, title="Metric"), + alt.Tooltip(_M_COLS.value, title="Value", format=",.3f"), + ], + opacity=alt.condition(class_selection, alt.value(1), alt.value(0.1)), + ) + .properties(height=300) + ) + # Average + mean_bars = class_bars.encode( + alt.X(f"mean({_M_COLS.value}):Q", title="", scale=alt.Scale(domain=[0.0, 1.0])), + alt.Y("group:N", title=""), + alt.Color("average:N"), + tooltip=[ + alt.Tooltip("group:N", title="Metric"), + alt.Tooltip(f"mean({_M_COLS.value}):Q", title="Value", format=",.3f"), + ], + ) + bar_chart = (class_bars + mean_bars).add_selection(class_selection) + + class_precisions = ( + alt.Chart(precisions, title="Precision-Recall Curve") + .mark_line(point=True) + .encode( + alt.X(_PR_COLS.recall, title="Recall", scale=alt.Scale(domain=[0.0, 1.0])), + alt.Y(_PR_COLS.precision, scale=alt.Scale(domain=[0.0, 1.0])), + alt.Color(_PR_COLS.class_name), + tooltip=[ + alt.Tooltip(_PR_COLS.class_name), + alt.Tooltip(_PR_COLS.recall, title="Recall"), + alt.Tooltip(_PR_COLS.precision, title="Precision", format=",.3f"), + ], + opacity=alt.condition(class_selection, alt.value(1.0), alt.value(0.2)), + ) + .properties(height=300) + ) + + mean_precisions = ( + class_precisions.transform_calculate(average="'average'") + .mark_line(point=True) + .encode( + alt.X(_PR_COLS.recall), + alt.Y(f"average({_PR_COLS.precision}):Q"), + alt.Color("average:N"), + tooltip=[ + alt.Tooltip("average:N", title="Aggregate"), + alt.Tooltip(_PR_COLS.recall, title="Recall"), + alt.Tooltip(f"average({_PR_COLS.precision})", title="Avg. precision", format=",.3f"), + ], + ) + ) + precision_chart = (class_precisions + mean_precisions).add_selection(class_selection) + return bar_chart | precision_chart diff --git a/src/encord_active/lib/charts/scopes.py b/src/encord_active/lib/charts/scopes.py new file mode 100644 index 000000000..8f9013648 --- /dev/null +++ b/src/encord_active/lib/charts/scopes.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class PredictionMatchScope(Enum): + TRUE_POSITIVES = "true positives" + FALSE_POSITIVES = "false positives" + FALSE_NEGATIVES = "false negatives" diff --git a/src/encord_active/app/common/colors.py b/src/encord_active/lib/common/colors.py similarity index 100% rename from src/encord_active/app/common/colors.py rename to src/encord_active/lib/common/colors.py diff --git a/src/encord_active/lib/common/image_utils.py b/src/encord_active/lib/common/image_utils.py new file mode 100644 index 000000000..a020aeae6 --- /dev/null +++ b/src/encord_active/lib/common/image_utils.py @@ -0,0 +1,252 @@ +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import pandas as pd +from pandas import Series +from pandera.typing import DataFrame + +from encord_active.lib.common.colors import Color, hex_to_rgb +from encord_active.lib.common.utils import get_du_size, rle_to_binary_mask +from encord_active.lib.model_predictions.reader import PredictionMatchSchema + + +def get_polygon_thickness(img_w: int): + t = max(1, int(img_w / 100)) + return t + + +def get_bbox_csv(row: pd.Series) -> np.ndarray: + """ + Used to get a bounding box "polygon" for plotting. + The input should be a row from a LabelSchema (or descendants thereof). + """ + x1, y1, x2, y2 = row["x1"], row["y1"], row["x2"], row["y2"] + return np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]]).reshape((-1, 1, 2)).astype(int) + + +def draw_object( + image: np.ndarray, + row: pd.Series, + mask_opacity: float = 0.5, + color: Union[Color, str] = Color.PURPLE, + with_box: bool = False, +): + """ + The input should be a row from a LabelSchema (or descendants thereof). + """ + isClosed = True + thickness = get_polygon_thickness(image.shape[1]) + + hex_color = color.value if isinstance(color, Color) else color + _color: Tuple[int, ...] = hex_to_rgb(hex_color) + _color_outline: Tuple[int, ...] = hex_to_rgb(hex_color, lighten=-0.5) + if isinstance(row["rle"], str): + if with_box: + box = get_bbox_csv(row) + image = cv2.polylines(image, [box], isClosed, _color, thickness // 2, lineType=cv2.LINE_8) + + mask = rle_to_binary_mask(eval(row["rle"])) + + # Draw contour line + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + image = cv2.polylines(image, contours, isClosed, _color_outline, thickness, lineType=cv2.LINE_8) + + # Fill polygon with opacity + patch = np.zeros_like(image) + mask_select = mask == 1 + patch[mask_select] = _color + image[mask_select] = cv2.addWeighted(image, (1 - mask_opacity), patch, mask_opacity, 0)[mask_select] + return image + + +def show_image_with_predictions_and_label( + label: pd.Series, + predictions: DataFrame[PredictionMatchSchema], + data_dir: Path, + label_color: Color = Color.RED, + mask_opacity=0.5, + class_colors: Optional[Dict[int, str]] = None, +): + """ + Displays all predictions in the frame and the one label specified by the `label` + argument. The label will be colored with the `label_color` provided and the + predictions with a `purple` color unless `class_colors` are provided as a dict of + (`class_id`, ``) pairs. + + :param label: The csv row of the false-negative label to display (from a LabelSchema). + :param predictions: All the predictions on the same image with the samme predicted class (from a PredictionSchema). + :param data_dir: The data directory of the project + :param label_color: The hex color to use when drawing the prediction. + :param class_colors: Dict of [class_id, hex_color] pairs. + """ + class_colors = class_colors or {} + image = load_or_fill_image(label, data_dir) + + for _, pred in predictions.iterrows(): + color = class_colors.get(pred["class_id"], Color.PURPLE) + image = draw_object(image, pred, mask_opacity=mask_opacity, color=color) + + return draw_object(image, label, mask_opacity=mask_opacity, color=label_color, with_box=True) + + +def show_image_and_draw_polygons( + row: Union[Series, str], data_dir: Path, draw_polygons: bool = True, skip_object_hash: bool = False +) -> np.ndarray: + image = load_or_fill_image(row, data_dir) + is_closed = True + thickness = get_polygon_thickness(image.shape[1]) + + img_h, img_w = image.shape[:2] + if draw_polygons: + for color, geometry in get_geometries(row, img_h, img_w, data_dir, skip_object_hash=skip_object_hash): + image = cv2.polylines(image, [geometry], is_closed, hex_to_rgb(color), thickness) + + return image + + +def load_or_fill_image(row: Union[pd.Series, str], data_dir: Path) -> np.ndarray: + """ + Tries to read the infered image path. If not possible, generates a white image + and indicates what the error seemd to be embedded in the image. + :param row: A csv row from either a metric, a prediction, or a label csv file. + :return: Numpy / cv2 image. + """ + read_error = False + key = __get_key(row) + + img_pth: Optional[Path] = key_to_image_path(key, data_dir) + + if img_pth and img_pth.is_file(): + try: + image = cv2.imread(img_pth.as_posix()) + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + except cv2.error: + pass + + # Read not successful, so tell the user why + error_text = "Image not found" if not img_pth else "File seems broken" + + _, du_hash, *_ = key.split("_") + lr = json.loads(key_to_lr_path(key, data_dir).read_text(encoding="utf-8")) + + h, w = get_du_size(lr["data_units"].get(du_hash, {}), None) or (600, 900) + + image = np.ones((h, w, 3), dtype=np.uint8) * 255 + image[:4, :] = [255, 0, 0] + image[-4:, :] = [255, 0, 0] + image[:, :4] = [255, 0, 0] + image[:, -4:] = [255, 0, 0] + font = cv2.FONT_HERSHEY_SIMPLEX + pos = int(0.05 * min(w, h)) + cv2.putText(image, error_text, (pos, 2 * pos), font, w / 900, hex_to_rgb("#999999"), 2, cv2.LINE_AA) + + return image + + +def __get_key(row: Union[pd.Series, str]): + if isinstance(row, pd.Series): + if "identifier" not in row: + raise ValueError("A Series passed but the series doesn't contain 'identifier'") + return str(row["identifier"]) + elif isinstance(row, str): + return row + else: + raise Exception(f"Undefined row type {row}") + + +def __get_geometry(obj: dict, img_h: int, img_w: int) -> Optional[Tuple[str, np.ndarray]]: + """ + Convert Encord object dictionary to polygon coordinates used to draw geometries + with opencv. + + :param obj: the encord object dict + :param w: the image width + :param h: the image height + :return: The polygon coordinates + """ + + if obj["shape"] == "polygon": + p = obj["polygon"] + polygon = np.array([[p[str(i)]["x"] * img_w, p[str(i)]["y"] * img_h] for i in range(len(p))]) + elif obj["shape"] == "bounding_box": + b = obj["boundingBox"] + polygon = np.array( + [ + [b["x"] * img_w, b["y"] * img_h], + [(b["x"] + b["w"]) * img_w, b["y"] * img_h], + [(b["x"] + b["w"]) * img_w, (b["y"] + b["h"]) * img_h], + [b["x"] * img_w, (b["y"] + b["h"]) * img_h], + ] + ) + else: + return None + + polygon = polygon.reshape((-1, 1, 2)).astype(int) + return obj.get("color", Color.PURPLE.value), polygon + + +def get_geometries( + row: Union[pd.Series, str], img_h: int, img_w: int, data_dir: Path, skip_object_hash: bool = False +) -> List[Tuple[str, np.ndarray]]: + """ + Loads cached label row and computes geometries from the label row. + If the ``identifier`` in the ``row`` contains an object hash, only that object will + be drawn. If no object hash exists, all polygons / bboxes will be drawn. + :param row: the pandas row of the selected csv file. + :return: List of tuples of (hex color, polygon: [[x, y], ...]) + """ + key = __get_key(row) + _, du_hash, frame, *remainder = key.split("_") + + lr_pth = key_to_lr_path(key, data_dir) + with lr_pth.open("r") as f: + label_row = json.load(f) + + du = label_row["data_units"][du_hash] + + geometries = [] + objects = ( + du["labels"].get("objects", []) + if "video" not in du["data_type"] + else du["labels"][str(int(frame))].get("objects", []) + ) + + if remainder and not skip_object_hash: + # Return specific geometries + geometry_object_hashes = set(remainder) + for obj in objects: + if obj["objectHash"] in geometry_object_hashes: + geometries.append(__get_geometry(obj, img_h=img_h, img_w=img_w)) + else: + # Get all geometries + for obj in objects: + if obj["shape"] not in {"polygon", "bounding_box"}: + continue + geometries.append(__get_geometry(obj, img_h=img_h, img_w=img_w)) + + valid_geometries = list(filter(None, geometries)) + return valid_geometries + + +def key_to_lr_path(key: str, data_dir: Path) -> Path: + label_hash, *_ = key.split("_") + return data_dir / label_hash / "label_row.json" + + +def key_to_image_path(key: str, data_dir: Path) -> Optional[Path]: + """ + Infer image path from the identifier stored in the csv files. + :param key: the row["identifier"] from a csv row + :return: The associated image path if it exists or a path to a placeholder otherwise + """ + label_hash, du_hash, frame, *_ = key.split("_") + img_folder = data_dir / label_hash / "images" + + # check if it is a video frame + frame_pth = next(img_folder.glob(f"{du_hash}_{int(frame)}.*"), None) + if frame_pth is not None: + return frame_pth + return next(img_folder.glob(f"{du_hash}.*"), None) # So this is an img_group image diff --git a/src/encord_active/lib/common/iterator.py b/src/encord_active/lib/common/iterator.py index d8b10d8a1..bb5d12bfb 100644 --- a/src/encord_active/lib/common/iterator.py +++ b/src/encord_active/lib/common/iterator.py @@ -14,7 +14,7 @@ from loguru import logger from tqdm import tqdm -from encord_active.lib.common.project import Project +from encord_active.lib.project.project import Project class Iterator(Sized): diff --git a/src/encord_active/lib/common/tester.py b/src/encord_active/lib/common/tester.py deleted file mode 100644 index 65cbdda93..000000000 --- a/src/encord_active/lib/common/tester.py +++ /dev/null @@ -1,71 +0,0 @@ -import json -import logging -from enum import Enum -from pathlib import Path -from typing import Any, List, Type, Union - -from loguru import logger - -from encord_active.lib.common.iterator import DatasetIterator, Iterator -from encord_active.lib.common.metric import Metric -from encord_active.lib.common.utils import fetch_project_info -from encord_active.lib.common.writer import CSVMetricWriter, StatisticsObserver - -logger = logger.opt(colors=True) - - -def __get_value(o): - if isinstance(o, (float, int, str)): - return o - if isinstance(o, Enum): - return __get_value(o.value) - if isinstance(o, (list, tuple)): - return [__get_value(v) for v in o] - return None - - -def __get_object_attributes(obj: Any): - metric_properties = {v.lower(): __get_value(getattr(obj, v)) for v in dir(obj)} - metric_properties = {k: v for k, v in metric_properties.items() if (v is not None or k == "annotation_type")} - return metric_properties - - -@logger.catch() -def perform_test( - metrics: Union[Metric, List[Metric]], - data_dir: Path, - iterator_cls: Type[Iterator] = DatasetIterator, - use_cache_only: bool = False, - **kwargs, -): - all_tests: List[Metric] = metrics if isinstance(metrics, list) else [metrics] - - project = None if use_cache_only else fetch_project_info(data_dir) - iterator = iterator_cls(data_dir, project=project, **kwargs) - cache_dir = iterator.update_cache_dir(data_dir) - - for metric in all_tests: - logger.info(f"Running Metric {metric.TITLE.title()}") - unique_metric_name = metric.get_unique_name() - - stats = StatisticsObserver() - with CSVMetricWriter(cache_dir, iterator, prefix=unique_metric_name) as writer: - writer.attach(stats) - - try: - metric.test(iterator, writer) - except Exception as e: - logging.critical(e, exc_info=True) - - # Store meta-data about the scores. - meta_file = (cache_dir / "metrics" / f"{unique_metric_name}.meta.json").expanduser() - - with meta_file.open("w") as f: - json.dump( - { - **__get_object_attributes(metric), - **__get_object_attributes(stats), - }, - f, - indent=2, - ) diff --git a/src/encord_active/lib/common/utils.py b/src/encord_active/lib/common/utils.py index d1674a092..cc27385e2 100644 --- a/src/encord_active/lib/common/utils.py +++ b/src/encord_active/lib/common/utils.py @@ -1,3 +1,4 @@ +import json import os import shutil import warnings @@ -5,7 +6,7 @@ from concurrent.futures import as_completed from itertools import product from pathlib import Path -from typing import Any, Collection, Dict, List, Optional, Tuple, Union +from typing import Any, Collection, Dict, List, Optional, Tuple, TypedDict, Union import av import cv2 @@ -23,10 +24,51 @@ from encord_active.lib.coco.datastructure import CocoBbox -def fetch_project_meta(data_dir: Path) -> dict: +def load_json(json_file: Path) -> Optional[dict]: + if not json_file.exists(): + return None + + with json_file.open("r", encoding="utf-8") as f: + try: + return json.load(f) + except json.JSONDecodeError: + return None + + +class ProjectMeta(TypedDict): + project_description: str + project_hash: str + project_title: str + ssh_key_path: str + + +class ProjectNotFound(Exception): + """Exception raised when a path doesn't contain a valid project. + + Attributes: + project_dir -- path to a project directory + """ + + def __init__(self, project_dir): + self.project_dir = project_dir + super().__init__(f"Couldn't find meta file for project in `{project_dir}`") + + +def get_local_project(project_dir: Path) -> Project: + project_meta = fetch_project_meta(project_dir) + + ssh_key_path = Path(project_meta["ssh_key_path"]) + with open(ssh_key_path.expanduser(), "r", encoding="utf-8") as f: + key = f.read() + + client = EncordUserClient.create_with_ssh_private_key(key) + return client.get_project(project_meta.get("project_hash")) + + +def fetch_project_meta(data_dir: Path) -> ProjectMeta: meta_file = data_dir / "project_meta.yaml" if not meta_file.is_file(): - raise FileNotFoundError(f"Couldn't find meta file for project in {meta_file}") + raise ProjectNotFound(data_dir) with meta_file.open("r", encoding="utf-8") as f: return yaml.safe_load(f) diff --git a/src/encord_active/lib/common/writer.py b/src/encord_active/lib/common/writer.py index cea3ec71e..6b6f38a2b 100644 --- a/src/encord_active/lib/common/writer.py +++ b/src/encord_active/lib/common/writer.py @@ -1,4 +1,3 @@ -import csv import math from abc import ABC, abstractmethod from itertools import chain @@ -9,9 +8,6 @@ from encord_active.lib.common.iterator import Iterator -SCORE_CSV_FIELDS = ["identifier", "score", "description", "object_class", "annotator", "frame", "url"] -EMBEDDING_CSV_FIELDS = ["identifier", "embedding", "description", "object_class", "frame", "url"] - class MetricObserver(ABC): @abstractmethod @@ -104,117 +100,3 @@ def get_identifier( hashes = [lbl["objectHash"] if "objectHash" in lbl else lbl["featureHash"] for lbl in labels] return "_".join(chain([identifier], hashes)) return identifier - - -class CSVMetricWriter(CSVWriter): - def __init__(self, data_path: Path, iterator: Iterator, prefix: str): - filename = (data_path / "metrics" / f"{prefix}.csv").expanduser() - super(CSVMetricWriter, self).__init__(filename=filename, iterator=iterator) - - self.writer = csv.DictWriter(self.csv_file, fieldnames=SCORE_CSV_FIELDS) - self.writer.writeheader() - - def write( - self, - score: Union[float, int], - labels: Union[list[dict], dict, None] = None, - description: str = "", - label_class: Optional[str] = None, - label_hash: Optional[str] = None, - du_hash: Optional[str] = None, - frame: Optional[int] = None, - url: Optional[str] = None, - annotator: Optional[str] = None, - key: Optional[str] = None, # TODO obsolete parameter, remove from metrics first - ): - if not isinstance(score, (float, int)): - raise TypeError("score must be a float or int") - if isinstance(labels, list) and len(labels) == 0: - labels = None - elif isinstance(labels, dict): - labels = [labels] - - label_hash = self.iterator.label_hash if label_hash is None else label_hash - du_hash = self.iterator.du_hash if du_hash is None else du_hash - frame = self.iterator.frame if frame is None else frame - url = self.iterator.get_data_url() if url is None else url - - if labels is None: - label_class = "" if label_class is None else label_class - annotator = "" if annotator is None else annotator - else: - label_class = labels[0]["name"] if label_class is None else label_class - annotator = labels[0]["lastEditedBy"] if "lastEditedBy" in labels[0] else labels[0]["createdBy"] - - # remember to remove if clause (not its content) when writer's format (obj, score) is enforced on all metrics - # start hack - if key is None: - key = self.get_identifier(labels, label_hash, du_hash, frame) - # end hack - - row = { - "identifier": key, - "score": score, - "description": description, - "object_class": label_class, - "frame": frame, - "url": url, - "annotator": annotator, - } - - self.writer.writerow(row) - self.csv_file.flush() - - super().write(score) - - -class CSVEmbeddingWriter(CSVWriter): - def __init__(self, data_path: Path, iterator: Iterator, prefix: str): - filename = (data_path / "embeddings" / f"{prefix}.csv").expanduser() - super(CSVEmbeddingWriter, self).__init__(filename=filename, iterator=iterator) - - self.writer = csv.DictWriter(self.csv_file, fieldnames=EMBEDDING_CSV_FIELDS) - self.writer.writeheader() - - def write( - self, - value: Union[float, list], - labels: Union[list[dict], dict, None] = None, - description: str = "", - label_class: Optional[str] = None, - label_hash: Optional[str] = None, - du_hash: Optional[str] = None, - frame: Optional[int] = None, - url: Optional[str] = None, - ): - if not isinstance(value, (float, list)): - raise TypeError("value must be a float or list") - - if isinstance(labels, list) and len(labels) == 0: - labels = None - elif isinstance(labels, dict): - labels = [labels] - - label_hash = self.iterator.label_hash if label_hash is None else label_hash - du_hash = self.iterator.du_hash if du_hash is None else du_hash - frame = self.iterator.frame if frame is None else frame - url = self.iterator.get_data_url() if url is None else url - - if labels is None: - label_class = "" if label_class is None else label_class - else: - label_class = labels[0]["name"] if label_class is None else label_class - - row = { - "identifier": self.get_identifier(labels, label_hash, du_hash, frame), - "embedding": value, - "description": description, - "object_class": label_class, - "frame": frame, - "url": url, - } - - self.writer.writerow(row) - self.csv_file.flush() - - super().write(value) diff --git a/src/encord_active/lib/dataset/__init__.py b/src/encord_active/lib/dataset/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/lib/dataset/balance.py b/src/encord_active/lib/dataset/balance.py new file mode 100644 index 000000000..63b2c2c7f --- /dev/null +++ b/src/encord_active/lib/dataset/balance.py @@ -0,0 +1,79 @@ +import io +import json +from encodings import utf_8 +from typing import Dict, List +from zipfile import ZipFile + +import numpy as np +import pandas as pd +from tqdm import tqdm + +from encord_active.lib.coco.encoder import generate_coco_file +from encord_active.lib.metrics.utils import MetricData, load_metric_dataframe +from encord_active.lib.project.project_file_structure import ProjectFileStructure + + +def balance_dataframe(selected_metrics: List[MetricData], partition_sizes: Dict[str, int], seed: int) -> pd.DataFrame: + """ + Balances the dataset over the selected metrics and partition sizes. + Currently, it is done by random sampling. + + Args: + selected_metrics (List[MetricData]): The metrics to balance over. + partition_sizes (Dict[str,int]): The dictionary of partition names : partition sizes. + seed (int): The seed for the random sampling. + + Returns: + pd.Dataframe: A dataframe with the following columns: sample identifiers, metric values and allocated partition. + """ + # Collect metric dataframes + merged_df_list = [] + for i, m in enumerate(selected_metrics): + df = load_metric_dataframe(m, normalize=False).copy() + merged_df_list.append(df[["identifier", "score"]].rename(columns={"score": m.name})) + + # Merge all dataframes by identifier + merged_df = merged_df_list.pop() + for df_tmp in merged_df_list: + merged_df = merged_df.merge(df_tmp, on="identifier", how="outer") + + # Randomly sample from each partition and add column to merged_df + n_samples = len(merged_df) + selection_df = merged_df.copy() + merged_df["partition"] = "" + for partition_name, partition_size in [(k, v) for k, v in partition_sizes.items()][:-1]: + n_partition = int(np.floor(n_samples * partition_size / 100)) + partition_df = selection_df.sample(n=n_partition, replace=False, random_state=seed) + # Remove samples from selection_df + selection_df = selection_df[~selection_df["identifier"].isin(partition_df["identifier"])] + # Add partition column to merged_df + merged_df.loc[partition_df.index, "partition"] = partition_name + + # Assign the remaining samples to the last partition + merged_df.loc[merged_df["partition"] == "", "partition"] = list(partition_sizes.keys())[-1] + return merged_df + + +def get_partitions_zip(partition_dict: Dict[str, pd.DataFrame], project_file_structure: ProjectFileStructure) -> bytes: + """ + Creates a zip file with a COCO json object for each partition. + + Args: + partition_dict (Dict[str, pd.DataFrame]): A dictionary of partition names : partition dataframes. + + Returns: + bytes: The zip file as a byte array. + """ + zip_io = io.BytesIO() + with ZipFile(zip_io, mode="w") as zf: + partition_dict.pop("Unassigned", None) + for partition_name, partition in tqdm(partition_dict.items(), desc="Generating COCO files"): + coco_json = generate_coco_file( + partition, project_file_structure.project_dir, project_file_structure.ontology + ) + with zf.open(partition_name.replace(" ", "_").lower() + ".json", "w") as zip_file: + writer = utf_8.StreamWriter(zip_file) + json.dump(coco_json, writer) # type: ignore + zip_io.seek(0) + + return zip_io.read() diff --git a/src/encord_active/lib/dataset/outliers.py b/src/encord_active/lib/dataset/outliers.py new file mode 100644 index 000000000..7899ac253 --- /dev/null +++ b/src/encord_active/lib/dataset/outliers.py @@ -0,0 +1,54 @@ +from typing import NamedTuple, Optional, Tuple + +import pandera as pa +from pandera.typing import DataFrame, Series + +from encord_active.lib.metrics.utils import MetricSchema + + +class MetricWithDistanceSchema(MetricSchema): + dist_to_iqr: Optional[Series[float]] = pa.Field() + + +class IqrOutliers(NamedTuple): + n_moderate_outliers: int + n_severe_outliers: int + moderate_lb: float + moderate_ub: float + severe_lb: float + severe_ub: float + + +_COLUMNS = MetricWithDistanceSchema + + +def get_iqr_outliers( + input: DataFrame[MetricSchema], +) -> Optional[Tuple[DataFrame[MetricWithDistanceSchema], IqrOutliers]]: + if input.empty: + return None + + df = DataFrame[MetricWithDistanceSchema](input) + + moderate_iqr_scale = 1.5 + severe_iqr_scale = 2.5 + Q1 = df[_COLUMNS.score].quantile(0.25) + Q3 = df[_COLUMNS.score].quantile(0.75) + IQR = Q3 - Q1 + + df[_COLUMNS.dist_to_iqr] = 0 + df.loc[df[_COLUMNS.score] > Q3, _COLUMNS.dist_to_iqr] = (df[_COLUMNS.score] - Q3).abs() + df.loc[df[_COLUMNS.score] < Q1, _COLUMNS.dist_to_iqr] = (df[_COLUMNS.score] - Q1).abs() + df.sort_values(by=_COLUMNS.dist_to_iqr, inplace=True, ascending=False) + + moderate_lb, moderate_ub = Q1 - moderate_iqr_scale * IQR, Q3 + moderate_iqr_scale * IQR + severe_lb, severe_ub = Q1 - severe_iqr_scale * IQR, Q3 + severe_iqr_scale * IQR + + n_moderate_outliers = ( + ((severe_lb <= df[_COLUMNS.score]) & (df[_COLUMNS.score] < moderate_lb)) + | ((severe_ub >= df[_COLUMNS.score]) & (df[_COLUMNS.score] > moderate_ub)) + ).sum() + + n_severe_outliers = ((df[_COLUMNS.score] < severe_lb) | (df[_COLUMNS.score] > severe_ub)).sum() + + return (df, IqrOutliers(n_moderate_outliers, n_severe_outliers, moderate_lb, moderate_ub, severe_lb, severe_ub)) diff --git a/src/encord_active/lib/db/__init__.py b/src/encord_active/lib/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/lib/db/connection.py b/src/encord_active/lib/db/connection.py new file mode 100644 index 000000000..9bd95e295 --- /dev/null +++ b/src/encord_active/lib/db/connection.py @@ -0,0 +1,26 @@ +import sqlite3 +from pathlib import Path +from typing import Optional + +from encord_active.lib.project.project_file_structure import ProjectFileStructure + + +class DBConnection: + _project_file_structure: Optional[ProjectFileStructure] = None + + def __enter__(self): + self.conn = sqlite3.connect(self.project_file_structure().db) + return self.conn + + def __exit__(self, type, value, traceback): + self.conn.__exit__(type, value, traceback) + + @classmethod + def set_project_path(cls, project_path: Path): + cls._project_file_structure = ProjectFileStructure(project_path) + + @classmethod + def project_file_structure(cls): + if not cls._project_file_structure: + raise ConnectionError("`project_path` was not set, call `DBConnection.set_project_path('path/to/project')`") + return cls._project_file_structure diff --git a/src/encord_active/lib/db/helpers/__init__.py b/src/encord_active/lib/db/helpers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/lib/db/helpers/tags.py b/src/encord_active/lib/db/helpers/tags.py new file mode 100644 index 000000000..5789c2112 --- /dev/null +++ b/src/encord_active/lib/db/helpers/tags.py @@ -0,0 +1,20 @@ +from typing import Dict + +import pandas as pd + +from encord_active.lib.db.tags import Tags + + +def count_of_tags(df: pd.DataFrame) -> Dict[str, int]: + tag_list = Tags().all() + if not tag_list: + return {} + + tag_counts = df["tags"].value_counts() + + total_tags_count: Dict[str, int] = {tag.name: 0 for tag in tag_list} + for unique_list, count in tag_counts.items(): + for tag in unique_list: + total_tags_count[tag.name] += count + + return total_tags_count diff --git a/src/encord_active/app/db/merged_metrics.py b/src/encord_active/lib/db/merged_metrics.py similarity index 89% rename from src/encord_active/app/db/merged_metrics.py rename to src/encord_active/lib/db/merged_metrics.py index 804afea5d..e210d6fa6 100644 --- a/src/encord_active/app/db/merged_metrics.py +++ b/src/encord_active/lib/db/merged_metrics.py @@ -1,23 +1,22 @@ import json +from pathlib import Path from typing import List import pandas as pd -import streamlit as st from encord.project_ontology.classification_type import ClassificationType -from encord_active.app.common.state import MERGED_DATAFRAME -from encord_active.app.db.connection import DBConnection -from encord_active.app.db.tags import Tag, TagScope +from encord_active.lib.db.connection import DBConnection +from encord_active.lib.db.tags import Tag, TagScope TABLE_NAME = "merged_metrics" -def build_merged_metrics() -> pd.DataFrame: +def build_merged_metrics(metrics_path: Path) -> pd.DataFrame: main_df_images = pd.DataFrame(columns=["identifier"]) main_df_objects = pd.DataFrame(columns=["identifier"]) main_df_image_quality = pd.DataFrame() - for index in st.session_state.metric_dir.glob("*.csv"): + for index in metrics_path.glob("*.csv"): meta_pth = index.with_suffix(".meta.json") if not meta_pth.is_file(): continue @@ -75,9 +74,7 @@ def wrapper(*args, **kwargs): try: return fn(*args, **kwargs) except: - merged_metrics = st.session_state.get(MERGED_DATAFRAME) - if merged_metrics is None: - merged_metrics = build_merged_metrics() + merged_metrics = build_merged_metrics(DBConnection.project_file_structure().metrics) MergedMetrics().replace_all(merged_metrics) return fn(*args, **kwargs) diff --git a/src/encord_active/app/db/predictions.py b/src/encord_active/lib/db/predictions.py similarity index 93% rename from src/encord_active/app/db/predictions.py rename to src/encord_active/lib/db/predictions.py index 8c4e9564f..dcbc9afae 100644 --- a/src/encord_active/app/db/predictions.py +++ b/src/encord_active/lib/db/predictions.py @@ -6,10 +6,10 @@ import numpy as np from pydantic import BaseModel, Field, validator -from encord_active.lib.common.project import Project -from encord_active.lib.metrics.run_all import run_all_prediction_metrics +from encord_active.lib.metrics.execute import run_all_prediction_metrics from encord_active.lib.model_predictions.iterator import PredictionIterator -from encord_active.lib.model_predictions.prediction_writer import PredictionWriter +from encord_active.lib.model_predictions.writer import PredictionWriter +from encord_active.lib.project.project import Project RelativeFloat = Annotated[float, Field(ge=0, le=1)] diff --git a/src/encord_active/app/db/tags.py b/src/encord_active/lib/db/tags.py similarity index 70% rename from src/encord_active/app/db/tags.py rename to src/encord_active/lib/db/tags.py index 9b2c7fd74..36bb5e24a 100644 --- a/src/encord_active/app/db/tags.py +++ b/src/encord_active/lib/db/tags.py @@ -2,10 +2,8 @@ from sqlite3 import OperationalError from typing import Callable, List, NamedTuple -import streamlit as st - -from encord_active.app.common.state import ALL_TAGS -from encord_active.app.db.connection import DBConnection +from encord_active.lib.db.connection import DBConnection +from encord_active.lib.metrics.utils import MetricScope TABLE_NAME = "tags" @@ -20,6 +18,18 @@ class Tag(NamedTuple): scope: TagScope +SCOPE_EMOJI = { + TagScope.DATA.value: "πŸ–ΌοΈ", + TagScope.LABEL.value: "✏️", +} + +METRIC_SCOPE_TAG_SCOPES = { + MetricScope.DATA_QUALITY: {TagScope.DATA}, + MetricScope.LABEL_QUALITY: {TagScope.DATA, TagScope.LABEL}, + MetricScope.MODEL_QUALITY: {TagScope.DATA}, +} + + def ensure_existence(fn: Callable): def wrapper(*args, **kwargs): try: @@ -35,9 +45,6 @@ def wrapper(*args, **kwargs): ) """ ) - all_tags = st.session_state.get(ALL_TAGS) - if all_tags: - conn.executemany(f"INSERT INTO {TABLE_NAME} (name, scope) VALUES(?, ?) ", all_tags) return fn(*args, **kwargs) @@ -60,5 +67,12 @@ def all(self) -> List[Tag]: @ensure_existence def create_tag(self, tag: Tag): + stripped = tag.name.strip() + if not stripped: + raise ValueError("Empty tags are not allowed") + + if tag in self.all(): + raise ValueError("Tag already exists") + with DBConnection() as conn: return conn.execute(f"INSERT INTO {TABLE_NAME} (name, scope) VALUES(?, ?) ", tag) diff --git a/src/encord_active/lib/embeddings/cnn_embed.py b/src/encord_active/lib/embeddings/cnn.py similarity index 93% rename from src/encord_active/lib/embeddings/cnn_embed.py rename to src/encord_active/lib/embeddings/cnn.py index e8547880b..3d4e6da61 100644 --- a/src/encord_active/lib/embeddings/cnn_embed.py +++ b/src/encord_active/lib/embeddings/cnn.py @@ -16,6 +16,7 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_bbox_from_encord_label_object +from encord_active.lib.embeddings.utils import LabelEmbedding logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -100,17 +101,17 @@ def generate_cnn_embeddings(iterator: Iterator, filepath: str) -> None: last_edited_by = obj["lastEditedBy"] if "lastEditedBy" in obj.keys() else obj["createdBy"] - entry = { - "label_row": iterator.label_hash, - "data_unit": data_unit["data_hash"], - "frame": iterator.frame, - "objectHash": obj["objectHash"], - "lastEditedBy": last_edited_by, - "featureHash": obj["featureHash"], - "name": obj["name"], - "dataset_title": iterator.dataset_title, - "embedding": emb, - } + entry = LabelEmbedding( + label_row=iterator.label_hash, + data_unit=data_unit["data_hash"], + frame=iterator.frame, + objectHash=obj["objectHash"], + lastEditedBy=last_edited_by, + featureHash=obj["featureHash"], + name=obj["name"], + dataset_title=iterator.dataset_title, + embedding=emb, + ) collections.append(entry) @@ -178,7 +179,7 @@ def generate_cnn_classification_embeddings(iterator: Iterator, filepath: str) -> temp_entry["embedding"] = embedding - collections.append(temp_entry) + collections.append(LabelEmbedding(temp_entry)) # type: ignore with open(filepath, "wb") as f: pickle.dump(collections, f) diff --git a/src/encord_active/lib/embeddings/hu_embed.py b/src/encord_active/lib/embeddings/hu_moments.py similarity index 96% rename from src/encord_active/lib/embeddings/hu_embed.py rename to src/encord_active/lib/embeddings/hu_moments.py index 4a7e29a13..de61ae63b 100644 --- a/src/encord_active/lib/embeddings/hu_embed.py +++ b/src/encord_active/lib/embeddings/hu_moments.py @@ -6,7 +6,7 @@ from encord_active.lib.common.iterator import Iterator from encord_active.lib.common.utils import get_du_size, get_object_coordinates -from encord_active.lib.common.writer import CSVEmbeddingWriter +from encord_active.lib.embeddings.writer import CSVEmbeddingWriter logger = logging.getLogger(__name__) HU_FILENAME = "hu_moments-embeddings" diff --git a/src/encord_active/lib/embeddings/utils.py b/src/encord_active/lib/embeddings/utils.py new file mode 100644 index 000000000..d6461254a --- /dev/null +++ b/src/encord_active/lib/embeddings/utils.py @@ -0,0 +1,150 @@ +import os +import pickle +from pathlib import Path +from typing import List, Optional, Tuple, TypedDict + +import faiss +import numpy as np +from faiss import IndexFlatL2 + + +class LabelEmbedding(TypedDict): + label_row: str + data_unit: str + frame: int + objectHash: Optional[str] + lastEditedBy: str + featureHash: str + name: str + dataset_title: str + embedding: np.ndarray + + +def get_collections(embedding_name: str, embeddings_dir: Path) -> list[LabelEmbedding]: + embedding_path = embeddings_dir / embedding_name + collections = [] + if os.path.isfile(embedding_path): + with open(embedding_path, "rb") as f: + collections = pickle.load(f) + return collections + + +def get_collections_and_metadata(embedding_name: str, embeddings_dir: Path) -> Tuple[list[LabelEmbedding], dict]: + try: + collections = get_collections(embedding_name, embeddings_dir) + + embedding_metadata_file_name = "embedding_classifications_metadata.pkl" + embedding_metadata_path = embeddings_dir / embedding_metadata_file_name + if os.path.isfile(embedding_metadata_path): + with open(embedding_metadata_path, "rb") as f: + question_hash_to_collection_indexes_local = pickle.load(f) + else: + question_hash_to_collection_indexes_local = {} + + return collections, question_hash_to_collection_indexes_local + except Exception: + return [], {} + + +def get_key_from_index(collection: LabelEmbedding, question_hash: Optional[str] = None, has_annotation=True) -> str: + label_hash = collection["label_row"] + du_hash = collection["data_unit"] + frame_idx = int(collection["frame"]) + + if not has_annotation: + key = f"{label_hash}_{du_hash}_{frame_idx:05d}" + else: + if question_hash: + key = f"{label_hash}_{du_hash}_{frame_idx:05d}_{question_hash}" + else: + object_hash = collection["objectHash"] + key = f"{label_hash}_{du_hash}_{frame_idx:05d}_{object_hash}" + + return key + + +# TODO: remove if unused +def get_identifier_to_neighbors( + collections: list[LabelEmbedding], nearest_indexes: np.ndarray, has_annotation=True +) -> dict[str, list]: + nearest_neighbors = {} + n, k = nearest_indexes.shape + for i in range(n): + key = get_key_from_index(collections[i], has_annotation=has_annotation) + temp_list = [] + for j in range(1, k): + temp_list.append( + { + "key": get_key_from_index(collections[nearest_indexes[i, j]], has_annotation=has_annotation), + "name": collections[nearest_indexes[i, j]].get("name", "Does not have a label"), + } + ) + nearest_neighbors[key] = temp_list + return nearest_neighbors + + +# TODO: remove if unused +def convert_to_indexes(collections, question_hash_to_collection_indexes): + embedding_databases, indexes = {}, {} + + for question_hash in question_hash_to_collection_indexes: + selected_collections = [collections[i] for i in question_hash_to_collection_indexes[question_hash]] + + if len(selected_collections) > 10: + embedding_database = np.stack(list(map(lambda x: x["embedding"], selected_collections))) + + index = faiss.IndexFlatL2(embedding_database.shape[1]) + index.add(embedding_database) # pylint: disable=no-value-for-parameter + + embedding_databases[question_hash] = embedding_database + indexes[question_hash] = index + + return embedding_databases, indexes + + +def get_faiss_index_image( + collections: list[LabelEmbedding], question_hash_to_collection_indexes: dict +) -> dict[str, IndexFlatL2]: + indexes = {} + + for question_hash in question_hash_to_collection_indexes: + selected_collections = [collections[i] for i in question_hash_to_collection_indexes[question_hash]] + + if len(selected_collections) > 10: + embedding_database = np.stack(list(map(lambda x: x["embedding"], selected_collections))) + + index = faiss.IndexFlatL2(embedding_database.shape[1]) + index.add(embedding_database) # pylint: disable=no-value-for-parameter + + indexes[question_hash] = index + + return indexes + + +def get_faiss_index_object(collections: list[LabelEmbedding]) -> IndexFlatL2: + """ + + Args: + _collections: Underscore is used to skip hashing for this object to make this function faster. + faiss_index_name: Since we skip hashing for collections, we need another parameter for memoization. + + Returns: Faiss Index object for searching embeddings. + + """ + embeddings_list: List[np.ndarray] = [x["embedding"] for x in collections] + embeddings = np.array(embeddings_list).astype(np.float32) + + if len(embeddings.shape) != 2: + return + + db_index = faiss.IndexFlatL2(embeddings.shape[1]) + db_index.add(embeddings) # pylint: disable=no-value-for-parameter + return db_index + + +def get_object_keys_having_similarities(collections: list[LabelEmbedding]) -> dict: + return {get_key_from_index(collection): i for i, collection in enumerate(collections)} + + +def get_image_keys_having_similarities(collections: list[LabelEmbedding]) -> dict: + return {get_key_from_index(collection, has_annotation=False): i for i, collection in enumerate(collections)} diff --git a/src/encord_active/lib/embeddings/writer.py b/src/encord_active/lib/embeddings/writer.py new file mode 100644 index 000000000..fed3c80ce --- /dev/null +++ b/src/encord_active/lib/embeddings/writer.py @@ -0,0 +1,60 @@ +import csv +from pathlib import Path +from typing import Optional, Union + +from encord_active.lib.common.iterator import Iterator +from encord_active.lib.common.writer import CSVWriter + +EMBEDDING_CSV_FIELDS = ["identifier", "embedding", "description", "object_class", "frame", "url"] + + +class CSVEmbeddingWriter(CSVWriter): + def __init__(self, data_path: Path, iterator: Iterator, prefix: str): + filename = (data_path / "embeddings" / f"{prefix}.csv").expanduser() + super(CSVEmbeddingWriter, self).__init__(filename=filename, iterator=iterator) + + self.writer = csv.DictWriter(self.csv_file, fieldnames=EMBEDDING_CSV_FIELDS) + self.writer.writeheader() + + def write( + self, + value: Union[float, list], + labels: Union[list[dict], dict, None] = None, + description: str = "", + label_class: Optional[str] = None, + label_hash: Optional[str] = None, + du_hash: Optional[str] = None, + frame: Optional[int] = None, + url: Optional[str] = None, + ): + if not isinstance(value, (float, list)): + raise TypeError("value must be a float or list") + + if isinstance(labels, list) and len(labels) == 0: + labels = None + elif isinstance(labels, dict): + labels = [labels] + + label_hash = self.iterator.label_hash if label_hash is None else label_hash + du_hash = self.iterator.du_hash if du_hash is None else du_hash + frame = self.iterator.frame if frame is None else frame + url = self.iterator.get_data_url() if url is None else url + + if labels is None: + label_class = "" if label_class is None else label_class + else: + label_class = labels[0]["name"] if label_class is None else label_class + + row = { + "identifier": self.get_identifier(labels, label_hash, du_hash, frame), + "embedding": value, + "description": description, + "object_class": label_class, + "frame": frame, + "url": url, + } + + self.writer.writerow(row) + self.csv_file.flush() + + super().write(value) diff --git a/src/encord_active/lib/encord/__init__.py b/src/encord_active/lib/encord/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/lib/encord/actions.py b/src/encord_active/lib/encord/actions.py new file mode 100644 index 000000000..0504506f2 --- /dev/null +++ b/src/encord_active/lib/encord/actions.py @@ -0,0 +1,250 @@ +import json +from pathlib import Path +from typing import Callable, NamedTuple, Optional + +import pandas as pd +from encord import Dataset, EncordUserClient, Project +from encord.constants.enums import DataType +from encord.exceptions import AuthorisationError +from encord.orm.dataset import Image, StorageLocation +from encord.utilities.label_utilities import construct_answer_dictionaries +from tqdm import tqdm + +from encord_active.lib.common.utils import fetch_project_meta +from encord_active.lib.project.project_file_structure import ProjectFileStructure + + +class DatasetCreationResult(NamedTuple): + hash: str + du_original_mapping: dict[str, dict] + + +class ProjectCreationResult(NamedTuple): + hash: str + + +class EncordActions: + def __init__(self, project_dir: Path): + self.project_meta = fetch_project_meta(project_dir) + self.project_file_structure = ProjectFileStructure(project_dir) + + try: + ssh_key_path = Path(self.project_meta["ssh_key_path"]).resolve() + original_project_hash = self.project_meta["project_hash"] + except Exception as e: + raise MissingProjectMetaAttribure(e.args[0], self.project_file_structure.project_meta) + + if not ssh_key_path.is_file(): + raise FileNotFoundError(f"No SSH file in location: {ssh_key_path}") + + self.user_client = EncordUserClient.create_with_ssh_private_key( + Path(ssh_key_path).expanduser().read_text(encoding="utf-8"), + ) + + self.original_project = self.user_client.get_project(original_project_hash) + try: + if self.original_project.project_hash == original_project_hash: + pass + except AuthorisationError: + raise AuthorisationError( + f'The user associated to the ssh key `{ssh_key_path}` does not have access to the project with project hash `{original_project_hash}`. Run "encord-active config set ssh_key_path /path/to/your/key_file" to set it.' + ) + + def create_dataset( + self, + dataset_title: str, + dataset_description: str, + filtered_dataset: pd.DataFrame, + progress_callback: Optional[Callable] = None, + ): + datasets_with_same_title = self.user_client.get_datasets(title_eq=dataset_title) + if len(datasets_with_same_title) > 0: + raise DatasetUniquenessError(dataset_title) + + new_du_to_original: dict[str, dict] = {} + self.user_client.create_dataset( + dataset_title=dataset_title, + dataset_type=StorageLocation.CORD_STORAGE, + dataset_description=dataset_description, + ) + dataset_hash: str = self.user_client.get_datasets(title_eq=dataset_title)[0]["dataset"].dataset_hash + dataset: Dataset = self.user_client.get_dataset(dataset_hash) + + # The following operation is for image groups (to upload them efficiently) + label_hash_to_data_units: dict[str, list] = {} + for _, item in tqdm(filtered_dataset.iterrows(), total=filtered_dataset.shape[0]): + label_row_hash, data_unit_hash, *_ = str(item["identifier"]).split("_") + label_hash_to_data_units.setdefault(label_row_hash, []).append(data_unit_hash) + + uploaded_label_rows: set = set() + for counter, (_, item) in enumerate(filtered_dataset.iterrows()): + label_row_hash, data_unit_hash, *_ = str(item["identifier"]).split("_") + label_row_structure = self.project_file_structure.label_row_structure(label_row_hash) + label_row = json.loads(label_row_structure.label_row_file.expanduser().read_text()) + + if label_row_hash not in uploaded_label_rows: + if label_row["data_type"] == DataType.IMAGE.value: + image_path = list(label_row_structure.images.glob(f"{data_unit_hash}.*"))[0] + uploaded_image: Image = dataset.upload_image( + file_path=image_path, title=label_row["data_units"][data_unit_hash]["data_title"] + ) + + new_du_to_original[uploaded_image["data_hash"]] = { + "label_row_hash": label_row_hash, + "data_unit_hash": data_unit_hash, + } + + elif label_row["data_type"] == DataType.IMG_GROUP.value: + image_paths = [] + image_names = [] + if len(label_hash_to_data_units[label_row_hash]) > 0: + for data_unit in label_hash_to_data_units[label_row_hash]: + img_path = list(label_row_structure.images.glob(f"{data_unit}.*"))[0] + image_paths.append(img_path.as_posix()) + image_names.append(img_path.name) + + # Unfortunately the following function does not return metadata related to the uploaded items + dataset.create_image_group(file_paths=image_paths, title=label_row["data_title"]) + + # Since create_image_group does not return info related to the uploaded images, we should find its + # data_hash in a hacky way + _update_mapping( + self.user_client, dataset_hash, label_row_hash, data_unit_hash, new_du_to_original + ) + + elif label_row["data_type"] == DataType.VIDEO.value: + video_path = list(label_row_structure.images.glob(f"{data_unit_hash}.*"))[0].as_posix() + + # Unfortunately the following function does not return metadata related to the uploaded items + dataset.upload_video( + file_path=video_path, title=label_row["data_units"][data_unit_hash]["data_title"] + ) + + # Since upload_video does not return info related to the uploaded video, we should find its data_hash + # in a hacky way + _update_mapping(self.user_client, dataset_hash, label_row_hash, data_unit_hash, new_du_to_original) + + else: + raise Exception( + f'Undefined data type {label_row["data_type"]} for label_row={label_row["label_hash"]}' + ) + + uploaded_label_rows.add(label_row_hash) + + if progress_callback: + progress_callback((counter + 1) / filtered_dataset.shape[0]) + + return DatasetCreationResult(dataset_hash, new_du_to_original) + + def create_project( + self, + dataset_creation_result: DatasetCreationResult, + project_title: str, + project_description: str, + progress_callback: Optional[Callable] = None, + ): + new_project_hash: str = self.user_client.create_project( + project_title=project_title, + dataset_hashes=[dataset_creation_result.hash], + project_description=project_description, + ontology_hash=self.original_project.get_project().ontology_hash, + ) + + new_project: Project = self.user_client.get_project(new_project_hash) + + # Copy labels from old project to new project + # Three things to copy: labels, object_answers, classification_answers + + all_new_label_rows = new_project.label_rows + for counter, new_label_row in enumerate(all_new_label_rows): + initiated_label_row: dict = new_project.create_label_row(new_label_row["data_hash"]) + original_data = dataset_creation_result.du_original_mapping[new_label_row["data_hash"]] + original_label_row = json.loads( + self.project_file_structure.label_row_structure( + original_data["label_row_hash"] + ).label_row_file.read_text( + encoding="utf-8", + ) + ) + + if initiated_label_row["data_type"] in [DataType.IMAGE.value, DataType.VIDEO.value]: + original_labels = original_label_row["data_units"][original_data["data_unit_hash"]]["labels"] + initiated_label_row["data_units"][new_label_row["data_hash"]]["labels"] = original_labels + initiated_label_row["object_answers"] = original_label_row["object_answers"] + initiated_label_row["classification_answers"] = original_label_row["classification_answers"] + + if original_labels != {}: + initiated_label_row = construct_answer_dictionaries(initiated_label_row) + new_project.save_label_row(initiated_label_row["label_hash"], initiated_label_row) + + elif initiated_label_row["data_type"] == DataType.IMG_GROUP.value: + object_hashes: set = set() + classification_hashes: set = set() + + # Currently img_groups are matched using data_title, it should be fixed after SDK update + for data_unit in initiated_label_row["data_units"].values(): + for original_data in original_label_row["data_units"].values(): + if original_data["data_hash"] == data_unit["data_title"].split(".")[0]: + data_unit["labels"] = original_data["labels"] + for obj in data_unit["labels"].get("objects", []): + object_hashes.add(obj["objectHash"]) + for classification in data_unit["labels"].get("classifications", []): + classification_hashes.add(classification["classificationHash"]) + + initiated_label_row["object_answers"] = original_label_row["object_answers"] + initiated_label_row["classification_answers"] = original_label_row["classification_answers"] + + # Remove unused object/classification answers + for object_hash in object_hashes: + initiated_label_row["object_answers"].pop(object_hash) + + for classification_hash in classification_hashes: + initiated_label_row["classification_answers"].pop(classification_hash) + + initiated_label_row = construct_answer_dictionaries(initiated_label_row) + new_project.save_label_row(initiated_label_row["label_hash"], initiated_label_row) + + # remove unused object and classification answers + + if progress_callback: + progress_callback((counter + 1) / len(all_new_label_rows)) + + return new_project + + +def _update_mapping( + user_client: EncordUserClient, new_dataset_hash: str, label_row_hash: str, data_unit_hash: str, out_mapping: dict +): + updated_dataset = user_client.get_dataset(new_dataset_hash) + for new_data_row in updated_dataset.data_rows: + if new_data_row["data_hash"] not in out_mapping: + out_mapping[new_data_row["data_hash"]] = { + "label_row_hash": label_row_hash, + "data_unit_hash": data_unit_hash, + } + return + + +class MissingProjectMetaAttribure(Exception): + """Exception raised when project metadata doesn't contain an attribute + + Attributes: + project_dir -- path to a project directory + """ + + def __init__(self, attribute: str, project_meta_file: Path): + self.attribute = attribute + self.project_meta_file = project_meta_file + super().__init__( + f"`{attribute}` not specified in the project meta data file `{project_meta_file.resolve().as_posix()}`" + ) + + +class DatasetUniquenessError(Exception): + """Exception raised when a dataset with the same title already exists""" + + def __init__(self, dataset_title: str): + self.dataset_title = dataset_title + super().__init__( + f"Dataset title '{dataset_title}' already exists in your list of datasets at Encord. Please use a different title." + ) diff --git a/src/encord_active/lib/encord/project.py b/src/encord_active/lib/encord/utils.py similarity index 100% rename from src/encord_active/lib/encord/project.py rename to src/encord_active/lib/encord/utils.py diff --git a/src/encord_active/lib/metrics/__init__.py b/src/encord_active/lib/metrics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/lib/metrics/example.py b/src/encord_active/lib/metrics/example.py index 7f2d33683..448051aed 100644 --- a/src/encord_active/lib/metrics/example.py +++ b/src/encord_active/lib/metrics/example.py @@ -1,8 +1,13 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -13,17 +18,17 @@ class ExampleMetric(Metric): DATA_TYPE = DataType.IMAGE ANNOTATION_TYPE = [AnnotationType.OBJECT.BOUNDING_BOX, AnnotationType.OBJECT.POLYGON] SHORT_DESCRIPTION = "Assigns same value and description to all objects." - LONG_DESCRIPTION = r"""For long descriptions, I can use Markdown to _format_ the text. - -I can, e.g., make a -[hyperlink](https://memegenerator.net/instance/74454868/europe-its-the-final-markdown) + LONG_DESCRIPTION = r"""For long descriptions, I can use Markdown to _format_ the text. + +I can, e.g., make a +[hyperlink](https://memegenerator.net/instance/74454868/europe-its-the-final-markdown) to the awesome paper that proposed the method. -Or use math to better explain the method: +Or use math to better explain the method: $$h_{\lambda}(x) = \frac{1}{x^\intercal x}$$ """ - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} logger.info("My custom logging") @@ -59,7 +64,7 @@ def test(self, iterator: Iterator, writer: CSVMetricWriter): import sys from pathlib import Path - from encord_active.lib.common.tester import perform_test + from encord_active.lib.metric.execute import perform_test path = sys.argv[1] perform_test(ExampleMetric(), data_dir=Path(path)) diff --git a/src/encord_active/lib/metrics/execute.py b/src/encord_active/lib/metrics/execute.py new file mode 100644 index 000000000..c4301c22e --- /dev/null +++ b/src/encord_active/lib/metrics/execute.py @@ -0,0 +1,149 @@ +import inspect +import json +import logging +import os +from enum import Enum +from importlib import import_module +from pathlib import Path +from typing import Any, Callable, List, Optional, Type, Union + +from encord.project_ontology.object_type import ObjectShape +from loguru import logger + +from encord_active.lib.common.iterator import DatasetIterator, Iterator +from encord_active.lib.common.utils import fetch_project_info +from encord_active.lib.common.writer import StatisticsObserver +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter + + +def get_metrics(module: Optional[Union[str, list[str]]] = None, filter_func=lambda x: True): + if module is None: + module = ["geometric", "heuristic", "semantic"] + elif isinstance(module, str): + module = [module] + + return [i for m in module for i in get_module_metrics(m, filter_func)] + + +def get_module_metrics(module_name: str, filter_func: Callable) -> List: + if __name__ == "__main__": + base_module_name = "" + else: + base_module_name = __name__[: __name__.rindex(".")] + "." # Remove "run_all" + + metrics = [] + path = os.path.join(os.path.dirname(__file__), *module_name.split(".")) + for file in os.listdir(path): + if file.endswith(".py") and not file.startswith("_") and not file.split(".")[0].endswith("_"): + logging.debug("Importing %s", file) + clsmembers = inspect.getmembers( + import_module(f"{base_module_name}{module_name}.{file.split('.')[0]}"), inspect.isclass + ) + for cls in clsmembers: + if issubclass(cls[1], Metric) and cls[1] != Metric and filter_func(cls[1]): + metrics.append((f"{base_module_name}{module_name}.{file.split('.')[0]}", f"{cls[0]}")) + + return metrics + + +def run_all_heuristic_metrics(): + run_metrics(filter_func=lambda x: x.METRIC_TYPE == MetricType.HEURISTIC) + + +def run_all_image_metrics(): + run_metrics(filter_func=lambda x: x.DATA_TYPE == DataType.IMAGE) + + +def run_all_polygon_metrics(): + run_metrics(filter_func=lambda x: x.ANNOTATION_TYPE in [AnnotationType.OBJECT.POLYGON, AnnotationType.ALL]) + + +def run_all_prediction_metrics(**kwargs): + # Return all metrics that apply to objects. + def filter(m: Metric): + at = m.ANNOTATION_TYPE + if isinstance(at, list): + for t in at: + if isinstance(t, ObjectShape): + return True + return False + else: + return isinstance(at, ObjectShape) + + run_metrics(filter_func=filter, **kwargs) + + +def run_metrics(filter_func: Callable = lambda x: True, **kwargs): + metrics: List[Metric] = list( + map( + lambda mod_cls: import_module(mod_cls[0]).__getattribute__(mod_cls[1])(), + get_metrics(filter_func=filter_func), + ) + ) + execute_metric(metrics, **kwargs) + + +def __get_value(o): + if isinstance(o, (float, int, str)): + return o + if isinstance(o, Enum): + return __get_value(o.value) + if isinstance(o, (list, tuple)): + return [__get_value(v) for v in o] + return None + + +def __get_object_attributes(obj: Any): + metric_properties = {v.lower(): __get_value(getattr(obj, v)) for v in dir(obj)} + metric_properties = {k: v for k, v in metric_properties.items() if (v is not None or k == "annotation_type")} + return metric_properties + + +logger = logger.opt(colors=True) + + +@logger.catch() +def execute_metric( + metrics: Union[Metric, List[Metric]], + data_dir: Path, + iterator_cls: Type[Iterator] = DatasetIterator, + use_cache_only: bool = False, + **kwargs, +): + all_tests: List[Metric] = metrics if isinstance(metrics, list) else [metrics] + + project = None if use_cache_only else fetch_project_info(data_dir) + iterator = iterator_cls(data_dir, project=project, **kwargs) + cache_dir = iterator.update_cache_dir(data_dir) + + for metric in all_tests: + logger.info(f"Running Metric {metric.TITLE.title()}") + unique_metric_name = metric.get_unique_name() + + stats = StatisticsObserver() + with CSVMetricWriter(cache_dir, iterator, prefix=unique_metric_name) as writer: + writer.attach(stats) + + try: + metric.execute(iterator, writer) + except Exception as e: + logging.critical(e, exc_info=True) + + # Store meta-data about the scores. + meta_file = (cache_dir / "metrics" / f"{unique_metric_name}.meta.json").expanduser() + + with meta_file.open("w") as f: + json.dump( + { + **__get_object_attributes(metric), + **__get_object_attributes(stats), + }, + f, + indent=2, + ) diff --git a/src/encord_active/lib/metrics/geometric/annotation_duplicates.py b/src/encord_active/lib/metrics/geometric/annotation_duplicates.py index d233c34bb..6d82b56ae 100644 --- a/src/encord_active/lib/metrics/geometric/annotation_duplicates.py +++ b/src/encord_active/lib/metrics/geometric/annotation_duplicates.py @@ -1,9 +1,14 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType from encord_active.lib.common.utils import get_iou, get_polygon -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -11,8 +16,8 @@ class AnnotationDuplicates(Metric): TITLE = "Annotation Duplicates" SHORT_DESCRIPTION = "Ranks annotations by how likely they are to represent the same object" - LONG_DESCRIPTION = r"""Ranks annotations by how likely they are to represent the same object. -> [Jaccard similarity coefficient](https://en.wikipedia.org/wiki/Jaccard_index) + LONG_DESCRIPTION = r"""Ranks annotations by how likely they are to represent the same object. +> [Jaccard similarity coefficient](https://en.wikipedia.org/wiki/Jaccard_index) is used to measure closeness of two annotations.""" METRIC_TYPE = MetricType.GEOMETRIC DATA_TYPE = DataType.IMAGE @@ -27,7 +32,7 @@ def __init__(self, threshold: float = 0.6): super().__init__() self.threshold = threshold - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False diff --git a/src/encord_active/lib/metrics/geometric/hu_static.py b/src/encord_active/lib/metrics/geometric/hu_static.py index c2ab74b49..ef311bdfb 100644 --- a/src/encord_active/lib/metrics/geometric/hu_static.py +++ b/src/encord_active/lib/metrics/geometric/hu_static.py @@ -5,10 +5,16 @@ from sklearn.decomposition import PCA from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType from encord_active.lib.common.utils import get_object_coordinates -from encord_active.lib.common.writer import CSVEmbeddingWriter, CSVMetricWriter -from encord_active.lib.embeddings.hu_embed import get_hu_embeddings +from encord_active.lib.embeddings.hu_moments import get_hu_embeddings +from encord_active.lib.embeddings.writer import CSVEmbeddingWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -29,15 +35,15 @@ def compute_cls_distances(embeddings: np.ndarray, labels: np.ndarray) -> np.ndar class HuMomentsStatic(Metric): TITLE = "Shape outlier detection" SHORT_DESCRIPTION = "Calculates potential outliers by polygon shape." - LONG_DESCRIPTION = r"""Computes the Euclidean distance between the polygons' - [Hu moments](https://en.wikipedia.org/wiki/Image_moment) for each class and + LONG_DESCRIPTION = r"""Computes the Euclidean distance between the polygons' + [Hu moments](https://en.wikipedia.org/wiki/Image_moment) for each class and the prototypical class moments.""" SCORE_NORMALIZATION = True METRIC_TYPE = MetricType.GEOMETRIC DATA_TYPE = DataType.IMAGE ANNOTATION_TYPE = [AnnotationType.OBJECT.POLYGON] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} hu_moments_df = get_hu_embeddings(iterator, force=True) diff --git a/src/encord_active/lib/metrics/geometric/hu_temporal.py b/src/encord_active/lib/metrics/geometric/hu_temporal.py index 138ba3b91..9cd8332a6 100644 --- a/src/encord_active/lib/metrics/geometric/hu_temporal.py +++ b/src/encord_active/lib/metrics/geometric/hu_temporal.py @@ -4,9 +4,14 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter -from encord_active.lib.embeddings.hu_embed import get_hu_embeddings +from encord_active.lib.embeddings.hu_moments import get_hu_embeddings +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -43,15 +48,15 @@ def score(self, key: str, moments: np.ndarray): class HuMomentsTemporalMetric(Metric): TITLE = "Polygon Shape Similarity" SHORT_DESCRIPTION = "Ranks objects by how similar they are to their instances in previous frames." - LONG_DESCRIPTION = r"""Ranks objects by how similar they are to their instances in previous frames -based on [Hu moments](https://en.wikipedia.org/wiki/Image_moment). The more an object's shape changes, + LONG_DESCRIPTION = r"""Ranks objects by how similar they are to their instances in previous frames +based on [Hu moments](https://en.wikipedia.org/wiki/Image_moment). The more an object's shape changes, the lower its score will be.""" SCORE_NORMALIZATION = True METRIC_TYPE = MetricType.GEOMETRIC DATA_TYPE = DataType.SEQUENCE ANNOTATION_TYPE = [AnnotationType.OBJECT.POLYGON] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False diff --git a/src/encord_active/lib/metrics/geometric/image_border_closeness.py b/src/encord_active/lib/metrics/geometric/image_border_closeness.py index e36b4de30..4618c03a1 100644 --- a/src/encord_active/lib/metrics/geometric/image_border_closeness.py +++ b/src/encord_active/lib/metrics/geometric/image_border_closeness.py @@ -1,9 +1,14 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType from encord_active.lib.common.utils import get_object_coordinates -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -22,7 +27,7 @@ class ImageBorderCloseness(Metric): AnnotationType.OBJECT.SKELETON, ] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False diff --git a/src/encord_active/lib/metrics/geometric/object_size.py b/src/encord_active/lib/metrics/geometric/object_size.py index 6eff7b8ba..38bdc9d6f 100644 --- a/src/encord_active/lib/metrics/geometric/object_size.py +++ b/src/encord_active/lib/metrics/geometric/object_size.py @@ -2,13 +2,18 @@ from shapely.ops import unary_union from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType from encord_active.lib.common.utils import ( get_bbox_from_encord_label_object, get_du_size, get_polygon, ) -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -35,7 +40,7 @@ class RelativeObjectAreaMetric(Metric): AnnotationType.OBJECT.POLYGON, ] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False @@ -65,7 +70,7 @@ class OccupiedTotalAreaMetric(Metric): AnnotationType.OBJECT.POLYGON, ] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False @@ -101,7 +106,7 @@ class AbsoluteObjectAreaMetric(Metric): AnnotationType.OBJECT.POLYGON, ] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False @@ -137,7 +142,7 @@ class ObjectAspectRatioMetric(Metric): AnnotationType.OBJECT.POLYGON, ] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False diff --git a/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py b/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py index 361cbdf95..a67c12af5 100644 --- a/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py +++ b/src/encord_active/lib/metrics/geometric/occlusion_detection_video.py @@ -4,8 +4,13 @@ from tqdm import tqdm from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -44,7 +49,7 @@ def get_description_from_occlusion(self, distance: float) -> str: else: return "There is no occlusion" - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} videos: dict[str, dict[str, dict]] = {} diff --git a/src/encord_active/lib/metrics/heuristic/_annotation_time.py b/src/encord_active/lib/metrics/heuristic/_annotation_time.py index f13aaaf6e..ebb285bcb 100644 --- a/src/encord_active/lib/metrics/heuristic/_annotation_time.py +++ b/src/encord_active/lib/metrics/heuristic/_annotation_time.py @@ -1,8 +1,13 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -18,7 +23,7 @@ class AnnotationTimeMetric(Metric): ANNOTATION_TYPE = AnnotationType.ALL SCORE_NORMALIZATION = True - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): found_any = False for data_unit, img_pth in iterator.iterate(desc="Computing annotation times"): diff --git a/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py b/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py index b0d369d70..4e698ac4f 100644 --- a/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py +++ b/src/encord_active/lib/metrics/heuristic/high_iou_changing_classes.py @@ -1,9 +1,14 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType from encord_active.lib.common.utils import get_iou, get_polygon -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -14,12 +19,12 @@ class HighIOUChangingClasses(Metric): DATA_TYPE = DataType.SEQUENCE ANNOTATION_TYPE = [AnnotationType.OBJECT.BOUNDING_BOX, AnnotationType.OBJECT.POLYGON] SHORT_DESCRIPTION = "Looks for overlapping objects with different classes (across frames)." - LONG_DESCRIPTION = r"""This algorithm looks for overlapping objects in consecutive -frames that have different classes. Furthermore, if classes are the same for overlapping objects but have different + LONG_DESCRIPTION = r"""This algorithm looks for overlapping objects in consecutive +frames that have different classes. Furthermore, if classes are the same for overlapping objects but have different track-ids, they will be flagged as potential inconsistencies in tracks. -**Example 1:** +**Example 1:** ``` Frame 1 Frame 2 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” @@ -34,7 +39,7 @@ class HighIOUChangingClasses(Metric): ``` `Dog:1` will be flagged as potentially wrong class, because it overlaps with `CAT:1`. -**Example 2:** +**Example 2:** ``` Frame 1 Frame 2 β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” @@ -55,7 +60,7 @@ def __init__(self, threshold: float = 0.8): super(HighIOUChangingClasses, self).__init__() self.threshold = threshold - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False found_valid = False diff --git a/src/encord_active/lib/metrics/heuristic/img_features.py b/src/encord_active/lib/metrics/heuristic/img_features.py index e15b01ca2..5a79f7ab2 100644 --- a/src/encord_active/lib/metrics/heuristic/img_features.py +++ b/src/encord_active/lib/metrics/heuristic/img_features.py @@ -4,8 +4,13 @@ import numpy as np from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter def iterate_with_rank_fn( @@ -29,9 +34,9 @@ def iterate_with_rank_fn( class ContrastMetric(Metric): TITLE = "Contrast" SHORT_DESCRIPTION = "Ranks images by their contrast." - LONG_DESCRIPTION = r"""Ranks images by their contrast. + LONG_DESCRIPTION = r"""Ranks images by their contrast. -Contrast is computed as the standard deviation of the pixel values. +Contrast is computed as the standard deviation of the pixel values. """ SCORE_NORMALIZATION = True METRIC_TYPE = MetricType.HEURISTIC @@ -42,7 +47,7 @@ class ContrastMetric(Metric): def rank_by_contrast(image): return image.std() / 255 - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): return iterate_with_rank_fn(iterator, writer, self.rank_by_contrast, self.TITLE) @@ -130,7 +135,7 @@ def rank_by_hsv_filtering(self, image): return ratio - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): return iterate_with_rank_fn( iterator, writer, self.rank_by_hsv_filtering, self.TITLE, color_space=cv2.COLOR_BGR2HSV ) @@ -171,7 +176,7 @@ class BrightnessMetric(Metric): def rank_by_brightness(image): return image.mean() / 255 - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): return iterate_with_rank_fn(iterator, writer, self.rank_by_brightness, self.TITLE) @@ -181,7 +186,7 @@ class SharpnessMetric(Metric): LONG_DESCRIPTION = r"""Ranks images by their sharpness. Sharpness is computed by applying a Laplacian filter to each image and computing the -variance of the output. In short, the score computes "the amount of edges" in each +variance of the output. In short, the score computes "the amount of edges" in each image. ```python @@ -197,7 +202,7 @@ class SharpnessMetric(Metric): def rank_by_sharpness(image): return cv2.Laplacian(image, cv2.CV_64F).var() - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): return iterate_with_rank_fn(iterator, writer, self.rank_by_sharpness, self.TITLE) @@ -207,7 +212,7 @@ class BlurMetric(Metric): LONG_DESCRIPTION = r"""Ranks images by their blurriness. Blurriness is computed by applying a Laplacian filter to each image and computing the -variance of the output. In short, the score computes "the amount of edges" in each +variance of the output. In short, the score computes "the amount of edges" in each image. Note that this is $1 - \text{sharpness}$. ```python @@ -223,16 +228,16 @@ class BlurMetric(Metric): def rank_by_blur(image): return 1 - cv2.Laplacian(image, cv2.CV_64F).var() - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): return iterate_with_rank_fn(iterator, writer, self.rank_by_blur, self.TITLE) class AspectRatioMetric(Metric): TITLE = "Aspect Ratio" SHORT_DESCRIPTION = "Ranks images by their aspect ratio (width/height)." - LONG_DESCRIPTION = r"""Ranks images by their aspect ratio (width/height). + LONG_DESCRIPTION = r"""Ranks images by their aspect ratio (width/height). -Aspect ratio is computed as the ratio of image width to image height. +Aspect ratio is computed as the ratio of image width to image height. """ METRIC_TYPE = MetricType.HEURISTIC DATA_TYPE = DataType.IMAGE @@ -242,16 +247,16 @@ class AspectRatioMetric(Metric): def rank_by_aspect_ratio(image): return image.shape[1] / image.shape[0] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): return iterate_with_rank_fn(iterator, writer, self.rank_by_aspect_ratio, self.TITLE) class AreaMetric(Metric): TITLE = "Area" SHORT_DESCRIPTION = "Ranks images by their area (width*height)." - LONG_DESCRIPTION = r"""Ranks images by their area (width*height). + LONG_DESCRIPTION = r"""Ranks images by their area (width*height). -Area is computed as the product of image width and image height. +Area is computed as the product of image width and image height. """ METRIC_TYPE = MetricType.HEURISTIC DATA_TYPE = DataType.IMAGE @@ -261,5 +266,5 @@ class AreaMetric(Metric): def rank_by_area(image): return image.shape[0] * image.shape[1] - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): return iterate_with_rank_fn(iterator, writer, self.rank_by_area, self.TITLE) diff --git a/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py b/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py index 15c25b3ef..e629a3cfa 100644 --- a/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py +++ b/src/encord_active/lib/metrics/heuristic/missing_objects_and_wrong_tracks.py @@ -8,9 +8,14 @@ from shapely.geometry import Polygon from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType from encord_active.lib.common.utils import get_iou, get_polygon -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) @@ -36,11 +41,11 @@ class MissingObjectsMetric(Metric): DATA_TYPE = DataType.SEQUENCE ANNOTATION_TYPE = [AnnotationType.OBJECT.BOUNDING_BOX, AnnotationType.OBJECT.POLYGON] SHORT_DESCRIPTION = "Identifies missing objects and broken tracks based on object overlaps." - LONG_DESCRIPTION = r"""Identifies missing objects by comparing object overlaps based -on a running window. + LONG_DESCRIPTION = r"""Identifies missing objects by comparing object overlaps based +on a running window. -**Case 1:** -If an intermediate frame (frame $i$) doesn't include an object in the same +**Case 1:** +If an intermediate frame (frame $i$) doesn't include an object in the same region, as the two surrounding framge ($i-1$ and $i+1$), it is flagged. ``` @@ -58,7 +63,7 @@ class MissingObjectsMetric(Metric): ``` Frame $i$ will be flagged as potentially missing an object. -**Case 2:** +**Case 2:** If objects of the same class overlap in three consecutive frames ($i-1$, $i$, and $i+1$) but do not share object hash, they will be flagged as a potentially broken track. @@ -82,7 +87,7 @@ def __init__(self, threshold: float = 0.5): super().__init__() self.threshold = threshold - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): valid_annotation_types = {annotation_type.value for annotation_type in self.ANNOTATION_TYPE} found_any = False found_valid = False diff --git a/src/encord_active/lib/metrics/heuristic/object_counting.py b/src/encord_active/lib/metrics/heuristic/object_counting.py index 6c71eff7d..c5e01246d 100644 --- a/src/encord_active/lib/metrics/heuristic/object_counting.py +++ b/src/encord_active/lib/metrics/heuristic/object_counting.py @@ -1,6 +1,11 @@ from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter class ObjectsCountMetric(Metric): @@ -11,7 +16,7 @@ class ObjectsCountMetric(Metric): DATA_TYPE = DataType.IMAGE ANNOTATION_TYPE = AnnotationType.ALL - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): for data_unit, img_pth in iterator.iterate(desc="Counting objects"): score = len(data_unit["labels"]["objects"]) if "objects" in data_unit["labels"] else 0 writer.write(score) diff --git a/src/encord_active/lib/common/metric.py b/src/encord_active/lib/metrics/metric.py similarity index 90% rename from src/encord_active/lib/common/metric.py rename to src/encord_active/lib/metrics/metric.py index 6c257b9c8..d0b382bf6 100644 --- a/src/encord_active/lib/common/metric.py +++ b/src/encord_active/lib/metrics/metric.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod from enum import Enum from hashlib import md5 -from typing import List, Optional, Union +from typing import List, Optional, TypedDict, Union from encord.project_ontology.classification_type import ClassificationType from encord.project_ontology.object_type import ObjectShape from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.writer import CSVMetricWriter class MetricType(Enum): @@ -35,9 +35,25 @@ class EmbeddingType(Enum): NONE = "none" +class MetricMetadata(TypedDict): + annotation_type: Optional[List[Union[ObjectShape, ClassificationType]]] + data_type: DataType + long_description: str + metric_type: MetricType + needs_images: bool + score_normalization: bool + short_description: str + title: str + threshold: float + max_value: Optional[float] + mean_value: float + min_value: Optional[float] + num_rows: int + + class Metric(ABC): @abstractmethod - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): """ This is where you should perform your data indexing. diff --git a/src/encord_active/lib/metrics/run_all.py b/src/encord_active/lib/metrics/run_all.py deleted file mode 100644 index b1166c7e3..000000000 --- a/src/encord_active/lib/metrics/run_all.py +++ /dev/null @@ -1,77 +0,0 @@ -import inspect -import logging -import os -from importlib import import_module -from typing import Callable, List, Optional, Union - -from encord.project_ontology.object_type import ObjectShape - -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.tester import perform_test - - -def get_metrics(module: Optional[Union[str, list[str]]] = None, filter_func=lambda x: True): - if module is None: - module = ["geometric", "heuristic", "semantic"] - elif isinstance(module, str): - module = [module] - - return [i for m in module for i in get_module_metrics(m, filter_func)] - - -def get_module_metrics(module_name: str, filter_func: Callable) -> List: - if __name__ == "__main__": - base_module_name = "" - else: - base_module_name = __name__[: __name__.rindex(".")] + "." # Remove "run_all" - - metrics = [] - path = os.path.join(os.path.dirname(__file__), *module_name.split(".")) - for file in os.listdir(path): - if file.endswith(".py") and not file.startswith("_") and not file.split(".")[0].endswith("_"): - logging.debug("Importing %s", file) - clsmembers = inspect.getmembers( - import_module(f"{base_module_name}{module_name}.{file.split('.')[0]}"), inspect.isclass - ) - for cls in clsmembers: - if issubclass(cls[1], Metric) and cls[1] != Metric and filter_func(cls[1]): - metrics.append((f"{base_module_name}{module_name}.{file.split('.')[0]}", f"{cls[0]}")) - - return metrics - - -def run_all_heuristic_metrics(): - run_metrics(filter_func=lambda x: x.METRIC_TYPE == MetricType.HEURISTIC) - - -def run_all_image_metrics(): - run_metrics(filter_func=lambda x: x.DATA_TYPE == DataType.IMAGE) - - -def run_all_polygon_metrics(): - run_metrics(filter_func=lambda x: x.ANNOTATION_TYPE in [AnnotationType.OBJECT.POLYGON, AnnotationType.ALL]) - - -def run_all_prediction_metrics(**kwargs): - # Return all metrics that apply to objects. - def filter(m: Metric): - at = m.ANNOTATION_TYPE - if isinstance(at, list): - for t in at: - if isinstance(t, ObjectShape): - return True - return False - else: - return isinstance(at, ObjectShape) - - run_metrics(filter_func=filter, **kwargs) - - -def run_metrics(filter_func: Callable = lambda x: True, **kwargs): - metrics: List[Metric] = list( - map( - lambda mod_cls: import_module(mod_cls[0]).__getattribute__(mod_cls[1])(), - get_metrics(filter_func=filter_func), - ) - ) - perform_test(metrics, **kwargs) diff --git a/src/encord_active/lib/metrics/semantic/_class_uncertainty.py b/src/encord_active/lib/metrics/semantic/_class_uncertainty.py index d1713c7f7..aba59f5a6 100644 --- a/src/encord_active/lib/metrics/semantic/_class_uncertainty.py +++ b/src/encord_active/lib/metrics/semantic/_class_uncertainty.py @@ -10,9 +10,14 @@ from torch.nn import LeakyReLU from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter -from encord_active.lib.embeddings.cnn_embed import get_cnn_embeddings +from encord_active.lib.embeddings.cnn import get_cnn_embeddings +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logging.getLogger(__name__) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -192,7 +197,7 @@ class EntropyMetric(Metric): LONG_DESCRIPTION = r"""Uses the entropy of the distribution over labels from a lightweight classifier neural network and Monte-Carlo Dropout to estimate the uncertainty of the label. """ - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): batches, classifier, name_to_idx, resnet_embeddings_df = preliminaries(iterator) with classifier.mc_eval() and torch.inference_mode(): pbar = tqdm.tqdm(total=len(resnet_embeddings_df), desc="Predicting uncertainty") @@ -220,7 +225,7 @@ class ConfidenceScoreMetric(Metric): SHORT_DESCRIPTION = "Estimates the confidence of the assigned label." LONG_DESCRIPTION = r"""Estimates the confidence of the assigned label as the probability of the assigned label.""" - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): batches, classifier, name_to_idx, resnet_embeddings_df = preliminaries(iterator) with classifier.mc_eval() and torch.inference_mode(): diff --git a/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py b/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py index 795dde682..b8486ace7 100644 --- a/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py +++ b/src/encord_active/lib/metrics/semantic/_heatmap_uncertainty.py @@ -12,9 +12,14 @@ from torchvision.models.segmentation import DeepLabV3_MobileNet_V3_Large_Weights from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import AnnotationType, DataType, Metric, MetricType -from encord_active.lib.common.writer import CSVMetricWriter +from encord_active.lib.metrics.metric import ( + AnnotationType, + DataType, + Metric, + MetricType, +) from encord_active.lib.metrics.semantic._class_uncertainty import train_test_split +from encord_active.lib.metrics.writer import CSVMetricWriter DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -201,7 +206,7 @@ class EntropyHeatmapMetric(Metric): ) LONG_DESCRIPTION = r"""""" - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): model_path = os.path.join(iterator.cache_dir, "models", f"{Path(__file__).stem}_model.pt") os.makedirs(os.path.dirname(model_path), exist_ok=True) diff --git a/src/encord_active/lib/metrics/semantic/img_classification_quality.py b/src/encord_active/lib/metrics/semantic/img_classification_quality.py index d143c98bb..97446b9fb 100644 --- a/src/encord_active/lib/metrics/semantic/img_classification_quality.py +++ b/src/encord_active/lib/metrics/semantic/img_classification_quality.py @@ -10,15 +10,15 @@ from loguru import logger from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import ( +from encord_active.lib.embeddings.cnn import get_cnn_embeddings +from encord_active.lib.metrics.metric import ( AnnotationType, DataType, EmbeddingType, Metric, MetricType, ) -from encord_active.lib.common.writer import CSVMetricWriter -from encord_active.lib.embeddings.cnn_embed import get_cnn_embeddings +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logger.opt(colors=True) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -249,7 +249,7 @@ def setup(self, iterator: Iterator) -> bool: return found_any - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): project_has_classifications = self.setup(iterator) if not project_has_classifications: logger.info("[Skipping] No frame level classifications in the project ontology.") diff --git a/src/encord_active/lib/metrics/semantic/img_object_quality.py b/src/encord_active/lib/metrics/semantic/img_object_quality.py index bafa33c96..9f9707d93 100644 --- a/src/encord_active/lib/metrics/semantic/img_object_quality.py +++ b/src/encord_active/lib/metrics/semantic/img_object_quality.py @@ -7,18 +7,18 @@ from tqdm import tqdm from encord_active.lib.common.iterator import Iterator -from encord_active.lib.common.metric import ( +from encord_active.lib.common.utils import ( + fix_duplicate_image_orders_in_knn_graph_all_rows, +) +from encord_active.lib.embeddings.cnn import get_cnn_embeddings +from encord_active.lib.metrics.metric import ( AnnotationType, DataType, EmbeddingType, Metric, MetricType, ) -from encord_active.lib.common.utils import ( - fix_duplicate_image_orders_in_knn_graph_all_rows, -) -from encord_active.lib.common.writer import CSVMetricWriter -from encord_active.lib.embeddings.cnn_embed import get_cnn_embeddings +from encord_active.lib.metrics.writer import CSVMetricWriter logger = logging.getLogger(__name__) @@ -26,7 +26,7 @@ class ObjectEmbeddingSimilarityTest(Metric): TITLE = "Object Annotation Quality" SHORT_DESCRIPTION = "Compares object annotations against similar image crops" - LONG_DESCRIPTION = r"""This metric transforms polygons into bounding boxes + LONG_DESCRIPTION = r"""This metric transforms polygons into bounding boxes and an embedding for each bounding box is extracted. Then, these embeddings are compared with their neighbors. If the neighbors are annotated differently, a low score is given to it. """ @@ -102,7 +102,7 @@ def unpack_collections(self, collections: list) -> None: def get_identifier_from_collection_item(self, item): return f'{item["label_row"]}_{item["data_unit"]}_{item["frame"]:05d}_{item["objectHash"]}' - def test(self, iterator: Iterator, writer: CSVMetricWriter): + def execute(self, iterator: Iterator, writer: CSVMetricWriter): ontology_contains_objects = self.setup(iterator) if not ontology_contains_objects: logger.info("[Skipping] No objects in the project ontology.") diff --git a/src/encord_active/lib/metrics/utils.py b/src/encord_active/lib/metrics/utils.py new file mode 100644 index 000000000..0dc6753ef --- /dev/null +++ b/src/encord_active/lib/metrics/utils.py @@ -0,0 +1,139 @@ +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, TypedDict, Union + +import pandas as pd +import pandera as pa +from natsort import natsorted +from pandera.typing import DataFrame, Series + +from encord_active.lib.common.utils import load_json +from encord_active.lib.metrics.metric import MetricMetadata + + +@dataclass +class MetricData: + name: str + path: Path + meta: MetricMetadata + level: str + + +class IdentifierSchema(pa.SchemaModel): + identifier: Series[str] = pa.Field() + + +class MetricSchema(IdentifierSchema): + score: Series[float] = pa.Field(coerce=True) + identifier: Series[str] = pa.Field() + description: Series[str] = pa.Field(nullable=True, coerce=True) + object_class: Series[str] = pa.Field(nullable=True, coerce=True) + annotator: Series[str] = pa.Field(nullable=True, coerce=True) + frame: Series[int] = pa.Field() + url: Series[str] = pa.Field(nullable=True, coerce=True) + + +def load_metric_dataframe(metric: MetricData, normalize: bool, *, sorting_key="score") -> DataFrame[MetricSchema]: + """ + Load and sort the selected csv file and cache it, so we don't need to perform this + heavy computation each time the slider in the UI is moved. + :param metric: The metric to load data from. + :param normalize: whether to apply normalisation to the scores or not. + :param sorting_key: key by which to sort dataframe (default: "score") + :return: a pandas data frame with all the scores. + """ + df = pd.read_csv(metric.path).sort_values([sorting_key, "identifier"], ascending=True).reset_index() + + if normalize: + min_val = metric.meta.get("min_value") + max_val = metric.meta.get("max_value") + if min_val is None: + min_val = df["score"].min() + if max_val is None: + max_val = df["score"].max() + + diff = max_val - min_val + if diff == 0: # Avoid dividing by zero + diff = 1.0 + + df["score"] = (df["score"] - min_val) / diff + + return df.pipe(DataFrame[MetricSchema]) + + +class MetricScope(Enum): + DATA_QUALITY = "data_quality" + LABEL_QUALITY = "label_quality" + MODEL_QUALITY = "model_quality" + + +def get_metric_operation_level(pth: Path) -> str: + if not all([pth.exists(), pth.is_file(), pth.suffix == ".csv"]): + return "" + + with pth.open("r", encoding="utf-8") as f: + _ = f.readline() # Header, which we don't care about + csv_row = f.readline() # Content line + + if not csv_row: # Empty metric + return "" + + key, _ = csv_row.split(",", 1) + _, _, _, *object_hashes = key.split("_") + return "O" if object_hashes else "F" + + +def is_valid_annotaion_type(annotaion_type: Union[None, List[str]], metric_scope: Optional[MetricScope] = None) -> bool: + if metric_scope == MetricScope.DATA_QUALITY: + return annotaion_type is None + elif metric_scope == MetricScope.LABEL_QUALITY: + return isinstance(annotaion_type, list) + else: + return True + + +def load_available_metrics(metric_dir: Path, metric_scope: Optional[MetricScope] = None) -> List[MetricData]: + if not metric_dir.is_dir(): + return [] + + paths = natsorted([p for p in metric_dir.iterdir() if p.suffix == ".csv"], key=lambda x: x.stem.split("_", 1)[1]) + levels = list(map(get_metric_operation_level, paths)) + + make_name = lambda p: p.name.split("_", 1)[1].rsplit(".", 1)[0].replace("_", " ").title() + names = [f"{make_name(p)}" for p, l in zip(paths, levels)] + meta_data = [load_json(f.with_suffix(".meta.json")) for f in paths] + + out: List[MetricData] = [] + + if not meta_data: + return out + + for p, n, m, l in zip(paths, names, meta_data, levels): + if m is None or not l or not is_valid_annotaion_type(m.get("annotation_type"), metric_scope): + + continue + + out.append(MetricData(name=n, path=p, meta=MetricMetadata(**m), level=l)) # type: ignore + + out = natsorted(out, key=lambda i: (i.level, i.name)) # type: ignore + return out + + +class AnnotatorInfo(TypedDict): + name: str + total_annotations: int + mean_score: float + + +def get_annotator_level_info(df: DataFrame[MetricSchema]) -> dict[str, AnnotatorInfo]: + annotator_set: List[str] = natsorted(list(df[MetricSchema.annotator].unique())) + annotators: Dict[str, AnnotatorInfo] = {} + for annotator in annotator_set: + annotators[annotator] = AnnotatorInfo( + name=annotator, + total_annotations=df[df[MetricSchema.annotator] == annotator].shape[0], + mean_score=df[df[MetricSchema.annotator] == annotator]["score"].mean(), + ) + + return annotators diff --git a/src/encord_active/lib/metrics/writer.py b/src/encord_active/lib/metrics/writer.py new file mode 100644 index 000000000..3452fe47e --- /dev/null +++ b/src/encord_active/lib/metrics/writer.py @@ -0,0 +1,70 @@ +import csv +from pathlib import Path +from typing import Optional, Union + +from encord_active.lib.common.iterator import Iterator +from encord_active.lib.common.writer import CSVWriter + +SCORE_CSV_FIELDS = ["identifier", "score", "description", "object_class", "annotator", "frame", "url"] + + +class CSVMetricWriter(CSVWriter): + def __init__(self, data_path: Path, iterator: Iterator, prefix: str): + filename = (data_path / "metrics" / f"{prefix}.csv").expanduser() + super(CSVMetricWriter, self).__init__(filename=filename, iterator=iterator) + + self.writer = csv.DictWriter(self.csv_file, fieldnames=SCORE_CSV_FIELDS) + self.writer.writeheader() + + def write( + self, + score: Union[float, int], + labels: Union[list[dict], dict, None] = None, + description: str = "", + label_class: Optional[str] = None, + label_hash: Optional[str] = None, + du_hash: Optional[str] = None, + frame: Optional[int] = None, + url: Optional[str] = None, + annotator: Optional[str] = None, + key: Optional[str] = None, # TODO obsolete parameter, remove from metrics first + ): + if not isinstance(score, (float, int)): + raise TypeError("score must be a float or int") + if isinstance(labels, list) and len(labels) == 0: + labels = None + elif isinstance(labels, dict): + labels = [labels] + + label_hash = self.iterator.label_hash if label_hash is None else label_hash + du_hash = self.iterator.du_hash if du_hash is None else du_hash + frame = self.iterator.frame if frame is None else frame + url = self.iterator.get_data_url() if url is None else url + + if labels is None: + label_class = "" if label_class is None else label_class + annotator = "" if annotator is None else annotator + else: + label_class = labels[0]["name"] if label_class is None else label_class + annotator = labels[0]["lastEditedBy"] if "lastEditedBy" in labels[0] else labels[0]["createdBy"] + + # remember to remove if clause (not its content) when writer's format (obj, score) is enforced on all metrics + # start hack + if key is None: + key = self.get_identifier(labels, label_hash, du_hash, frame) + # end hack + + row = { + "identifier": key, + "score": score, + "description": description, + "object_class": label_class, + "frame": frame, + "url": url, + "annotator": annotator, + } + + self.writer.writerow(row) + self.csv_file.flush() + + super().write(score) diff --git a/src/encord_active/lib/model_predictions/filters.py b/src/encord_active/lib/model_predictions/filters.py new file mode 100644 index 000000000..8e34629cc --- /dev/null +++ b/src/encord_active/lib/model_predictions/filters.py @@ -0,0 +1,36 @@ +import pandas as pd +from pandera.typing import DataFrame + +from encord_active.lib.model_predictions.reader import PredictionMatchSchema + + +def filter_labels_for_frames_wo_predictions( + model_predictions: DataFrame[PredictionMatchSchema], sorted_labels: pd.DataFrame +): + pred_keys = model_predictions["img_id"].unique() + return sorted_labels[sorted_labels["img_id"].isin(pred_keys)] + + +def prediction_and_label_filtering( + selected_class_idx: dict, + labels: pd.DataFrame, + metrics: pd.DataFrame, + model_pred: pd.DataFrame, + precisions: pd.DataFrame, +): + # Predictions + class_idx = selected_class_idx + row_selection = model_pred["class_id"].isin(set(map(int, class_idx.keys()))) + _model_pred = model_pred[row_selection].copy() + + # Labels + row_selection = labels["class_id"].isin(set(map(int, class_idx.keys()))) + _labels = labels[row_selection] + + chosen_name_set = set(map(lambda x: x["name"], class_idx.values())).union({"Mean"}) + _metrics = metrics[metrics["class_name"].isin(chosen_name_set)] + _precisions = precisions[precisions["class_name"].isin(chosen_name_set)] + name_map = {int(k): v["name"] for k, v in class_idx.items()} + _model_pred["class_name"] = _model_pred["class_id"].map(name_map) + _labels["class_name"] = _labels["class_id"].map(name_map) + return _labels, _metrics, _model_pred, _precisions diff --git a/src/encord_active/lib/model_predictions/importers.py b/src/encord_active/lib/model_predictions/importers.py index 9cf388507..ecf357054 100644 --- a/src/encord_active/lib/model_predictions/importers.py +++ b/src/encord_active/lib/model_predictions/importers.py @@ -17,8 +17,8 @@ from PIL import Image from tqdm import tqdm -from encord_active.lib.common.project import download_all_label_rows -from encord_active.lib.model_predictions.prediction_writer import PredictionWriter +from encord_active.lib.model_predictions.writer import PredictionWriter +from encord_active.lib.project.project import download_all_label_rows logger = logging.getLogger(__name__) KITTI_COLUMNS = [ diff --git a/src/encord_active/app/model_quality/map_mar.py b/src/encord_active/lib/model_predictions/map_mar.py similarity index 55% rename from src/encord_active/app/model_quality/map_mar.py rename to src/encord_active/lib/model_predictions/map_mar.py index e22e84be3..4e3872592 100644 --- a/src/encord_active/app/model_quality/map_mar.py +++ b/src/encord_active/lib/model_predictions/map_mar.py @@ -1,16 +1,68 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple, TypedDict import numpy as np import pandas as pd -import streamlit as st +import pandera as pa +from pandera.typing import DataFrame, Series + +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + LabelSchema, + PredictionMatchSchema, + PredictionSchema, +) + + +class GtMatchEntry(TypedDict): + lidx: str + pidxs: List[int] + + +GtMatchCollection = Dict[str, Dict[str, List[GtMatchEntry]]] +""" +Collection of lists of labels and what predictions of the same class they match with. +First-level key is `class_id`, second-level is `img_id` +""" + + +class PerformanceMetricSchema(pa.SchemaModel): + metric: Series[str] = pa.Field() + value: Series[float] = pa.Field(coerce=True) + class_name: Series[str] = pa.Field() + + +class PrecisionRecallSchema(pa.SchemaModel): + precision: Series[float] = pa.Field() + recall: Series[float] = pa.Field() + class_name: Series[str] = pa.Field() + + +class ClassMapEntry(TypedDict): + name: str + featureHash: str + color: str + + +ClassMapCollection = Dict[str, ClassMapEntry] +""" +Key is the `class_id` +""" -@st.experimental_memo def compute_mAP_and_mAR( + model_predictions: DataFrame[PredictionSchema], + labels: DataFrame[LabelSchema], + gt_matched: GtMatchCollection, + class_map: Dict[str, ClassMapEntry], iou_threshold: float, rec_thresholds: Optional[np.ndarray] = None, - ignore_unmatched_frames: bool = False, # pylint: disable=unused-argument -) -> Tuple[pd.DataFrame, pd.DataFrame, np.ndarray, pd.DataFrame, np.ndarray]: + ignore_unmatched_frames: bool = False, +) -> Tuple[ + DataFrame[PredictionMatchSchema], + DataFrame[LabelMatchSchema], + DataFrame[PerformanceMetricSchema], + DataFrame[PrecisionRecallSchema], +]: """ Computes a bunch of quantities used to display results in the UI. The main purpose of this function is to filter true positives from false positives @@ -20,103 +72,94 @@ def compute_mAP_and_mAR( means that the prediction's best match (highest iou) has the given value. - :param iou_threshold: The IOU threshold to compute scores for. - :param rec_thresholds: The recall thresholds to compute the scores for. - Default here is the same as `torchmetrics`. - :param ignore_unmatched_labels: If set to true, will not normalize by - all class labels, but only those associated with images for which - there exist predictions. - :return: - - metrics_df: A df with AP_{class_name} and AR_{class_name} for every class - name in the `_class_map`. This is used with `altair` to plot scores. - - prec_df: Precision-recall data used with `altair` as well. Again grouped - by `class_name`. - - tps: An indicator array for which `tps[i] == True` if `_predictions.iloc[i]` - was a true positive. Otherwise False. - - reasons: A string for each entry in `_predictions` giving the reason for - why this point was a false negative. - (`reasons[i]` == ""` if `tps[i] == True`). - - fns: An indicator array for which `fns[j] == True` if `_labels.iloc[j]` - was not matched by any prediction. - """ - """ - :param _predictions: A df with the predictions ordered (DESC) by the + :param model_predictions: A df with the predictions ordered (DESC) by the models confidence scores. - """ - _predictions = st.session_state.model_predictions - """ - :param _labels: The df with the labels. The labels are only used in + :param labels: The df with the labels. The labels are only used in this function to build a list of false negatives, as the matching between labels and predictions were already done during import. - """ - _labels = st.session_state.labels - """ - :param _gt_matched: A dictionary with format:: + :param gt_matched: A dictionary with format:: - _gt_matched[(img_id: int, class_id: int)] = \ + gt_matched[class_id: str][img_id: str] = \ [{"lidx": lidx: int, "pidxs": [pidx1: int, pidx2, ... ]}] - each entry in the `_gt_matched` is thus a list of all the labels of a + each entry in the `gt_matched` is thus a list of all the labels of a particular class (`class_id`) from a particular image (`img_id`). Each entry in the list contains the index to which it belongs in the - `_labels` data frame (`lidx`) and the indices of all the predictions + `labels` data frame (`lidx`) and the indices of all the predictions that matches the given label best sorted by confidence score. That is, if `lidx == 512` and `pidx1 == 256`, it - means that prediction 256 (`_predictions.iloc[256]`) matched label - 512 (`_labels.iloc[512]`) with the highest iou of all the prediction's + means that prediction 256 (`predictions.iloc[256]`) matched label + 512 (`labels.iloc[512]`) with the highest iou of all the prediction's matches. Furthermore, as `pidx1` comes before `pidx2`, it means that:: - _predictions.iloc[pidx1]["confidence"] >= _predictions.iloc[pidx2]["confidence"] + predictions.iloc[pidx1]["confidence"] >= predictions.iloc[pidx2]["confidence"] - """ - _gt_matched = st.session_state.gt_matched - """ - :param _class_map: This is a mapping between class indices and essential + :param class_map: This is a mapping between class indices and essential metadata. The dict has the structure:: { - "": { # <-- Note string type here <-- + "": { # <-- Note string type here, e.g., '"1"' <-- "name": "The name of the object class", - "heatureHash": "e2f0be6c", + "featureHash": "e2f0be6c", "color": "#001122", } } + + :param iou_threshold: The IOU threshold to compute scores for. + :param rec_thresholds: The recall thresholds to compute the scores for. + Default here is the same as `torchmetrics`. + :param ignore_unmatched_labels: If set to true, will not normalize by + all class labels, but only those associated with images for which + there exist predictions. + :return: + - metrics_df: A df with AP_{class_name} and AR_{class_name} for every class + name in the `_class_map`. This is used with `altair` to plot scores. + - prec_df: Precision-recall data used with `altair` as well. Again grouped + by `class_name`. + - tps: An indicator array for which `tps[i] == True` if `_predictions.iloc[i]` + was a true positive. Otherwise False. + - reasons: A string for each entry in `_predictions` giving the reason for + why this point was a false negative. + (`reasons[i]` == ""` if `tps[i] == True`). + - fns: An indicator array for which `fns[j] == True` if `_labels.iloc[j]` + was not matched by any prediction. """ - _class_map = st.session_state.full_class_idx + model_predictions = model_predictions.copy() + labels = labels.copy() + rec_thresholds = rec_thresholds or np.linspace(0.0, 1.00, round(1.00 / 0.01) + 1) - full_index_list = np.arange(_predictions.shape[0]) - pred_class_list = _predictions["class_id"].to_numpy(dtype=int) - ious = _predictions["iou"].to_numpy(dtype=float) + full_index_list = np.arange(model_predictions.shape[0]) + pred_class_list = model_predictions["class_id"].to_numpy(dtype=int) + ious = model_predictions["iou"].to_numpy(dtype=float) # == Output containers == # - # Stores the mapping between class_idx and list_idx of the following two lists + # Stores the mapping between class_id and list_idx of the following two lists class_idx_map = {} precisions = [] recalls = [] - _tps = np.zeros((_predictions.shape[0],), dtype=bool) - reasons = [f"No overlapping label of class `{_class_map[str(i)]['name']}`." for i in pred_class_list] - _fns = np.zeros((_labels.shape[0],), dtype=bool) + _tps = np.zeros((model_predictions.shape[0],), dtype=bool) + reasons = [f"No overlapping label of class `{class_map[str(i)]['name']}`." for i in pred_class_list] + _fns = np.zeros((labels.shape[0],), dtype=bool) - pred_img_ids = set(_predictions["img_id"]) - label_include_ids = set(_labels.index[_labels["img_id"].isin(pred_img_ids)]) + pred_img_ids = set(model_predictions["img_id"]) + label_include_ids = set(labels.index[labels["img_id"].isin(pred_img_ids)]) cidx = 0 - for class_idx in _class_map: + for class_id in class_map: if ignore_unmatched_frames: - # nb_labels = sum([len([c for c in l if c in pred_img_ids]) for l in _gt_matched[class_idx].values()]) nb_labels = sum( - [len([t for t in l if t["lidx"] in label_include_ids]) for l in _gt_matched.get(class_idx, {}).values()] + [len([t for t in l if t["lidx"] in label_include_ids]) for l in gt_matched.get(class_id, {}).values()] ) else: - nb_labels = sum([len(l) for l in _gt_matched.get(class_idx, {}).values()]) + nb_labels = sum([len(l) for l in gt_matched.get(class_id, {}).values()]) if nb_labels == 0: continue - class_idx_map[cidx] = class_idx # Keep track of the order of the output lists - pred_select = pred_class_list == int(class_idx) + class_idx_map[cidx] = class_id # Keep track of the order of the output lists + pred_select = pred_class_list == int(class_id) if pred_select.sum() == 0: precisions.append(np.zeros(rec_thresholds.shape)) recalls.append(np.array(0.0)) @@ -130,7 +173,7 @@ def compute_mAP_and_mAR( TP_candidates = set(class_level_to_full_list_idx[_ious >= iou_threshold].astype(int).tolist()) TP = np.zeros(_ious.shape[0]) - for img_idx, img_label_matches in _gt_matched[class_idx].items(): + for img_label_matches in gt_matched[class_id].values(): for label_match in img_label_matches: found_one = False for tp_idx in label_match["pidxs"]: @@ -178,21 +221,30 @@ def compute_mAP_and_mAR( ["mAP", np.mean(_precisions.mean(axis=1)).item(), "Mean"], ] metrics += [ - [f"AP_{_class_map[class_idx]['name']}", _precisions[cidx].mean().item(), _class_map[class_idx]["name"]] - for cidx, class_idx in class_idx_map.items() + [f"AP_{class_map[class_id]['name']}", _precisions[cidx].mean().item(), class_map[class_id]["name"]] + for cidx, class_id in class_idx_map.items() ] metrics += [["mAR", np.mean(recalls).item(), "Mean"]] metrics += [ - [f"AR_{_class_map[class_idx]['name']}", recalls[cidx].item(), _class_map[class_idx]["name"]] - for cidx, class_idx in class_idx_map.items() + [f"AR_{class_map[class_id]['name']}", recalls[cidx].item(), class_map[class_id]["name"]] + for cidx, class_id in class_idx_map.items() ] - metrics_df = pd.DataFrame(metrics, columns=["metric", "value", "class_name"]) + metrics_df = pd.DataFrame(metrics, columns=["metric", "value", "class_name"]).pipe( + DataFrame[PerformanceMetricSchema] + ) prec_data = [] - columns = ["rc_threshold", "class_name", "precision"] - for cidx, class_idx in class_idx_map.items(): + columns = ["precision", "recall", "class_name"] + for cidx, class_id in class_idx_map.items(): for rc_idx, rc_threshold in enumerate(rec_thresholds): - prec_data.append([rc_threshold, _class_map[class_idx]["name"], _precisions[cidx, rc_idx]]) + prec_data.append([_precisions[cidx, rc_idx], rc_threshold, class_map[class_id]["name"]]) + pr_df = pd.DataFrame(prec_data, columns=columns).pipe(DataFrame[PrecisionRecallSchema]) + + model_predictions[PredictionMatchSchema.is_true_positive] = _tps.astype(float) + model_predictions[PredictionMatchSchema.false_positive_reason] = reasons + out_predictions = model_predictions.pipe(DataFrame[PredictionMatchSchema]) + + labels[LabelMatchSchema.is_false_negative] = _fns + out_labels = labels.pipe(DataFrame[LabelMatchSchema]) - prec_df = pd.DataFrame(prec_data, columns=columns) - return metrics_df, prec_df, _tps, pd.DataFrame(reasons, columns=["fp_reason"]), _fns + return out_predictions, out_labels, metrics_df, pr_df diff --git a/src/encord_active/lib/model_predictions/reader.py b/src/encord_active/lib/model_predictions/reader.py new file mode 100644 index 000000000..c90c5c8df --- /dev/null +++ b/src/encord_active/lib/model_predictions/reader.py @@ -0,0 +1,169 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Iterable, List, Optional, TypedDict, cast + +import pandas as pd +import pandera as pa +import pandera.dtypes as padt +from natsort import natsorted +from pandera.typing import DataFrame, Series + +from encord_active.lib.common.utils import load_json +from encord_active.lib.metrics.utils import ( + IdentifierSchema, + MetricData, + load_available_metrics, + load_metric_dataframe, +) + + +class OntologyObjectJSON(TypedDict): + featureHash: str + name: str + color: str + + +@dataclass +class MetricEntryPoint: + metric_path: Path + is_predictions: bool + filter_fn: Optional[Callable[[MetricData], Any]] = None + + +class LabelSchema(IdentifierSchema): + url: Series[str] = pa.Field() + img_id: Series[padt.Int64] = pa.Field(coerce=True) + class_id: Series[padt.Int64] = pa.Field(coerce=True) + x1: Series[padt.Int64] = pa.Field(coerce=True) + y1: Series[padt.Int64] = pa.Field(coerce=True) + x2: Series[padt.Int64] = pa.Field(coerce=True) + y2: Series[padt.Int64] = pa.Field(coerce=True) + rle: Series[object] = pa.Field() + + +class PredictionSchema(LabelSchema): + confidence: Series[padt.Float64] = pa.Field(coerce=True) + iou: Series[padt.Float64] = pa.Field(coerce=True) + + +class PredictionMatchSchema(PredictionSchema): + is_true_positive: Series[float] = pa.Field() + false_positive_reason: Series[str] = pa.Field() + + +class LabelMatchSchema(LabelSchema): + is_false_negative: Series[bool] = pa.Field() + + +def check_model_prediction_availability(predictions_dir): + predictions_path = predictions_dir / "predictions.csv" + return predictions_path.is_file() + + +def filter_none_empty_metrics(metric: MetricData) -> str: + with metric.path.open("r", encoding="utf-8") as f: + f.readline() # header + key, *_ = f.readline().split(",") + return key + + +def filter_label_metrics_for_predictions(metric: MetricData) -> bool: + key = filter_none_empty_metrics(metric) + if not key: + return False + _, _, _, *rest = key.split("_") + return not rest + + +def get_metric_data(entry_points: List[MetricEntryPoint], append_level_to_titles: bool = True) -> List[MetricData]: + all_metrics: List[MetricData] = [] + + for entry in entry_points: + available_metrics: Iterable[MetricData] = list( + filter(entry.filter_fn, load_available_metrics(entry.metric_path, None)) + ) + if not append_level_to_titles: + all_metrics += natsorted(available_metrics, key=lambda x: x.name) + continue + + for metric in available_metrics: + metric.name += " (P)" if entry.is_predictions else f" ({metric.level})" + all_metrics += natsorted(available_metrics, key=lambda x: x.name[-3:] + x.name[:-3]) + + return all_metrics + + +def append_metric_columns(df: pd.DataFrame, metric_entries: List[MetricData]) -> pd.DataFrame: + IdentifierSchema.validate(df) + df = df.copy() + df["identifier_no_oh"] = df[IdentifierSchema.identifier].str.replace(r"^(\S{73}_\d+)(.*)", r"\1", regex=True) + + for metric in metric_entries: + metric_scores = load_metric_dataframe(metric, normalize=False) + metric_scores = metric_scores.set_index(keys=["identifier"]) + + has_object_level_keys = len(cast(str, metric_scores.index[0]).split("_")) > 3 + metric_column = "identifier" if has_object_level_keys else "identifier_no_oh" + + # Join data and rename column to metric name. + df = df.join(metric_scores["score"], on=[metric_column]) + df[metric.name] = df["score"] + df.drop("score", axis=1, inplace=True) + + df.drop("identifier_no_oh", axis=1, inplace=True) + return df + + +def _load_csv_and_merge_metrics(path: Path, metric_data: List[MetricData]) -> Optional[pd.DataFrame]: + if not path.suffix == ".csv": + return None + + if not path.is_file(): + return None + + df = cast(pd.DataFrame, pd.read_csv(path, iterator=False)) + return append_metric_columns(df, metric_data) + + +def get_prediction_metric_data(predictions_dir: Path, metrics_dir: Path) -> List[MetricData]: + predictions_metric_dir = predictions_dir / "metrics" + if not predictions_metric_dir.is_dir(): + return [] + + entry_points = [ + MetricEntryPoint(metrics_dir, is_predictions=False, filter_fn=filter_label_metrics_for_predictions), + MetricEntryPoint(predictions_metric_dir, is_predictions=True, filter_fn=filter_none_empty_metrics), + ] + return get_metric_data(entry_points) + + +def get_model_predictions( + predictions_dir: Path, metric_data: List[MetricData] +) -> Optional[DataFrame[PredictionSchema]]: + df = _load_csv_and_merge_metrics(predictions_dir / "predictions.csv", metric_data) + return df.pipe(DataFrame[PredictionSchema]) if df is not None else None + + +def get_label_metric_data(metrics_dir: Path) -> List[MetricData]: + if not metrics_dir.is_dir(): + return [] + + entry_points = [ + MetricEntryPoint(metrics_dir, is_predictions=False, filter_fn=filter_none_empty_metrics), + ] + return get_metric_data(entry_points) + + +def get_labels(predictions_dir: Path, metric_data: List[MetricData]) -> Optional[DataFrame[LabelSchema]]: + df = _load_csv_and_merge_metrics(predictions_dir / "labels.csv", metric_data) + return df.pipe(DataFrame[LabelSchema]) if df is not None else None + + +def get_gt_matched(predictions_dir: Path) -> Optional[dict]: + gt_path = predictions_dir / "ground_truths_matched.json" + return load_json(gt_path) + + +def get_class_idx(predictions_dir: Path) -> dict[str, OntologyObjectJSON]: + class_idx_pth = predictions_dir / "class_idx.json" + return load_json(class_idx_pth) or {} diff --git a/src/encord_active/lib/model_predictions/prediction_writer.py b/src/encord_active/lib/model_predictions/writer.py similarity index 98% rename from src/encord_active/lib/model_predictions/prediction_writer.py rename to src/encord_active/lib/model_predictions/writer.py index 24a176ce9..35eca2edb 100644 --- a/src/encord_active/lib/model_predictions/prediction_writer.py +++ b/src/encord_active/lib/model_predictions/writer.py @@ -13,8 +13,8 @@ from torchvision.ops import box_iou from tqdm import tqdm -from encord_active.lib.common.project import Project from encord_active.lib.common.utils import binary_mask_to_rle, rle_iou +from encord_active.lib.project.project import Project logger = logging.getLogger(__name__) BBOX_KEYS = {"x", "y", "w", "h"} @@ -65,7 +65,7 @@ def polyobj_to_nparray(o: dict, width: int, height: int) -> np.ndarray: def points_to_mask(points: np.ndarray, width: int, height: int): mask = np.zeros((height, width), dtype=np.uint8) - mask = cv2.fillPoly(mask, [points.astype(int)], 1) # type: ignore + mask = cv2.fillPoly(mask, [(points * np.array([[width, height]])).astype(int)], 1) # type: ignore return mask @@ -425,7 +425,7 @@ def add_prediction( polygon = polygon.astype(float) / np.array([[width, height]]) np_mask = points_to_mask(polygon, width=width, height=height) # type: ignore - x1, y1, w, h = cv2.boundingRect(polygon.reshape(-1, 1, 2).astype(int)) # type: ignore + x1, y1, w, h = cv2.boundingRect((polygon * np.array([[width, height]])).reshape(-1, 1, 2).astype(int)) # type: ignore x2, y2 = x1 + w, y1 + h points = [x1, y1, x2, y2] mask = binary_mask_to_rle(np_mask) diff --git a/src/encord_active/lib/project/__init__.py b/src/encord_active/lib/project/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/lib/common/project.py b/src/encord_active/lib/project/project.py similarity index 86% rename from src/encord_active/lib/common/project.py rename to src/encord_active/lib/project/project.py index 74603f845..275c30c5a 100644 --- a/src/encord_active/lib/common/project.py +++ b/src/encord_active/lib/project/project.py @@ -21,6 +21,7 @@ fetch_project_meta, slice_video_into_frames, ) +from encord_active.lib.project.project_file_structure import ProjectFileStructure logger = logger.opt(colors=True) encord_logger = logging.getLogger("encord") @@ -30,7 +31,8 @@ class Project: def __init__(self, project_dir: Path): self.project_dir: Path = project_dir - self.project_meta: Dict[str, str] = {} + self.project_file_structure = ProjectFileStructure(project_dir) + self.project_meta = fetch_project_meta(self.project_file_structure.project_dir) self.project_hash: str = "" self.ontology: OntologyStructure = OntologyStructure.from_dict(dict(objects=[], classifications=[])) self.label_row_meta: Dict[str, LabelRowMetadata] = {} @@ -89,22 +91,6 @@ def from_encord_project(self, encord_project: EncordProject) -> Project: return self.load() - @property - def data_dir_path(self) -> Path: - return self.project_dir / "data" - - @property - def label_row_meta_file_path(self) -> Path: - return self.project_dir / "label_row_meta.json" - - @property - def ontology_file_path(self) -> Path: - return self.project_dir / "ontology.json" - - @property - def project_meta_file_path(self) -> Path: - return self.project_dir / "project_meta.yaml" - @property def is_loaded(self) -> bool: return all( @@ -119,35 +105,34 @@ def is_loaded(self) -> bool: ) ) - def get_label_row_file_path(self, label_hash: str) -> Path: - return self.data_dir_path / label_hash / "label_row.json" - def __save_project_meta(self, encord_project: EncordProject): - project_meta = { - "project_title": encord_project.title, - "project_description": encord_project.description, - "project_hash": encord_project.project_hash, - } - project_meta_file_path = self.project_meta_file_path - project_meta_file_path.write_text(yaml.dump(project_meta), encoding="utf-8") + project_meta_file_path = self.project_file_structure.project_meta + self.project_meta.update( + { + "project_title": encord_project.title, + "project_description": encord_project.description, + "project_hash": encord_project.project_hash, + } + ) + project_meta_file_path.write_text(yaml.safe_dump(self.project_meta), encoding="utf-8") def __save_ontology(self, encord_project: EncordProject): - ontology_file_path = self.ontology_file_path + ontology_file_path = self.project_file_structure.ontology ontology_file_path.write_text(json.dumps(encord_project.ontology, indent=2), encoding="utf-8") def __load_ontology(self): - ontology_file_path = self.ontology_file_path + ontology_file_path = self.project_file_structure.ontology if not ontology_file_path.exists(): raise FileNotFoundError(f"Expected file `ontology.json` at {ontology_file_path.parent}") self.ontology = OntologyStructure.from_dict(json.loads(ontology_file_path.read_text(encoding="utf-8"))) def __save_label_row_meta(self, encord_project: EncordProject): label_row_meta = {lr["label_hash"]: lr for lr in encord_project.label_rows if lr["label_hash"] is not None} - label_row_meta_file_path = self.label_row_meta_file_path + label_row_meta_file_path = self.project_file_structure.label_row_meta label_row_meta_file_path.write_text(json.dumps(label_row_meta, indent=2), encoding="utf-8") def __load_label_row_meta(self, subset_size: Optional[int]): - label_row_meta_file_path = self.label_row_meta_file_path + label_row_meta_file_path = self.project_file_structure.label_row_meta if not label_row_meta_file_path.exists(): raise FileNotFoundError(f"Expected file `label_row_meta.json` at {label_row_meta_file_path.parent}") self.label_row_meta = { @@ -165,8 +150,8 @@ def __load_label_rows(self): self.label_rows = {} self.image_paths = {} for lr_hash in self.label_row_meta.keys(): - lr_file_path = self.get_label_row_file_path(lr_hash) - lr_images_dir = self.data_dir_path / lr_hash / "images" + lr_file_path = self.project_file_structure.label_row_structure(lr_hash).label_row_file + lr_images_dir = self.project_file_structure.data / lr_hash / "images" if not lr_file_path.is_file() or not lr_images_dir.is_dir(): logger.warning( f"Skipping label row `{lr_hash}` as no stored content was found for the label row." diff --git a/src/encord_active/lib/project/project_file_structure.py b/src/encord_active/lib/project/project_file_structure.py new file mode 100644 index 000000000..1b56f1104 --- /dev/null +++ b/src/encord_active/lib/project/project_file_structure.py @@ -0,0 +1,49 @@ +from pathlib import Path +from typing import NamedTuple + + +class LabelRowStructure(NamedTuple): + path: Path + images: Path + label_row_file: Path + + +class ProjectFileStructure: + def __init__(self, project_dir: Path): + self.project_dir: Path = project_dir + + @property + def data(self) -> Path: + return self.project_dir / "data" + + @property + def metrics(self) -> Path: + return self.project_dir / "metrics" + + @property + def embeddings(self) -> Path: + return self.project_dir / "embeddings" + + @property + def predictions(self) -> Path: + return self.project_dir / "predictions" + + @property + def db(self) -> Path: + return self.project_dir / "sqlite.db" + + @property + def label_row_meta(self) -> Path: + return self.project_dir / "label_row_meta.json" + + @property + def ontology(self) -> Path: + return self.project_dir / "ontology.json" + + @property + def project_meta(self) -> Path: + return self.project_dir / "project_meta.yaml" + + def label_row_structure(self, label_hash: str) -> LabelRowStructure: + path = self.data / label_hash + return LabelRowStructure(path, path / "images", path / "label_row.json") diff --git a/src/encord_active/lib/metrics/fetch_prebuilt_metrics.py b/src/encord_active/lib/project/sandbox_projects.py similarity index 100% rename from src/encord_active/lib/metrics/fetch_prebuilt_metrics.py rename to src/encord_active/lib/project/sandbox_projects.py