diff --git a/src/encord_active/app/common/state.py b/src/encord_active/app/common/state.py
index 3d4e91f68..c574e5e85 100644
--- a/src/encord_active/app/common/state.py
+++ b/src/encord_active/app/common/state.py
@@ -43,7 +43,7 @@ class PredictionsState:
all_classes_objects: Dict[str, OntologyObjectJSON] = field(default_factory=dict)
all_classes_classifications: Dict[str, OntologyClassificationJSON] = field(default_factory=dict)
selected_classes_objects: Dict[str, OntologyObjectJSON] = field(default_factory=dict)
- selected_classes_classifications: Dict[str, OntologyObjectJSON] = field(default_factory=dict)
+ selected_classes_classifications: Dict[str, OntologyClassificationJSON] = field(default_factory=dict)
labels: Optional[DataFrame[LabelSchema]] = None
nbins: int = 50
diff --git a/src/encord_active/app/model_quality/prediction_type_builder.py b/src/encord_active/app/model_quality/prediction_type_builder.py
new file mode 100644
index 000000000..b82dd98b4
--- /dev/null
+++ b/src/encord_active/app/model_quality/prediction_type_builder.py
@@ -0,0 +1,297 @@
+from abc import abstractmethod
+from enum import Enum
+from typing import Dict, List, Tuple, Union
+
+import streamlit as st
+from pandera.typing import DataFrame
+
+import encord_active.lib.model_predictions.reader as reader
+from encord_active.app.common.page import Page
+from encord_active.app.common.state import MetricNames, PredictionsState, get_state
+from encord_active.app.common.state_hooks import use_memo
+from encord_active.lib.charts.metric_importance import create_metric_importance_charts
+from encord_active.lib.metrics.utils import MetricData
+from encord_active.lib.model_predictions.reader import (
+ ClassificationLabelSchema,
+ ClassificationPredictionMatchSchemaWithClassNames,
+ ClassificationPredictionSchema,
+ LabelSchema,
+ OntologyObjectJSON,
+ PredictionMatchSchema,
+ PredictionSchema,
+)
+from encord_active.lib.model_predictions.writer import (
+ MainPredictionType,
+ OntologyClassificationJSON,
+)
+
+
+class ModelQualityPage(str, Enum):
+ METRICS = "metrics"
+ PERFORMANCE_BY_METRIC = "performance by metric"
+ EXPLORER = "explorer"
+
+
+class MetricType(str, Enum):
+ PREDICTION = "prediction"
+ LABEL = "label"
+
+
+class PredictionTypeBuilder(Page):
+ def sidebar_options(self, *args, **kwargs):
+ pass
+
+ def _render_class_filtering_component(
+ self, all_classes: Union[Dict[str, OntologyObjectJSON], Dict[str, OntologyClassificationJSON]]
+ ):
+ return st.multiselect(
+ "Filter by class",
+ list(all_classes.items()),
+ format_func=lambda x: x[1]["name"],
+ help="""
+With this selection, you can choose which classes to include in the performance metrics calculations.
+This acts as a filter, i.e. when nothing is selected all classes are included.
+Performance metrics will be automatically updated according to the chosen classes.
+ """,
+ )
+
+ def _description_expander(self, metric_datas: MetricNames):
+ with st.expander("Details", expanded=False):
+ st.markdown(
+ """### The View
+
+On this page, your model scores are displayed as a function of the metric that you selected in the top bar.
+Samples are discritized into $n$ equally sized buckets and the middle point of each bucket is displayed as the x-value
+in the plots. Bars indicate the number of samples in each bucket, while lines indicate the true positive and false
+negative rates of each bucket.
+
+Metrics marked with (P) are metrics computed on your predictions.
+Metrics marked with (F) are frame level metrics, which depends on the frame that each prediction is associated
+with. In the "False Negative Rate" plot, (O) means metrics computed on Object labels.
+
+For metrics that are computed on predictions (P) in the "True Positive Rate" plot, the corresponding "label metrics"
+(O/F) computed on your labels are used for the "False Negative Rate" plot.
+""",
+ unsafe_allow_html=True,
+ )
+ self._metric_details_description(metric_datas)
+
+ @staticmethod
+ def _metric_details_description(metric_datas: MetricNames):
+ metric_name = metric_datas.selected_prediction
+ if not metric_name:
+ return
+
+ metric_data = metric_datas.predictions.get(metric_name)
+
+ if not metric_data:
+ metric_data = metric_datas.labels.get(metric_name)
+
+ if metric_data:
+ st.markdown(f"### The {metric_data.name[:-4]} metric")
+ st.markdown(metric_data.meta.long_description)
+
+ def _set_binning(self):
+ get_state().predictions.nbins = int(
+ st.number_input(
+ "Number of buckets (n)",
+ min_value=5,
+ max_value=200,
+ value=PredictionsState.nbins,
+ help="Choose the number of bins to discritize the prediction metric values into.",
+ )
+ )
+
+ def _set_sampling_for_metric_importance(
+ self,
+ model_predictions: Union[
+ DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema]
+ ],
+ ) -> int:
+ num_samples = st.slider(
+ "Number of samples",
+ min_value=1,
+ max_value=len(model_predictions),
+ step=max(1, (len(model_predictions) - 1) // 100),
+ value=max((len(model_predictions) - 1) // 2, 1),
+ help="To avoid too heavy computations, we subsample the data at random to the selected size, "
+ "computing importance values.",
+ )
+ if num_samples < 100:
+ st.warning(
+ "Number of samples is too low to compute reliable index importance. "
+ "We recommend using at least 100 samples.",
+ )
+
+ return num_samples
+
+ def _class_decomposition(self):
+ st.write("") # Make some spacing.
+ st.write("")
+ get_state().predictions.decompose_classes = st.checkbox(
+ "Show class decomposition",
+ value=PredictionsState.decompose_classes,
+ help="When checked, every plot will have a separate component for each class.",
+ )
+
+ def _topbar_metric_selection_component(self, metric_type: MetricType, metric_datas: MetricNames):
+ """
+ Note: Adding the fixed options "confidence" and "iou" works here because
+ confidence is required on import and IOU is computed during prediction
+ import. So the two columns are already available in the
+ `st.session_state.model_predictions` data frame.
+ """
+ if metric_type == MetricType.LABEL:
+ column_names = list(metric_datas.predictions.keys())
+ metric_datas.selected_label = st.selectbox(
+ "Select metric for your labels",
+ column_names,
+ help="The data in the main view will be sorted by the selected metric. "
+ "(F) := frame scores, (O) := object scores.",
+ )
+ elif metric_type == MetricType.PREDICTION:
+ fixed_options = {"confidence": "Model Confidence"}
+ column_names = list(metric_datas.predictions.keys())
+ metric_datas.selected_prediction = st.selectbox(
+ "Select metric for your predictions",
+ column_names + list(fixed_options.keys()),
+ format_func=lambda s: fixed_options.get(s, s),
+ help="The data in the main view will be sorted by the selected metric. "
+ "(F) := frame scores, (P) := prediction scores.",
+ )
+
+ def _read_prediction_files(
+ self, prediction_type: MainPredictionType, project_path: str
+ ) -> Tuple[
+ List[MetricData],
+ List[MetricData],
+ Union[DataFrame[PredictionSchema], DataFrame[ClassificationPredictionSchema], None],
+ Union[DataFrame[LabelSchema], DataFrame[ClassificationLabelSchema], None],
+ ]:
+ metrics_dir = get_state().project_paths.metrics
+ predictions_dir = get_state().project_paths.predictions / prediction_type.value
+
+ predictions_metric_datas, _ = use_memo(
+ lambda: reader.get_prediction_metric_data(predictions_dir, metrics_dir),
+ key=f"predictions_metrics_data_{project_path}_{prediction_type.value}",
+ )
+
+ label_metric_datas, _ = use_memo(
+ lambda: reader.get_label_metric_data(metrics_dir),
+ key=f"label_metric_datas_{project_path}_{prediction_type.value}",
+ )
+ model_predictions, _ = use_memo(
+ lambda: reader.get_model_predictions(predictions_dir, predictions_metric_datas, prediction_type),
+ key=f"model_predictions_{project_path}_{prediction_type.value}",
+ )
+ labels, _ = use_memo(
+ lambda: reader.get_labels(predictions_dir, label_metric_datas, prediction_type),
+ key=f"labels_{project_path}_{prediction_type.value}",
+ )
+
+ return predictions_metric_datas, label_metric_datas, model_predictions, labels
+
+ def _render_performance_by_metric_description(
+ self,
+ model_predictions: Union[
+ DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema]
+ ],
+ metric_datas: MetricNames,
+ ):
+ if model_predictions.shape[0] == 0:
+ st.write("No predictions of the given class(es).")
+ return
+
+ metric_name = metric_datas.selected_prediction
+ if not metric_name:
+ # This shouldn't happen with the current flow. The only way a user can do this
+ # is if he/she write custom code to bypass running the metrics. In this case,
+ # I think that it is fair to not give more information than this.
+ st.write(
+ "No metrics computed for the your model predictions. "
+ "With `encord-active import predictions /path/to/predictions.pkl`, "
+ "Encord Active will automatically run compute the metrics."
+ )
+ return
+
+ self._description_expander(metric_datas)
+
+ def _get_metric_importance(
+ self,
+ model_predictions: Union[
+ DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema]
+ ],
+ metric_columns: List[str],
+ ):
+ with st.container():
+ with st.expander("Description"):
+ st.write(
+ "The following charts show the dependency between model performance and each index. "
+ "In other words, these charts answer the question of how much is model "
+ "performance affected by each index. This relationship can be decomposed into two metrics:"
+ )
+ st.markdown(
+ "- **Metric importance**: measures the *strength* of the dependency between and metric and model "
+ "performance. A high value means that the model performance would be strongly affected by "
+ "a change in the index. For example, a high importance in 'Brightness' implies that a change "
+ "in that quantity would strongly affect model performance. Values range from 0 (no dependency) "
+ "to 1 (perfect dependency, one can completely predict model performance simply by looking "
+ "at this index)."
+ )
+ st.markdown(
+ "- **Metric [correlation](https://en.wikipedia.org/wiki/Correlation)**: measures the *linearity "
+ "and direction* of the dependency between an index and model performance. "
+ "Crucially, this metric tells us whether a positive change in an index "
+ "will lead to a positive change (positive correlation) or a negative change (negative correlation) "
+ "in model performance . Values range from -1 to 1."
+ )
+ st.write(
+ "Finally, you can also select how many samples are included in the computation "
+ "with the slider, as well as filter by class with the dropdown in the side bar."
+ )
+
+ if model_predictions.shape[0] > 60_000: # Computation are heavy so allow computing for only a subset.
+ num_samples = self._set_sampling_for_metric_importance(model_predictions)
+ else:
+ num_samples = model_predictions.shape[0]
+
+ with st.spinner("Computing index importance..."):
+ try:
+ metric_importance_chart = create_metric_importance_charts(
+ model_predictions,
+ metric_columns=metric_columns,
+ num_samples=num_samples,
+ prediction_type=MainPredictionType.CLASSIFICATION,
+ )
+ st.altair_chart(metric_importance_chart, use_container_width=True)
+ except ValueError as e:
+ if e.args:
+ st.info(e.args[0])
+ else:
+ st.info("Failed to compute metric importance")
+
+ def build(self, page_mode: ModelQualityPage):
+ if self.load_data(page_mode):
+ if page_mode == ModelQualityPage.METRICS:
+ self.render_metrics()
+ elif page_mode == ModelQualityPage.PERFORMANCE_BY_METRIC:
+ self.render_performance_by_metric()
+ elif page_mode == ModelQualityPage.EXPLORER:
+ self.render_explorer()
+
+ def render_metrics(self):
+ st.markdown("### This page is not implemented")
+
+ def render_performance_by_metric(self):
+ st.markdown("### This page is not implemented")
+
+ def render_explorer(self):
+ st.markdown("### This page is not implemented")
+
+ @abstractmethod
+ def load_data(self, page_mode: ModelQualityPage) -> bool:
+ pass
+
+ @abstractmethod
+ def is_available(self) -> bool:
+ pass
diff --git a/src/encord_active/app/model_quality/prediction_types/__init__.py b/src/encord_active/app/model_quality/prediction_types/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py b/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py
new file mode 100644
index 000000000..1549f5e22
--- /dev/null
+++ b/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py
@@ -0,0 +1,264 @@
+from copy import deepcopy
+from enum import Enum
+from typing import List, Optional
+
+import altair as alt
+import streamlit as st
+from loguru import logger
+from pandera.typing import DataFrame
+
+import encord_active.lib.model_predictions.reader as reader
+from encord_active.app.common.components import sticky_header
+from encord_active.app.common.components.prediction_grid import (
+ prediction_grid_classifications,
+)
+from encord_active.app.common.state import MetricNames, get_state
+from encord_active.app.model_quality.prediction_type_builder import (
+ MetricType,
+ ModelQualityPage,
+ PredictionTypeBuilder,
+)
+from encord_active.lib.charts.classification_metrics import (
+ get_accuracy,
+ get_confusion_matrix,
+ get_precision_recall_f1,
+ get_precision_recall_graph,
+)
+from encord_active.lib.charts.histogram import get_histogram
+from encord_active.lib.charts.performance_by_metric import performance_rate_by_metric
+from encord_active.lib.charts.scopes import PredictionMatchScope
+from encord_active.lib.metrics.utils import MetricScope
+from encord_active.lib.model_predictions.classification_metrics import (
+ match_predictions_and_labels,
+)
+from encord_active.lib.model_predictions.filters import (
+ prediction_and_label_filtering_classification,
+)
+from encord_active.lib.model_predictions.reader import (
+ ClassificationLabelSchema,
+ ClassificationPredictionMatchSchemaWithClassNames,
+ ClassificationPredictionSchema,
+ get_class_idx,
+)
+from encord_active.lib.model_predictions.writer import MainPredictionType
+
+
+class ClassificationTypeBuilder(PredictionTypeBuilder):
+ title = "Classification"
+
+ class OutcomeType(str, Enum):
+ CORRECT_CLASSIFICATIONS = "Correct Classifications"
+ MISCLASSIFICATIONS = "Misclassifications"
+
+ def __init__(self):
+ self._explorer_outcome_type = self.OutcomeType.CORRECT_CLASSIFICATIONS
+ self._labels: Optional[List] = None
+ self._predictions: Optional[List] = None
+ self._model_predictions: Optional[DataFrame[ClassificationPredictionMatchSchemaWithClassNames]] = None
+
+ def load_data(self, page_mode: ModelQualityPage) -> bool:
+ predictions_metric_datas, label_metric_datas, model_predictions, labels = self._read_prediction_files(
+ MainPredictionType.CLASSIFICATION, project_path=get_state().project_paths.project_dir.as_posix()
+ )
+
+ if model_predictions is None:
+ st.error("Couldn't load model predictions")
+ return False
+
+ if labels is None:
+ st.error("Couldn't load labels properly")
+ return False
+
+ get_state().predictions.metric_datas_classification = MetricNames(
+ predictions={m.name: m for m in predictions_metric_datas},
+ )
+
+ with sticky_header():
+ self._common_settings()
+ self._topbar_additional_settings(page_mode)
+
+ model_predictions_matched = match_predictions_and_labels(model_predictions, labels)
+
+ (
+ labels_filtered,
+ predictions_filtered,
+ model_predictions_matched_filtered,
+ ) = prediction_and_label_filtering_classification(
+ get_state().predictions.selected_classes_classifications,
+ get_state().predictions.all_classes_classifications,
+ labels,
+ model_predictions,
+ model_predictions_matched,
+ )
+
+ img_id_intersection = list(
+ set(labels_filtered[ClassificationLabelSchema.img_id]).intersection(
+ set(predictions_filtered[ClassificationPredictionSchema.img_id])
+ )
+ )
+ labels_filtered_intersection = labels_filtered[
+ labels_filtered[ClassificationLabelSchema.img_id].isin(img_id_intersection)
+ ]
+ predictions_filtered_intersection = predictions_filtered[
+ predictions_filtered[ClassificationPredictionSchema.img_id].isin(img_id_intersection)
+ ]
+
+ self._labels, self._predictions = (
+ list(labels_filtered_intersection[ClassificationLabelSchema.class_id]),
+ list(predictions_filtered_intersection[ClassificationPredictionSchema.class_id]),
+ )
+
+ self._model_predictions = model_predictions_matched_filtered.copy()[
+ model_predictions_matched_filtered[ClassificationPredictionMatchSchemaWithClassNames.img_id].isin(
+ img_id_intersection
+ )
+ ]
+
+ return True
+
+ def _common_settings(self):
+ if not get_state().predictions.all_classes_classifications:
+ get_state().predictions.all_classes_classifications = get_class_idx(
+ get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value
+ )
+
+ all_classes = get_state().predictions.all_classes_classifications
+ selected_classes = self._render_class_filtering_component(all_classes)
+
+ get_state().predictions.selected_classes_classifications = dict(selected_classes) or deepcopy(all_classes)
+
+ def _topbar_additional_settings(self, page_mode: ModelQualityPage):
+ if page_mode == ModelQualityPage.METRICS:
+ return
+ elif page_mode == ModelQualityPage.PERFORMANCE_BY_METRIC:
+ c1, c2, c3 = st.columns([4, 4, 3])
+ with c1:
+ self._topbar_metric_selection_component(
+ MetricType.PREDICTION, get_state().predictions.metric_datas_classification
+ )
+ with c2:
+ self._set_binning()
+ with c3:
+ self._class_decomposition()
+ elif page_mode == ModelQualityPage.EXPLORER:
+ c1, c2 = st.columns([4, 4])
+
+ with c1:
+ self._explorer_outcome_type = st.selectbox(
+ "Outcome",
+ [x for x in self.OutcomeType],
+ format_func=lambda x: x.value,
+ help="Only the samples with this outcome will be shown",
+ )
+ with c2:
+ self._topbar_metric_selection_component(
+ MetricType.PREDICTION, get_state().predictions.metric_datas_classification
+ )
+
+ self.display_settings(MetricScope.MODEL_QUALITY)
+
+ def is_available(self) -> bool:
+ return reader.check_model_prediction_availability(
+ get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value
+ )
+
+ def render_metrics(self):
+ class_names = sorted(list(set(self._labels).union(self._predictions)))
+
+ precision, recall, f1, support = get_precision_recall_f1(self._labels, self._predictions)
+ accuracy = get_accuracy(self._labels, self._predictions)
+
+ # PERFORMANCE METRICS SUMMARY
+
+ col_acc, col_prec, col_rec, col_f1 = st.columns(4)
+ col_acc.metric("Accuracy", f"{float(accuracy):.2f}")
+ col_prec.metric(
+ "Mean Precision",
+ f"{float(precision.mean()):.2f}",
+ help="Average of precision scores of all classes",
+ )
+ col_rec.metric("Mean Recall", f"{float(recall.mean()):.2f}", help="Average of recall scores of all classes")
+ col_f1.metric("Mean F1", f"{float(f1.mean()):.2f}", help="Average of F1 scores of all classes")
+
+ # METRIC IMPORTANCE
+ self._get_metric_importance(
+ self._model_predictions, list(get_state().predictions.metric_datas_classification.predictions.keys())
+ )
+
+ col1, col2 = st.columns(2)
+
+ # CONFUSION MATRIX
+ confusion_matrix = get_confusion_matrix(self._labels, self._predictions, class_names)
+ col1.plotly_chart(confusion_matrix, use_container_width=True)
+
+ # PRECISION_RECALL BARS
+ pr_graph = get_precision_recall_graph(precision, recall, class_names)
+ col2.plotly_chart(pr_graph, use_container_width=True)
+
+ def render_performance_by_metric(self):
+ self._render_performance_by_metric_description(
+ self._model_predictions, get_state().predictions.metric_datas_classification
+ )
+ metric_name = get_state().predictions.metric_datas_classification.selected_prediction
+
+ classes_for_coloring = ["Average"]
+ decompose_classes = get_state().predictions.decompose_classes
+ if decompose_classes:
+ unique_classes = set(self._model_predictions["class_name"].unique())
+ classes_for_coloring += sorted(list(unique_classes))
+
+ # Ensure same colors between plots
+ chart_args = dict(
+ color_params={"scale": alt.Scale(domain=classes_for_coloring)},
+ bins=get_state().predictions.nbins,
+ show_decomposition=decompose_classes,
+ )
+
+ try:
+ tpr = performance_rate_by_metric(
+ self._model_predictions,
+ metric_name,
+ scope=PredictionMatchScope.TRUE_POSITIVES,
+ **chart_args,
+ )
+ if tpr is not None:
+ st.altair_chart(tpr.interactive(), use_container_width=True)
+ except Exception as e:
+ logger.warning(e)
+ pass
+
+ def render_explorer(self):
+ with st.expander("Details"):
+ if self._explorer_outcome_type == self.OutcomeType.CORRECT_CLASSIFICATIONS:
+ view_text = "These are the predictions where the model correctly predicts the true class."
+ else:
+ view_text = "These are the predictions where the model incorrectly predicts the positive class."
+ st.markdown(
+ f"""### The view
+{view_text}
+ """,
+ unsafe_allow_html=True,
+ )
+
+ self._metric_details_description(get_state().predictions.metric_datas_classification)
+
+ metric_name = get_state().predictions.metric_datas_classification.selected_prediction
+ if not metric_name:
+ st.error("No prediction metric selected")
+ return
+
+ if self._explorer_outcome_type == self.OutcomeType.CORRECT_CLASSIFICATIONS:
+ view_df = self._model_predictions[
+ self._model_predictions[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive] == 1.0
+ ].dropna(subset=[metric_name])
+ else:
+ view_df = self._model_predictions[
+ self._model_predictions[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive] == 0.0
+ ].dropna(subset=[metric_name])
+
+ if view_df.shape[0] == 0:
+ st.write(f"No {self._explorer_outcome_type}")
+ else:
+ histogram = get_histogram(view_df, metric_name)
+ st.altair_chart(histogram, use_container_width=True)
+ prediction_grid_classifications(get_state().project_paths, model_predictions=view_df)
diff --git a/src/encord_active/app/model_quality/prediction_types/object_type_builder.py b/src/encord_active/app/model_quality/prediction_types/object_type_builder.py
new file mode 100644
index 000000000..a6476607b
--- /dev/null
+++ b/src/encord_active/app/model_quality/prediction_types/object_type_builder.py
@@ -0,0 +1,382 @@
+import json
+import re
+from copy import deepcopy
+from enum import Enum
+from typing import Optional, Union
+
+import altair as alt
+import streamlit as st
+from loguru import logger
+from pandera.typing import DataFrame
+
+import encord_active.lib.model_predictions.reader as reader
+from encord_active.app.common.components import sticky_header
+from encord_active.app.common.components.prediction_grid import prediction_grid
+from encord_active.app.common.state import MetricNames, State, get_state
+from encord_active.app.common.state_hooks import use_memo
+from encord_active.app.model_quality.prediction_type_builder import (
+ MetricType,
+ ModelQualityPage,
+ PredictionTypeBuilder,
+)
+from encord_active.lib.charts.histogram import get_histogram
+from encord_active.lib.charts.performance_by_metric import performance_rate_by_metric
+from encord_active.lib.charts.precision_recall import create_pr_chart_plotly
+from encord_active.lib.charts.scopes import PredictionMatchScope
+from encord_active.lib.common.colors import Color
+from encord_active.lib.metrics.utils import MetricScope
+from encord_active.lib.model_predictions.filters import (
+ filter_labels_for_frames_wo_predictions,
+ prediction_and_label_filtering,
+)
+from encord_active.lib.model_predictions.map_mar import (
+ PerformanceMetricSchema,
+ PrecisionRecallSchema,
+ compute_mAP_and_mAR,
+)
+from encord_active.lib.model_predictions.reader import (
+ LabelMatchSchema,
+ PredictionMatchSchema,
+ get_class_idx,
+)
+from encord_active.lib.model_predictions.writer import MainPredictionType
+
+
+class ObjectTypeBuilder(PredictionTypeBuilder):
+ title = "Object"
+
+ class OutcomeType(str, Enum):
+ TRUE_POSITIVES = "True Positive"
+ FALSE_POSITIVES = "False Positive"
+ FALSE_NEGATIVES = "False Negative"
+
+ def __init__(self):
+ self._explorer_outcome_type = self.OutcomeType.TRUE_POSITIVES
+ self._model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None
+ self._labels: Optional[DataFrame[LabelMatchSchema]] = None
+ self._metrics: Optional[DataFrame[PerformanceMetricSchema]] = None
+ self._precisions: Optional[DataFrame[PrecisionRecallSchema]] = None
+
+ def load_data(self, page_mode: ModelQualityPage) -> bool:
+ predictions_dir = get_state().project_paths.predictions / MainPredictionType.OBJECT.value
+ predictions_metric_datas, label_metric_datas, model_predictions, labels = self._read_prediction_files(
+ MainPredictionType.OBJECT, project_path=get_state().project_paths.project_dir.as_posix()
+ )
+
+ if model_predictions is None:
+ st.error("Couldn't load model predictions")
+ return False
+
+ if labels is None:
+ st.error("Couldn't load labels properly")
+ return False
+
+ get_state().predictions.metric_datas = MetricNames(
+ predictions={m.name: m for m in predictions_metric_datas},
+ labels={m.name: m for m in label_metric_datas},
+ )
+
+ with sticky_header():
+ self._common_settings()
+ self._topbar_additional_settings(page_mode)
+
+ matched_gt, _ = use_memo(
+ lambda: reader.get_gt_matched(predictions_dir),
+ key=f"matched_gt_{get_state().project_paths.project_dir.as_posix()}",
+ )
+ if not matched_gt:
+ st.error("Couldn't match ground truths")
+ return False
+
+ (predictions_filtered, labels_filtered, metrics, precisions,) = compute_mAP_and_mAR(
+ model_predictions,
+ labels,
+ matched_gt,
+ get_state().predictions.all_classes_objects,
+ iou_threshold=get_state().iou_threshold,
+ ignore_unmatched_frames=get_state().ignore_frames_without_predictions,
+ )
+
+ # Sort predictions and labels according to selected metrics.
+ pred_sort_column = get_state().predictions.metric_datas.selected_prediction or predictions_metric_datas[0].name
+ sorted_model_predictions = predictions_filtered.sort_values([pred_sort_column], axis=0)
+
+ label_sort_column = get_state().predictions.metric_datas.selected_label or label_metric_datas[0].name
+ sorted_labels = labels_filtered.sort_values([label_sort_column], axis=0)
+
+ if get_state().ignore_frames_without_predictions:
+ labels_filtered = filter_labels_for_frames_wo_predictions(predictions_filtered, sorted_labels)
+ else:
+ labels_filtered = sorted_labels
+
+ self._labels, self._metrics, self._model_predictions, self._precisions = prediction_and_label_filtering(
+ get_state().predictions.selected_classes_objects,
+ labels_filtered,
+ metrics,
+ sorted_model_predictions,
+ precisions,
+ )
+
+ return True
+
+ def _common_settings(self):
+ if not get_state().predictions.all_classes_objects:
+ get_state().predictions.all_classes_objects = get_class_idx(
+ get_state().project_paths.predictions / MainPredictionType.OBJECT.value
+ )
+
+ all_classes = get_state().predictions.all_classes_objects
+ col1, col2, col3 = st.columns([4, 4, 3])
+
+ with col1:
+ selected_classes = self._render_class_filtering_component(all_classes)
+
+ get_state().predictions.selected_classes_objects = dict(selected_classes) or deepcopy(all_classes)
+
+ with col2:
+ # IOU
+ get_state().iou_threshold = st.slider(
+ "Select an IOU threshold",
+ min_value=0.0,
+ max_value=1.0,
+ value=State.iou_threshold,
+ help="The mean average precision (mAP) score is based on true positives and false positives. "
+ "The IOU threshold determines how closely predictions need to match labels to be considered "
+ "as true positives.",
+ )
+
+ with col3:
+ st.write("")
+ st.write("")
+ # Ignore unmatched frames
+ get_state().ignore_frames_without_predictions = st.checkbox(
+ "Ignore frames without predictions",
+ value=State.ignore_frames_without_predictions,
+ help="Scores like mAP and mAR are effected negatively if there are frames in the dataset for /"
+ "which there exist no predictions. With this flag, you can ignore those.",
+ )
+
+ def _topbar_additional_settings(self, page_mode: ModelQualityPage):
+ if page_mode == ModelQualityPage.METRICS:
+ return
+ elif page_mode == ModelQualityPage.PERFORMANCE_BY_METRIC:
+ c1, c2, c3 = st.columns([4, 4, 3])
+ with c1:
+ self._topbar_metric_selection_component(MetricType.PREDICTION, get_state().predictions.metric_datas)
+ with c2:
+ self._set_binning()
+ with c3:
+ self._class_decomposition()
+ elif page_mode == ModelQualityPage.EXPLORER:
+ c1, c2 = st.columns([4, 4])
+ with c1:
+ self._explorer_outcome_type = st.selectbox(
+ "Outcome",
+ [x for x in self.OutcomeType],
+ format_func=lambda x: x.value,
+ help="Only the samples with this outcome will be shown",
+ )
+ with c2:
+ explorer_metric_type = (
+ MetricType.PREDICTION
+ if self._explorer_outcome_type
+ in [self.OutcomeType.TRUE_POSITIVES, self.OutcomeType.FALSE_POSITIVES]
+ else MetricType.LABEL
+ )
+
+ self._topbar_metric_selection_component(explorer_metric_type, get_state().predictions.metric_datas)
+
+ self.display_settings(MetricScope.MODEL_QUALITY)
+
+ def _get_metric_name(self) -> Optional[str]:
+ if self._explorer_outcome_type in [self.OutcomeType.TRUE_POSITIVES, self.OutcomeType.FALSE_POSITIVES]:
+ return get_state().predictions.metric_datas.selected_prediction
+ else:
+ return get_state().predictions.metric_datas.selected_label
+
+ def _render_explorer_details(self) -> Optional[Color]:
+ color: Optional[Color] = None
+ with st.expander("Details"):
+ if self._explorer_outcome_type == self.OutcomeType.TRUE_POSITIVES:
+ color = Color.PURPLE
+
+ st.markdown(
+ f"""### The view
+These are the predictions for which the IOU was sufficiently high and the confidence score was
+the highest amongst predictions that overlap with the label.
+
+---
+
+**Color**:
+The {color.name.lower()} boxes marks the true positive predictions.
+The remaining colors correspond to the dataset labels with the colors you are used to from the label editor.
+ """,
+ unsafe_allow_html=True,
+ )
+
+ elif self._explorer_outcome_type == self.OutcomeType.FALSE_POSITIVES:
+ color = Color.RED
+
+ st.markdown(
+ f"""### The view
+These are the predictions for which either of the following is true
+1. The IOU between the prediction and the best matching label was too low
+2. There was another prediction with higher model confidence which matched the label already
+3. The predicted class didn't match
+
+---
+
+**Color**:
+The {color.name.lower()} boxes marks the false positive predictions.
+The remaining colors correspond to the dataset labels with the colors you are used to from the label editor.
+ """,
+ unsafe_allow_html=True,
+ )
+
+ elif self._explorer_outcome_type == self.OutcomeType.FALSE_NEGATIVES:
+ color = Color.PURPLE
+
+ st.markdown(
+ f"""### The view
+These are the labels that were not matched with any predictions.
+
+---
+**Color**:
+The {color.name.lower()} boxes mark the false negatives. That is, the labels that were not
+matched to any predictions. The remaining objects are predictions, where colors correspond to their predicted class
+(identical colors to labels objects in the editor).
+ """,
+ unsafe_allow_html=True,
+ )
+ self._metric_details_description(get_state().predictions.metric_datas)
+ return color
+
+ def _get_target_df(
+ self, metric_name: str
+ ) -> Union[DataFrame[PredictionMatchSchema], DataFrame[LabelMatchSchema], None]:
+
+ if self._model_predictions is not None and self._labels is not None:
+ if self._explorer_outcome_type == self.OutcomeType.TRUE_POSITIVES:
+ return self._model_predictions[
+ self._model_predictions[PredictionMatchSchema.is_true_positive] == 1.0
+ ].dropna(subset=[metric_name])
+ elif self._explorer_outcome_type == self.OutcomeType.FALSE_POSITIVES:
+ return self._model_predictions[
+ self._model_predictions[PredictionMatchSchema.is_true_positive] == 0.0
+ ].dropna(subset=[metric_name])
+ elif self._explorer_outcome_type == self.OutcomeType.FALSE_NEGATIVES:
+ return self._labels[self._labels[LabelMatchSchema.is_false_negative]].dropna(subset=[metric_name])
+ else:
+ return None
+ else:
+ return None
+
+ def is_available(self) -> bool:
+ return reader.check_model_prediction_availability(
+ get_state().project_paths.predictions / MainPredictionType.OBJECT.value
+ )
+
+ def render_metrics(self):
+ _map = self._metrics[self._metrics[PerformanceMetricSchema.metric] == "mAP"]["value"].item()
+ _mar = self._metrics[self._metrics[PerformanceMetricSchema.metric] == "mAR"]["value"].item()
+ col1, col2 = st.columns(2)
+ col1.metric("mAP", f"{_map:.3f}")
+ col2.metric("mAR", f"{_mar:.3f}")
+
+ # METRIC IMPORTANCE
+ self._get_metric_importance(
+ self._model_predictions, list(get_state().predictions.metric_datas.predictions.keys())
+ )
+
+ st.subheader("Subset selection scores")
+ with st.container():
+ project_ontology = json.loads(get_state().project_paths.ontology.read_text(encoding="utf-8"))
+ chart = create_pr_chart_plotly(self._metrics, self._precisions, project_ontology["objects"])
+ st.plotly_chart(chart, use_container_width=True)
+
+ def render_performance_by_metric(self):
+ self._render_performance_by_metric_description(self._model_predictions, get_state().predictions.metric_datas)
+ metric_name = get_state().predictions.metric_datas.selected_prediction
+
+ label_metric_name = metric_name
+ if metric_name[-3:] == "(P)": # Replace the P with O: "Metric (P)" -> "Metric (O)"
+ label_metric_name = re.sub(r"(.*?\()P(\))", r"\1O\2", metric_name)
+
+ if label_metric_name not in self._labels.columns:
+ label_metric_name = re.sub(
+ r"(.*?\()O(\))", r"\1F\2", label_metric_name
+ ) # Look for it in frame label metrics.
+
+ classes_for_coloring = ["Average"]
+ decompose_classes = get_state().predictions.decompose_classes
+ if decompose_classes:
+ unique_classes = set(self._model_predictions["class_name"].unique()).union(
+ self._labels["class_name"].unique()
+ )
+ classes_for_coloring += sorted(list(unique_classes))
+
+ # Ensure same colors between plots
+ chart_args = dict(
+ color_params={"scale": alt.Scale(domain=classes_for_coloring)},
+ bins=get_state().predictions.nbins,
+ show_decomposition=decompose_classes,
+ )
+
+ try:
+ if metric_name in self._model_predictions.columns:
+ tpr = performance_rate_by_metric(
+ self._model_predictions, metric_name, scope=PredictionMatchScope.TRUE_POSITIVES, **chart_args
+ )
+ if tpr is not None:
+ st.altair_chart(tpr.interactive(), use_container_width=True)
+ else:
+ st.info(f"True Positive Rate is not available for `{metric_name}` metric")
+ except Exception as e:
+ logger.warning(e)
+ pass
+
+ try:
+ if label_metric_name in self._labels.columns:
+ fnr = performance_rate_by_metric(
+ self._labels, label_metric_name, scope=PredictionMatchScope.FALSE_NEGATIVES, **chart_args
+ )
+ if fnr is not None:
+ st.altair_chart(fnr.interactive(), use_container_width=True)
+ else:
+ st.info(f"False Negative Rate is not available for `{label_metric_name}` metric")
+ except Exception as e:
+ logger.warning(e)
+ pass
+
+ def render_explorer(self):
+ metric_name = self._get_metric_name()
+ if not metric_name:
+ st.error("No metric selected")
+ return
+
+ color = self._render_explorer_details()
+ if color is None:
+ st.warning("An error occurred while rendering the explorer page")
+
+ view_df = self._get_target_df(metric_name)
+ if view_df is None:
+ st.error(f"An error occurred during getting data according to the {metric_name} metric")
+ return
+
+ if view_df.shape[0] == 0:
+ st.write(f"No {self._explorer_outcome_type}")
+ else:
+ histogram = get_histogram(view_df, metric_name)
+ st.altair_chart(histogram, use_container_width=True)
+ if self._explorer_outcome_type in [self.OutcomeType.TRUE_POSITIVES, self.OutcomeType.FALSE_POSITIVES]:
+ prediction_grid(get_state().project_paths, model_predictions=view_df, box_color=color)
+ else:
+ prediction_grid(
+ get_state().project_paths,
+ model_predictions=self._model_predictions,
+ labels=view_df,
+ box_color=color,
+ )
diff --git a/src/encord_active/app/model_quality/sub_pages/__init__.py b/src/encord_active/app/model_quality/sub_pages/__init__.py
deleted file mode 100644
index f6870db1c..000000000
--- a/src/encord_active/app/model_quality/sub_pages/__init__.py
+++ /dev/null
@@ -1,181 +0,0 @@
-from abc import abstractmethod
-from typing import Optional
-
-import streamlit as st
-from loguru import logger
-from pandera.typing import DataFrame
-from streamlit.delta_generator import DeltaGenerator
-
-import encord_active.app.common.state as state
-from encord_active.app.common.page import Page
-from encord_active.app.common.state import MetricNames
-from encord_active.lib.constants import DOCS_URL
-from encord_active.lib.model_predictions.map_mar import (
- PerformanceMetricSchema,
- PrecisionRecallSchema,
-)
-from encord_active.lib.model_predictions.reader import (
- ClassificationPredictionMatchSchema,
- LabelMatchSchema,
- PredictionMatchSchema,
-)
-
-
-class ModelQualityPage(Page):
- @property
- @abstractmethod
- def title(self) -> str:
- pass
-
- @abstractmethod
- def sidebar_options(self):
- """
- Used to append options to the sidebar.
- """
- pass
-
- def sidebar_options_classifications(self):
- pass
-
- def check_building_object_quality(
- self,
- object_predictions_exist: bool,
- object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None,
- object_labels: Optional[DataFrame[LabelMatchSchema]] = None,
- object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None,
- object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None,
- ) -> bool:
- if not object_predictions_exist:
- st.markdown(
- "## Missing model predictions for the objects\n"
- "This project does not have any imported predictions for the objects. "
- "Please refer to the "
- f"[Importing Model Predictions]({DOCS_URL}/sdk/importing-model-predictions) "
- "section of the documentation to learn how to import your predictions."
- )
- return False
- elif not (
- (object_model_predictions is not None)
- and (object_labels is not None)
- and (object_metrics is not None)
- and (object_precisions is not None)
- ):
- logger.error(
- "If object_prediction_exist is True, the followings should be provided: object_model_predictions, \
- object_labels, object_metrics, object_precisions"
- )
- return False
-
- return True
-
- def check_building_classification_quality(
- self,
- classification_predictions_exist: bool,
- classification_labels: Optional[list] = None,
- classification_pred: Optional[list] = None,
- classification_model_predictions_matched: Optional[DataFrame[ClassificationPredictionMatchSchema]] = None,
- ) -> bool:
- if not classification_predictions_exist:
- st.markdown(
- "## Missing model predictions for the classifications\n"
- "This project does not have any imported predictions for the classifications. "
- "Please refer to the "
- f"[Importing Model Predictions]({DOCS_URL}/sdk/importing-model-predictions) "
- "section of the documentation to learn how to import your predictions."
- )
- return False
- elif not (
- (classification_labels is not None)
- and (classification_pred is not None)
- and (classification_model_predictions_matched is not None),
- ):
- logger.error(
- "If classification_predictions_exist is True, the followings should be provided: classification_labels, \
- classification_pred, classification_model_predictions_matched_filtered"
- )
- return False
-
- return True
-
- @abstractmethod
- def build(
- self,
- object_predictions_exist: bool,
- classification_predictions_exist: bool,
- object_tab: DeltaGenerator,
- classification_tab: DeltaGenerator,
- object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None,
- object_labels: Optional[DataFrame[LabelMatchSchema]] = None,
- object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None,
- object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None,
- classification_labels: Optional[list] = None,
- classification_pred: Optional[list] = None,
- classification_model_predictions_matched_filtered: Optional[
- DataFrame[ClassificationPredictionMatchSchema]
- ] = None,
- ):
- """
- If object_predictions_exist is True, the followings should be provided: object_model_predictions, \
- object_labels, object_metrics, object_precisions
- If classification_predictions_exist is True, the followings should be provided: classification_labels, \
- classification_pred, classification_model_predictions_matched_filtered
- """
-
- pass
-
- def __call__(
- self,
- model_predictions: DataFrame[PredictionMatchSchema],
- labels: DataFrame[LabelMatchSchema],
- metrics: DataFrame[PerformanceMetricSchema],
- precisions: DataFrame[PrecisionRecallSchema],
- ):
- return self.build(model_predictions, labels, metrics, precisions)
-
- def __repr__(self):
- return f"{type(self).__name__}()"
-
- @staticmethod
- def prediction_metric_in_sidebar_objects():
- """
- Note: Adding the fixed options "confidence" and "iou" works here because
- confidence is required on import and IOU is computed during prediction
- import. So the two columns are already available in the
- `st.session_state.model_predictions` data frame.
- """
- fixed_options = {"confidence": "Model Confidence", "iou": "IOU"}
- column_names = list(state.get_state().predictions.metric_datas.predictions.keys())
- state.get_state().predictions.metric_datas.selected_prediction = st.selectbox(
- "Select metric for your predictions",
- column_names + list(fixed_options.keys()),
- format_func=lambda s: fixed_options.get(s, s),
- help="The data in the main view will be sorted by the selected metric. "
- "(F) := frame scores, (P) := prediction scores.",
- )
-
- @staticmethod
- def prediction_metric_in_sidebar_classifications():
- fixed_options = {"confidence": "Model Confidence"}
- column_names = list(state.get_state().predictions.metric_datas_classification.predictions.keys())
- state.get_state().predictions.metric_datas_classification.selected_prediction = st.selectbox(
- "Select metric for your predictions",
- column_names + list(fixed_options.keys()),
- format_func=lambda s: fixed_options.get(s, s),
- help="The data in the main view will be sorted by the selected metric. "
- "(F) := frame scores, (P) := prediction scores.",
- )
-
- @staticmethod
- def metric_details_description(metric_datas: MetricNames):
- metric_name = metric_datas.selected_prediction
- if not metric_name:
- return
-
- metric_data = metric_datas.predictions.get(metric_name)
-
- if not metric_data:
- metric_data = metric_datas.labels.get(metric_name)
-
- if metric_data:
- st.markdown(f"### The {metric_data.name[:-4]} metric")
- st.markdown(metric_data.meta.long_description)
diff --git a/src/encord_active/app/model_quality/sub_pages/false_negatives.py b/src/encord_active/app/model_quality/sub_pages/false_negatives.py
deleted file mode 100644
index f95d2ef62..000000000
--- a/src/encord_active/app/model_quality/sub_pages/false_negatives.py
+++ /dev/null
@@ -1,109 +0,0 @@
-from typing import Optional, cast
-
-import streamlit as st
-from pandera.typing import DataFrame
-from streamlit.delta_generator import DeltaGenerator
-
-from encord_active.app.common.components.prediction_grid import prediction_grid
-from encord_active.app.common.state import get_state
-from encord_active.lib.charts.histogram import get_histogram
-from encord_active.lib.common.colors import Color
-from encord_active.lib.metrics.utils import MetricScope
-from encord_active.lib.model_predictions.map_mar import (
- PerformanceMetricSchema,
- PrecisionRecallSchema,
-)
-from encord_active.lib.model_predictions.reader import (
- ClassificationPredictionMatchSchemaWithClassNames,
- LabelMatchSchema,
- PredictionMatchSchema,
-)
-
-from . import ModelQualityPage
-
-
-class FalseNegativesPage(ModelQualityPage):
- title = "🔍 False Negatives"
-
- def sidebar_options(self):
- metric_columns = list(get_state().predictions.metric_datas.labels.keys())
- get_state().predictions.metric_datas.selected_label = st.selectbox(
- "Select metric for your labels",
- metric_columns,
- help="The data in the main view will be sorted by the selected metric. "
- "(F) := frame scores, (O) := object scores.",
- )
- self.display_settings(MetricScope.MODEL_QUALITY)
-
- def sidebar_options_classifications(self):
- pass
-
- def _build_objects(
- self,
- object_model_predictions: DataFrame[PredictionMatchSchema],
- object_labels: DataFrame[LabelMatchSchema],
- ):
- metric_name = get_state().predictions.metric_datas.selected_label
- if not metric_name:
- st.error("Prediction label not selected")
- return
-
- with st.expander("Details"):
- color = Color.PURPLE
- st.markdown(
- f"""### The view
-These are the labels that were not matched with any predictions.
-
----
-**Color**:
-The {color.name.lower()} boxes mark the false negatives.
-That is, the labels that were not matched to any predictions.
-The remaining objects are predictions, where colors correspond to their predicted class (identical colors to labels objects in the editor).
- """,
- unsafe_allow_html=True,
- )
- self.metric_details_description(get_state().predictions.metric_datas)
- fns_df = object_labels[object_labels[LabelMatchSchema.is_false_negative]].dropna(subset=[metric_name])
- if fns_df.shape[0] == 0:
- st.write("No false negatives")
- else:
- histogram = get_histogram(fns_df, metric_name)
- st.altair_chart(histogram, use_container_width=True)
- prediction_grid(
- get_state().project_paths,
- labels=fns_df,
- model_predictions=object_model_predictions,
- box_color=color,
- )
-
- def build(
- self,
- object_predictions_exist: bool,
- classification_predictions_exist: bool,
- object_tab: DeltaGenerator,
- classification_tab: DeltaGenerator,
- object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None,
- object_labels: Optional[DataFrame[LabelMatchSchema]] = None,
- object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None,
- object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None,
- classification_labels: Optional[list] = None,
- classification_pred: Optional[list] = None,
- classification_model_predictions_matched: Optional[
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames]
- ] = None,
- ):
-
- with object_tab:
- if self.check_building_object_quality(
- object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions
- ):
- self._build_objects(
- cast(DataFrame[PredictionMatchSchema], object_model_predictions),
- cast(DataFrame[LabelMatchSchema], object_labels),
- )
-
- with classification_tab:
- st.markdown(
- "## False Negatives view for the classification predictions is not available\n"
- "Please use **Filter by class** field in True Positives page to inspect different classes."
- )
diff --git a/src/encord_active/app/model_quality/sub_pages/false_positives.py b/src/encord_active/app/model_quality/sub_pages/false_positives.py
deleted file mode 100644
index 7153f5711..000000000
--- a/src/encord_active/app/model_quality/sub_pages/false_positives.py
+++ /dev/null
@@ -1,142 +0,0 @@
-from typing import Optional, cast
-
-import streamlit as st
-from pandera.typing import DataFrame
-from streamlit.delta_generator import DeltaGenerator
-
-from encord_active.app.common.components.prediction_grid import (
- prediction_grid,
- prediction_grid_classifications,
-)
-from encord_active.app.common.state import get_state
-from encord_active.lib.charts.histogram import get_histogram
-from encord_active.lib.common.colors import Color
-from encord_active.lib.metrics.utils import MetricScope
-from encord_active.lib.model_predictions.map_mar import (
- PerformanceMetricSchema,
- PrecisionRecallSchema,
-)
-from encord_active.lib.model_predictions.reader import (
- ClassificationPredictionMatchSchemaWithClassNames,
- LabelMatchSchema,
- PredictionMatchSchema,
-)
-
-from . import ModelQualityPage
-
-
-class FalsePositivesPage(ModelQualityPage):
- title = "🌡 False Positives"
-
- def sidebar_options(self):
- self.prediction_metric_in_sidebar_objects()
- self.display_settings(MetricScope.MODEL_QUALITY)
-
- def sidebar_options_classifications(self):
- self.prediction_metric_in_sidebar_classifications()
-
- def _build_objects(
- self,
- object_model_predictions: DataFrame[PredictionMatchSchema],
- ):
- metric_name = get_state().predictions.metric_datas.selected_prediction
- if not metric_name:
- st.error("No prediction metric selected")
- return
-
- st.markdown(f"# {self.title}")
- color = Color.RED
- with st.expander("Details"):
- st.markdown(
- f"""### The view
-These are the predictions for which either of the following is true
-1. The IOU between the prediction and the best matching label was too low
-2. There was another prediction with higher model confidence which matched the label already
-3. The predicted class didn't match
-
----
-
-**Color**:
-The {color.name.lower()} boxes marks the false positive predictions.
-The remaining colors correspond to the dataset labels with the colors you are used to from the label editor.
- """,
- unsafe_allow_html=True,
- )
- self.metric_details_description(get_state().predictions.metric_datas)
-
- fp_df = object_model_predictions[
- object_model_predictions[PredictionMatchSchema.is_true_positive] == 0.0
- ].dropna(subset=[metric_name])
- if fp_df.shape[0] == 0:
- st.write("No false positives")
- else:
- histogram = get_histogram(fp_df, metric_name)
- st.altair_chart(histogram, use_container_width=True)
- prediction_grid(get_state().project_paths, model_predictions=fp_df, box_color=color)
-
- def _build_classifications(
- self,
- classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- ):
- with st.expander("Details"):
- st.markdown(
- """### The view
-These are the predictions where the model incorrectly predicts the positive class.
- """,
- unsafe_allow_html=True,
- )
- self.metric_details_description(get_state().predictions.metric_datas_classification)
-
- metric_name = get_state().predictions.metric_datas_classification.selected_prediction
- if not metric_name:
- st.error("No prediction metric selected")
- return
-
- fp_df = classification_model_predictions_matched[
- classification_model_predictions_matched[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive]
- == 0.0
- ].dropna(subset=[metric_name])
-
- if fp_df.shape[0] == 0:
- st.write("No false positives")
- else:
- histogram = get_histogram(fp_df, metric_name)
- st.altair_chart(histogram, use_container_width=True)
- prediction_grid_classifications(get_state().project_paths, model_predictions=fp_df)
-
- def build(
- self,
- object_predictions_exist: bool,
- classification_predictions_exist: bool,
- object_tab: DeltaGenerator,
- classification_tab: DeltaGenerator,
- object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None,
- object_labels: Optional[DataFrame[LabelMatchSchema]] = None,
- object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None,
- object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None,
- classification_labels: Optional[list] = None,
- classification_pred: Optional[list] = None,
- classification_model_predictions_matched: Optional[
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames]
- ] = None,
- ):
-
- with object_tab:
- if self.check_building_object_quality(
- object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions
- ):
- self._build_objects(cast(DataFrame[PredictionMatchSchema], object_model_predictions))
-
- with classification_tab:
- if self.check_building_classification_quality(
- classification_predictions_exist,
- classification_labels,
- classification_pred,
- classification_model_predictions_matched,
- ):
- self._build_classifications(
- cast(
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- classification_model_predictions_matched,
- )
- )
diff --git a/src/encord_active/app/model_quality/sub_pages/metrics.py b/src/encord_active/app/model_quality/sub_pages/metrics.py
deleted file mode 100644
index 920551b7f..000000000
--- a/src/encord_active/app/model_quality/sub_pages/metrics.py
+++ /dev/null
@@ -1,240 +0,0 @@
-import json
-from typing import Optional, cast
-
-import streamlit as st
-from pandera.typing import DataFrame
-from streamlit.delta_generator import DeltaGenerator
-
-from encord_active.app.common.state import get_state
-from encord_active.lib.charts.classification_metrics import (
- get_accuracy,
- get_confusion_matrix,
- get_precision_recall_f1,
- get_precision_recall_graph,
-)
-from encord_active.lib.charts.metric_importance import create_metric_importance_charts
-from encord_active.lib.charts.precision_recall import create_pr_chart_plotly
-from encord_active.lib.model_predictions.map_mar import (
- PerformanceMetricSchema,
- PrecisionRecallSchema,
-)
-from encord_active.lib.model_predictions.reader import (
- ClassificationPredictionMatchSchemaWithClassNames,
- LabelMatchSchema,
- PredictionMatchSchema,
-)
-from encord_active.lib.model_predictions.writer import MainPredictionType
-
-from . import ModelQualityPage
-
-_M_COLS = PerformanceMetricSchema
-
-
-class MetricsPage(ModelQualityPage):
- title = "📈 Metrics"
-
- def sidebar_options(self):
- pass
-
- def sidebar_options_classifications(self):
- pass
-
- def _build_objects(
- self,
- object_model_predictions: DataFrame[PredictionMatchSchema],
- object_metrics: DataFrame[PerformanceMetricSchema],
- object_precisions: DataFrame[PrecisionRecallSchema],
- ):
-
- _map = object_metrics[object_metrics[_M_COLS.metric] == "mAP"]["value"].item()
- _mar = object_metrics[object_metrics[_M_COLS.metric] == "mAR"]["value"].item()
- col1, col2 = st.columns(2)
- col1.metric("mAP", f"{_map:.3f}")
- col2.metric("mAR", f"{_mar:.3f}")
- st.markdown("""---""")
-
- st.subheader("Metric Importance")
- with st.container():
- with st.expander("Description"):
- st.write(
- "The following charts show the dependency between model performance and each index. "
- "In other words, these charts answer the question of how much is model "
- "performance affected by each index. This relationship can be decomposed into two metrics:"
- )
- st.markdown(
- "- **Metric importance**: measures the *strength* of the dependency between and metric and model "
- "performance. A high value means that the model performance would be strongly affected by "
- "a change in the index. For example, a high importance in 'Brightness' implies that a change "
- "in that quantity would strongly affect model performance. Values range from 0 (no dependency) "
- "to 1 (perfect dependency, one can completely predict model performance simply by looking "
- "at this index)."
- )
- st.markdown(
- "- **Metric [correlation](https://en.wikipedia.org/wiki/Correlation)**: measures the *linearity "
- "and direction* of the dependency between an index and model performance. "
- "Crucially, this metric tells us whether a positive change in an index "
- "will lead to a positive change (positive correlation) or a negative change (negative correlation) "
- "in model performance . Values range from -1 to 1."
- )
- st.write(
- "Finally, you can also select how many samples are included in the computation "
- "with the slider, as well as filter by class with the dropdown in the side bar."
- )
-
- if (
- object_model_predictions.shape[0] > 60_000
- ): # Computation are heavy so allow computing for only a subset.
- num_samples = st.slider(
- "Number of samples",
- min_value=1,
- max_value=len(object_model_predictions),
- step=max(1, (len(object_model_predictions) - 1) // 100),
- value=max((len(object_model_predictions) - 1) // 2, 1),
- help="To avoid too heavy computations, we subsample the data at random to the selected size, "
- "computing importance values.",
- )
- if num_samples < 100:
- st.warning(
- "Number of samples is too low to compute reliable index importances. "
- "We recommend using at least 100 samples.",
- )
- else:
- num_samples = object_model_predictions.shape[0]
-
- metric_columns = list(get_state().predictions.metric_datas.predictions.keys())
- with st.spinner("Computing index importance..."):
- try:
- chart = create_metric_importance_charts(
- object_model_predictions,
- metric_columns=metric_columns,
- num_samples=num_samples,
- prediction_type=MainPredictionType.OBJECT,
- )
- st.altair_chart(chart, use_container_width=True)
- except ValueError as e:
- if e.args:
- st.info(e.args[0])
- else:
- st.info("Failed to compute metric importance")
-
- st.subheader("Subset selection scores")
- with st.container():
- project_ontology = json.loads(get_state().project_paths.ontology.read_text(encoding="utf-8"))
- chart = create_pr_chart_plotly(object_metrics, object_precisions, project_ontology["objects"])
- st.plotly_chart(chart, use_container_width=True)
-
- def _build_classifications(
- self,
- classification_labels: list,
- classification_pred: list,
- classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- ):
- class_names = sorted(list(set(classification_labels).union(classification_pred)))
-
- precision, recall, f1, support = get_precision_recall_f1(classification_labels, classification_pred)
- accuracy = get_accuracy(classification_labels, classification_pred)
-
- # PERFORMANCE METRICS SUMMARY
-
- col_acc, col_prec, col_rec, col_f1 = st.columns(4)
- col_acc.metric("Accuracy", f"{float(accuracy):.2f}")
- col_prec.metric(
- "Mean Precision",
- f"{float(precision.mean()):.2f}",
- help="Average of precision scores of all classes",
- )
- col_rec.metric("Mean Recall", f"{float(recall.mean()):.2f}", help="Average of recall scores of all classes")
- col_f1.metric("Mean F1", f"{float(f1.mean()):.2f}", help="Average of F1 scores of all classes")
-
- # METRIC IMPORTANCE
- if (
- classification_model_predictions_matched.shape[0] > 60_000
- ): # Computation are heavy so allow computing for only a subset.
- num_samples = st.slider(
- "Number of samples",
- min_value=1,
- max_value=len(classification_model_predictions_matched),
- step=max(1, (len(classification_model_predictions_matched) - 1) // 100),
- value=max((len(classification_model_predictions_matched) - 1) // 2, 1),
- help="To avoid too heavy computations, we subsample the data at random to the selected size, "
- "computing importance values.",
- )
- if num_samples < 100:
- st.warning(
- "Number of samples is too low to compute reliable index importances. "
- "We recommend using at least 100 samples.",
- )
- else:
- num_samples = classification_model_predictions_matched.shape[0]
-
- metric_columns = list(get_state().predictions.metric_datas_classification.predictions.keys())
- try:
- metric_importance_chart = create_metric_importance_charts(
- classification_model_predictions_matched,
- metric_columns=metric_columns,
- num_samples=num_samples,
- prediction_type=MainPredictionType.CLASSIFICATION,
- )
- st.altair_chart(metric_importance_chart, use_container_width=True)
- except ValueError as e:
- if e.args:
- st.info(e.args[0])
- else:
- st.info("Failed to compute metric importance")
-
- col1, col2 = st.columns(2)
-
- # CONFUSION MATRIX
- confusion_matrix = get_confusion_matrix(classification_labels, classification_pred, class_names)
- col1.plotly_chart(confusion_matrix, use_container_width=True)
-
- # PRECISION_RECALL BARS
- pr_graph = get_precision_recall_graph(precision, recall, class_names)
- col2.plotly_chart(pr_graph, use_container_width=True)
-
- # In order to plot ROC curve, we need confidences for the ground
- # truth label. Currently, predictions.pkl file only has confidence
- # value for the predicted class.
- # roc_graph = get_roc_curve(classification_labels, classification_pred)
-
- def build(
- self,
- object_predictions_exist: bool,
- classification_predictions_exist: bool,
- object_tab: DeltaGenerator,
- classification_tab: DeltaGenerator,
- object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None,
- object_labels: Optional[DataFrame[LabelMatchSchema]] = None,
- object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None,
- object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None,
- classification_labels: Optional[list] = None,
- classification_pred: Optional[list] = None,
- classification_model_predictions_matched: Optional[
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames]
- ] = None,
- ):
- with object_tab:
- if self.check_building_object_quality(
- object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions
- ):
- self._build_objects(
- cast(DataFrame[PredictionMatchSchema], object_model_predictions),
- cast(DataFrame[PerformanceMetricSchema], object_metrics),
- cast(DataFrame[PrecisionRecallSchema], object_precisions),
- )
-
- with classification_tab:
- if self.check_building_classification_quality(
- classification_predictions_exist,
- classification_labels,
- classification_pred,
- classification_model_predictions_matched,
- ):
- self._build_classifications(
- cast(list, classification_labels),
- cast(list, classification_pred),
- cast(
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- classification_model_predictions_matched,
- ),
- )
diff --git a/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py b/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py
deleted file mode 100644
index 180156acb..000000000
--- a/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py
+++ /dev/null
@@ -1,239 +0,0 @@
-import re
-from typing import Optional, cast
-
-import altair as alt
-import streamlit as st
-from pandera.typing import DataFrame
-from streamlit.delta_generator import DeltaGenerator
-
-import encord_active.app.common.state as state
-from encord_active.app.common.state import MetricNames, PredictionsState, get_state
-from encord_active.lib.charts.performance_by_metric import performance_rate_by_metric
-from encord_active.lib.charts.scopes import PredictionMatchScope
-from encord_active.lib.model_predictions.map_mar import (
- PerformanceMetricSchema,
- PrecisionRecallSchema,
-)
-from encord_active.lib.model_predictions.reader import (
- ClassificationPredictionMatchSchemaWithClassNames,
- LabelMatchSchema,
- PredictionMatchSchema,
-)
-
-from . import ModelQualityPage
-
-FLOAT_FMT = ",.4f"
-PCT_FMT = ",.2f"
-COUNT_FMT = ",d"
-
-
-CHART_TITLES = {
- PredictionMatchScope.TRUE_POSITIVES: "True Positive Rate",
- PredictionMatchScope.FALSE_POSITIVES: "False Positive Rate",
- PredictionMatchScope.FALSE_NEGATIVES: "False Negative Rate",
-}
-
-
-class PerformanceMetric(ModelQualityPage):
- title = "⚡️ Performance by Metric"
-
- def description_expander(self, metric_datas: MetricNames):
- with st.expander("Details", expanded=False):
- st.markdown(
- """### The View
-
-On this page, your model scores are displayed as a function of the metric that you selected in the top bar.
-Samples are discritized into $n$ equally sized buckets and the middle point of each bucket is displayed as the x-value in the plots.
-Bars indicate the number of samples in each bucket, while lines indicate the true positive and false negative rates of each bucket.
-
-
-Metrics marked with (P) are metrics computed on your predictions.
-Metrics marked with (F) are frame level metrics, which depends on the frame that each prediction is associated
-with. In the "False Negative Rate" plot, (O) means metrics computed on Object labels.
-
-For metrics that are computed on predictions (P) in the "True Positive Rate" plot, the corresponding "label metrics" (O/F) computed
-on your labels are used for the "False Negative Rate" plot.
-""",
- unsafe_allow_html=True,
- )
- self.metric_details_description(metric_datas)
-
- def sidebar_options(self):
- c1, c2, c3 = st.columns([4, 4, 3])
- with c1:
- self.prediction_metric_in_sidebar_objects()
-
- with c2:
- get_state().predictions.nbins = int(
- st.number_input(
- "Number of buckets (n)",
- min_value=5,
- max_value=200,
- value=PredictionsState.nbins,
- help="Choose the number of bins to discritize the prediction metric values into.",
- )
- )
- with c3:
- st.write("") # Make some spacing.
- st.write("")
- get_state().predictions.decompose_classes = st.checkbox(
- "Show class decomposition",
- value=PredictionsState.decompose_classes,
- help="When checked, every plot will have a separate component for each class.",
- )
-
- def sidebar_options_classifications(self):
- self.prediction_metric_in_sidebar_classifications()
-
- def _build_objects(
- self,
- object_model_predictions: DataFrame[PredictionMatchSchema],
- object_labels: DataFrame[LabelMatchSchema],
- ):
-
- if object_model_predictions.shape[0] == 0:
- st.write("No predictions of the given class(es).")
- return
-
- metric_name = state.get_state().predictions.metric_datas.selected_prediction
- if not metric_name:
- # This shouldn't happen with the current flow. The only way a user can do this
- # is if he/she write custom code to bypass running the metrics. In this case,
- # I think that it is fair to not give more information than this.
- st.write(
- "No metrics computed for the your model predictions. "
- "With `encord-active import predictions /path/to/predictions.pkl`, "
- "Encord Active will automatically run compute the metrics."
- )
- return
-
- self.description_expander(get_state().predictions.metric_datas)
-
- label_metric_name = metric_name
- if metric_name[-3:] == "(P)": # Replace the P with O: "Metric (P)" -> "Metric (O)"
- label_metric_name = re.sub(r"(.*?\()P(\))", r"\1O\2", metric_name)
-
- if not label_metric_name in object_labels.columns:
- label_metric_name = re.sub(
- r"(.*?\()O(\))", r"\1F\2", label_metric_name
- ) # Look for it in frame label metrics.
-
- classes_for_coloring = ["Average"]
- decompose_classes = get_state().predictions.decompose_classes
- if decompose_classes:
- unique_classes = set(object_model_predictions["class_name"].unique()).union(
- object_labels["class_name"].unique()
- )
- classes_for_coloring += sorted(list(unique_classes))
-
- # Ensure same colors between plots
- chart_args = dict(
- color_params={"scale": alt.Scale(domain=classes_for_coloring)},
- bins=get_state().predictions.nbins,
- show_decomposition=decompose_classes,
- )
-
- try:
- tpr = performance_rate_by_metric(
- object_model_predictions, metric_name, scope=PredictionMatchScope.TRUE_POSITIVES, **chart_args
- )
- if tpr is not None:
- st.altair_chart(tpr.interactive(), use_container_width=True)
- except:
- pass
-
- try:
- fnr = performance_rate_by_metric(
- object_labels, label_metric_name, scope=PredictionMatchScope.FALSE_NEGATIVES, **chart_args
- )
- if fnr is not None:
- st.altair_chart(fnr.interactive(), use_container_width=True)
- except:
- pass
-
- def _build_classifications(
- self,
- classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- ):
- if classification_model_predictions_matched.shape[0] == 0:
- st.write("No predictions of the given class(es).")
- return
-
- metric_name = state.get_state().predictions.metric_datas_classification.selected_prediction
- if not metric_name:
- # This shouldn't happen with the current flow. The only way a user can do this
- # is if he/she write custom code to bypass running the metrics. In this case,
- # I think that it is fair to not give more information than this.
- st.write(
- "No metrics computed for the your model predictions. "
- "With `encord-active import predictions /path/to/predictions.pkl`, "
- "Encord Active will automatically run compute the metrics."
- )
- return
-
- self.description_expander(get_state().predictions.metric_datas_classification)
-
- classes_for_coloring = ["Average"]
- decompose_classes = get_state().predictions.decompose_classes
- if decompose_classes:
- unique_classes = set(classification_model_predictions_matched["class_name"].unique())
- classes_for_coloring += sorted(list(unique_classes))
-
- # Ensure same colors between plots
- chart_args = dict(
- color_params={"scale": alt.Scale(domain=classes_for_coloring)},
- bins=get_state().predictions.nbins,
- show_decomposition=decompose_classes,
- )
-
- try:
- tpr = performance_rate_by_metric(
- classification_model_predictions_matched,
- metric_name,
- scope=PredictionMatchScope.TRUE_POSITIVES,
- **chart_args,
- )
- if tpr is not None:
- st.altair_chart(tpr.interactive(), use_container_width=True)
- except:
- pass
-
- def build(
- self,
- object_predictions_exist: bool,
- classification_predictions_exist: bool,
- object_tab: DeltaGenerator,
- classification_tab: DeltaGenerator,
- object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None,
- object_labels: Optional[DataFrame[LabelMatchSchema]] = None,
- object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None,
- object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None,
- classification_labels: Optional[list] = None,
- classification_pred: Optional[list] = None,
- classification_model_predictions_matched: Optional[
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames]
- ] = None,
- ):
-
- with object_tab:
- if self.check_building_object_quality(
- object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions
- ):
- self._build_objects(
- cast(DataFrame[PredictionMatchSchema], object_model_predictions),
- cast(DataFrame[LabelMatchSchema], object_labels),
- )
-
- with classification_tab:
- if self.check_building_classification_quality(
- classification_predictions_exist,
- classification_labels,
- classification_pred,
- classification_model_predictions_matched,
- ):
- self._build_classifications(
- cast(
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- classification_model_predictions_matched,
- )
- )
diff --git a/src/encord_active/app/model_quality/sub_pages/true_positives.py b/src/encord_active/app/model_quality/sub_pages/true_positives.py
deleted file mode 100644
index 9df9561f7..000000000
--- a/src/encord_active/app/model_quality/sub_pages/true_positives.py
+++ /dev/null
@@ -1,138 +0,0 @@
-from typing import Optional, cast
-
-import streamlit as st
-from pandera.typing import DataFrame
-from streamlit.delta_generator import DeltaGenerator
-
-from encord_active.app.common.components.prediction_grid import (
- prediction_grid,
- prediction_grid_classifications,
-)
-from encord_active.app.common.state import get_state
-from encord_active.lib.charts.histogram import get_histogram
-from encord_active.lib.common.colors import Color
-from encord_active.lib.metrics.utils import MetricScope
-from encord_active.lib.model_predictions.map_mar import (
- PerformanceMetricSchema,
- PrecisionRecallSchema,
-)
-from encord_active.lib.model_predictions.reader import (
- ClassificationPredictionMatchSchemaWithClassNames,
- LabelMatchSchema,
- PredictionMatchSchema,
-)
-
-from . import ModelQualityPage
-
-
-class TruePositivesPage(ModelQualityPage):
- title = "✅ True Positives"
-
- def sidebar_options(self):
- self.prediction_metric_in_sidebar_objects()
- self.display_settings(MetricScope.MODEL_QUALITY)
-
- def sidebar_options_classifications(self):
- self.prediction_metric_in_sidebar_classifications()
-
- def _build_objects(
- self,
- object_model_predictions: DataFrame[PredictionMatchSchema],
- ):
- with st.expander("Details"):
- color = Color.PURPLE
- st.markdown(
- f"""### The view
-These are the predictions for which the IOU was sufficiently high and the confidence score was
-the highest amongst predictions that overlap with the label.
-
----
-
-**Color**:
-The {color.name.lower()} boxes marks the true positive predictions.
-The remaining colors correspond to the dataset labels with the colors you are used to from the label editor.
- """,
- unsafe_allow_html=True,
- )
- self.metric_details_description(get_state().predictions.metric_datas)
-
- metric_name = get_state().predictions.metric_datas.selected_prediction
- if not metric_name:
- st.error("No prediction metric selected")
- return
-
- tp_df = object_model_predictions[
- object_model_predictions[PredictionMatchSchema.is_true_positive] == 1.0
- ].dropna(subset=[metric_name])
- if tp_df.shape[0] == 0:
- st.write("No true positives")
- else:
- histogram = get_histogram(tp_df, metric_name)
- st.altair_chart(histogram, use_container_width=True)
- prediction_grid(get_state().project_paths, model_predictions=tp_df, box_color=color)
-
- def _build_classifications(
- self,
- classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- ):
- with st.expander("Details"):
- st.markdown(
- """### The view
-These are the predictions where the model correctly predicts the true class.
- """,
- unsafe_allow_html=True,
- )
- self.metric_details_description(get_state().predictions.metric_datas_classification)
-
- metric_name = get_state().predictions.metric_datas_classification.selected_prediction
- if not metric_name:
- st.error("No prediction metric selected")
- return
-
- tp_df = classification_model_predictions_matched[
- classification_model_predictions_matched[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive]
- == 1.0
- ].dropna(subset=[metric_name])
-
- if tp_df.shape[0] == 0:
- st.write("No true positives")
- else:
- histogram = get_histogram(tp_df, metric_name)
- st.altair_chart(histogram, use_container_width=True)
- prediction_grid_classifications(get_state().project_paths, model_predictions=tp_df)
-
- def build(
- self,
- object_predictions_exist: bool,
- classification_predictions_exist: bool,
- object_tab: DeltaGenerator,
- classification_tab: DeltaGenerator,
- object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None,
- object_labels: Optional[DataFrame[LabelMatchSchema]] = None,
- object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None,
- object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None,
- classification_labels: Optional[list] = None,
- classification_pred: Optional[list] = None,
- classification_model_predictions_matched: Optional[
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames]
- ] = None,
- ):
- with object_tab:
- if self.check_building_object_quality(
- object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions
- ):
- self._build_objects(cast(DataFrame[PredictionMatchSchema], object_model_predictions))
-
- with classification_tab:
- if self.check_building_classification_quality(
- classification_predictions_exist,
- classification_labels,
- classification_pred,
- classification_model_predictions_matched,
- ):
- self._build_classifications(
- cast(
- DataFrame[ClassificationPredictionMatchSchemaWithClassNames],
- classification_model_predictions_matched,
- )
- )
diff --git a/src/encord_active/app/page_menu.py b/src/encord_active/app/page_menu.py
index a341bf939..42f8f43cc 100644
--- a/src/encord_active/app/page_menu.py
+++ b/src/encord_active/app/page_menu.py
@@ -14,13 +14,7 @@
from encord_active.app.actions_page.versioning import is_latest, version_form
from encord_active.app.common.state import get_state, refresh
from encord_active.app.common.state_hooks import UseState
-from encord_active.app.model_quality.sub_pages.false_negatives import FalseNegativesPage
-from encord_active.app.model_quality.sub_pages.false_positives import FalsePositivesPage
-from encord_active.app.model_quality.sub_pages.metrics import MetricsPage
-from encord_active.app.model_quality.sub_pages.performance_by_metric import (
- PerformanceMetric,
-)
-from encord_active.app.model_quality.sub_pages.true_positives import TruePositivesPage
+from encord_active.app.model_quality.prediction_type_builder import ModelQualityPage
from encord_active.app.views.metrics import explorer, summary
from encord_active.app.views.model_quality import model_quality
from encord_active.lib.metrics.utils import MetricScope
@@ -29,11 +23,9 @@
"Data Quality": {"Summary": summary(MetricScope.DATA_QUALITY), "Explorer": explorer(MetricScope.DATA_QUALITY)},
"Label Quality": {"Summary": summary(MetricScope.LABEL_QUALITY), "Explorer": explorer(MetricScope.LABEL_QUALITY)},
"Model Quality": {
- "Metrics": model_quality(MetricsPage()),
- "Performance By Metric": model_quality(PerformanceMetric()),
- "True Positives": model_quality(TruePositivesPage()),
- "False Positives": model_quality(FalsePositivesPage()),
- "False Negatives": model_quality(FalseNegativesPage()),
+ "Metrics": model_quality(ModelQualityPage.METRICS),
+ "Performance By Metric": model_quality(ModelQualityPage.PERFORMANCE_BY_METRIC),
+ "Explorer": model_quality(ModelQualityPage.EXPLORER),
},
"Actions": {"Filter & Export": export_filter, "Balance & Export": export_balance, "Versioning": version_form},
}
diff --git a/src/encord_active/app/views/model_quality.py b/src/encord_active/app/views/model_quality.py
index f2daae51b..468cd91b6 100644
--- a/src/encord_active/app/views/model_quality.py
+++ b/src/encord_active/app/views/model_quality.py
@@ -1,237 +1,43 @@
-from pathlib import Path
+from typing import List
import streamlit as st
-import encord_active.lib.model_predictions.reader as reader
from encord_active.app.common.components import sticky_header
from encord_active.app.common.components.tags.tag_creator import tag_creator
-from encord_active.app.common.state import MetricNames, get_state
-from encord_active.app.common.state_hooks import use_memo
from encord_active.app.common.utils import setup_page
-from encord_active.app.model_quality.settings import (
- common_settings_classifications,
- common_settings_objects,
+from encord_active.app.model_quality.prediction_type_builder import (
+ ModelQualityPage,
+ PredictionTypeBuilder,
)
-from encord_active.app.model_quality.sub_pages import Page
-from encord_active.lib.model_predictions.classification_metrics import (
- match_predictions_and_labels,
+from encord_active.app.model_quality.prediction_types.classification_type_builder import (
+ ClassificationTypeBuilder,
)
-from encord_active.lib.model_predictions.filters import (
- filter_labels_for_frames_wo_predictions,
- prediction_and_label_filtering,
- prediction_and_label_filtering_classification,
+from encord_active.app.model_quality.prediction_types.object_type_builder import (
+ ObjectTypeBuilder,
)
-from encord_active.lib.model_predictions.map_mar import compute_mAP_and_mAR
-from encord_active.lib.model_predictions.reader import (
- ClassificationLabelSchema,
- ClassificationPredictionMatchSchemaWithClassNames,
- ClassificationPredictionSchema,
-)
-from encord_active.lib.model_predictions.writer import MainPredictionType
-
-
-def model_quality(page: Page):
- def _get_object_model_quality_data(metrics_dir: Path):
- if not reader.check_model_prediction_availability(
- get_state().project_paths.predictions / MainPredictionType.OBJECT.value
- ):
- return False, None, None, None, None
-
- predictions_dir = get_state().project_paths.predictions / MainPredictionType.OBJECT.value
-
- predictions_metric_datas, _ = use_memo(lambda: reader.get_prediction_metric_data(predictions_dir, metrics_dir))
- label_metric_datas, _ = use_memo(lambda: reader.get_label_metric_data(metrics_dir))
- model_predictions, _ = use_memo(
- lambda: reader.get_model_predictions(predictions_dir, predictions_metric_datas, MainPredictionType.OBJECT)
- )
- labels, _ = use_memo(lambda: reader.get_labels(predictions_dir, label_metric_datas, MainPredictionType.OBJECT))
-
- if model_predictions is None:
- st.error("Couldn't load model predictions")
- return
-
- if labels is None:
- st.error("Couldn't load labels properly")
- return
-
- matched_gt, _ = use_memo(lambda: reader.get_gt_matched(predictions_dir))
- get_state().predictions.metric_datas = MetricNames(
- predictions={m.name: m for m in predictions_metric_datas},
- labels={m.name: m for m in label_metric_datas},
- )
-
- if not matched_gt:
- st.error("Couldn't match ground truths")
- return
-
- with sticky_header():
- common_settings_objects()
- page.sidebar_options()
-
- (predictions_filtered, labels_filtered, metrics, precisions,) = compute_mAP_and_mAR(
- model_predictions,
- labels,
- matched_gt,
- get_state().predictions.all_classes_objects,
- iou_threshold=get_state().iou_threshold,
- ignore_unmatched_frames=get_state().ignore_frames_without_predictions,
- )
-
- # Sort predictions and labels according to selected metrics.
- pred_sort_column = get_state().predictions.metric_datas.selected_prediction or predictions_metric_datas[0].name
- sorted_model_predictions = predictions_filtered.sort_values([pred_sort_column], axis=0)
-
- label_sort_column = get_state().predictions.metric_datas.selected_label or label_metric_datas[0].name
- sorted_labels = labels_filtered.sort_values([label_sort_column], axis=0)
-
- if get_state().ignore_frames_without_predictions:
- labels_filtered = filter_labels_for_frames_wo_predictions(predictions_filtered, sorted_labels)
- else:
- labels_filtered = sorted_labels
-
- object_labels, object_metrics, object_model_pred, object_precisions = prediction_and_label_filtering(
- get_state().predictions.selected_classes_objects,
- labels_filtered,
- metrics,
- sorted_model_predictions,
- precisions,
- )
- return True, object_labels, object_metrics, object_model_pred, object_precisions
- def _get_classification_model_quality_data(metrics_dir: Path):
- if not reader.check_model_prediction_availability(
- get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value
- ):
- return False, None, None, None
-
- predictions_dir_classification = get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value
-
- predictions = reader.get_classification_predictions(
- get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value
- )
- labels = reader.get_classification_labels(
- get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value
- )
-
- predictions_metric_datas, _ = use_memo(
- lambda: reader.get_prediction_metric_data(predictions_dir_classification, metrics_dir)
- )
- label_metric_datas, _ = use_memo(lambda: reader.get_label_metric_data(metrics_dir))
- model_predictions, _ = use_memo(
- lambda: reader.get_model_predictions(
- predictions_dir_classification, predictions_metric_datas, MainPredictionType.CLASSIFICATION
- )
- )
- labels_all, _ = use_memo(
- lambda: reader.get_labels(
- predictions_dir_classification, label_metric_datas, MainPredictionType.CLASSIFICATION
- )
- )
-
- get_state().predictions.metric_datas_classification = MetricNames(
- predictions={m.name: m for m in predictions_metric_datas},
- )
-
- if model_predictions is None:
- st.error("Couldn't load model predictions")
- return
-
- if labels_all is None:
- st.error("Couldn't load labels properly")
- return
-
- with sticky_header():
- common_settings_classifications()
- page.sidebar_options_classifications()
-
- model_predictions_matched = match_predictions_and_labels(model_predictions, labels_all)
-
- (
- labels_filtered,
- predictions_filtered,
- model_predictions_matched_filtered,
- ) = prediction_and_label_filtering_classification(
- get_state().predictions.selected_classes_classifications,
- get_state().predictions.all_classes_classifications,
- labels,
- predictions,
- model_predictions_matched,
- )
-
- img_id_intersection = list(
- set(labels_filtered[ClassificationLabelSchema.img_id]).intersection(
- set(predictions_filtered[ClassificationPredictionSchema.img_id])
- )
- )
- labels_filtered_intersection = labels_filtered[
- labels_filtered[ClassificationLabelSchema.img_id].isin(img_id_intersection)
- ]
- predictions_filtered_intersection = predictions_filtered[
- predictions_filtered[ClassificationPredictionSchema.img_id].isin(img_id_intersection)
- ]
-
- y_true, y_pred = (
- list(labels_filtered_intersection[ClassificationLabelSchema.class_id]),
- list(predictions_filtered_intersection[ClassificationPredictionSchema.class_id]),
- )
-
- return (
- True,
- y_true,
- y_pred,
- model_predictions_matched_filtered.copy()[
- model_predictions_matched_filtered[ClassificationPredictionMatchSchemaWithClassNames.img_id].isin(
- img_id_intersection
- )
- ],
- )
+def model_quality(page_mode: ModelQualityPage):
+ def get_available_predictions() -> List[PredictionTypeBuilder]:
+ builders: List[PredictionTypeBuilder] = [ClassificationTypeBuilder(), ObjectTypeBuilder()]
+ return [b for b in builders if b.is_available()]
def render():
setup_page()
- tag_creator()
-
- """
- Note: Streamlit tabs should be initialized here for two reasons:
- 1. Selected tab will be same for each page in the model quality tabs. Otherwise, if we create tabs inside the
- pages, selected tab will be reset in each click.
- 2. Common top bar sticker includes filtering, so their result should be reflected in each page. Otherwise, it
- would be hard to get filters from other pages.
- """
-
- object_tab, classification_tab = st.tabs(["Objects", "Classifications"])
- metrics_dir = get_state().project_paths.metrics
+ with sticky_header():
+ tag_creator()
- with object_tab:
- (
- object_predictions_exist,
- object_labels,
- object_metrics,
- object_model_pred,
- object_precisions,
- ) = _get_object_model_quality_data(metrics_dir)
+ available_predictions: List[PredictionTypeBuilder] = get_available_predictions()
- with classification_tab:
+ if not available_predictions:
+ st.markdown("## No predictions imported into this project.")
+ return
- (
- classification_predictions_exist,
- classification_labels,
- classification_pred,
- classification_model_predictions_matched,
- ) = _get_classification_model_quality_data(metrics_dir)
+ tab_names = [m.title for m in available_predictions]
+ tabs = st.tabs(tab_names) if len(available_predictions) > 1 else [st.container()]
- page.build(
- object_predictions_exist,
- classification_predictions_exist,
- object_tab,
- classification_tab,
- object_model_predictions=object_model_pred,
- object_labels=object_labels,
- object_metrics=object_metrics,
- object_precisions=object_precisions,
- classification_labels=classification_labels,
- classification_pred=classification_pred,
- classification_model_predictions_matched=classification_model_predictions_matched,
- )
+ for tab, builder in zip(tabs, available_predictions):
+ with tab:
+ builder.build(page_mode)
return render
diff --git a/src/encord_active/lib/model_predictions/reader.py b/src/encord_active/lib/model_predictions/reader.py
index a37c21a77..34c4f0346 100644
--- a/src/encord_active/lib/model_predictions/reader.py
+++ b/src/encord_active/lib/model_predictions/reader.py
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from pathlib import Path
-from typing import Any, Callable, Iterable, List, Optional, TypedDict, cast
+from typing import Any, Callable, Iterable, List, Optional, TypedDict, Union, cast
import pandas as pd
import pandera as pa
@@ -170,7 +170,7 @@ def get_prediction_metric_data(predictions_dir: Path, metrics_dir: Path) -> List
def get_model_predictions(
predictions_dir: Path, metric_data: List[MetricData], prediction_type: MainPredictionType
-) -> Optional[DataFrame[PredictionSchema]]:
+) -> Union[DataFrame[PredictionSchema], DataFrame[ClassificationPredictionSchema], None]:
df = _load_csv_and_merge_metrics(predictions_dir / "predictions.csv", metric_data)
if prediction_type == MainPredictionType.CLASSIFICATION:
@@ -194,7 +194,7 @@ def get_label_metric_data(metrics_dir: Path) -> List[MetricData]:
def get_labels(
predictions_dir: Path, metric_data: List[MetricData], prediction_type: MainPredictionType
-) -> Optional[DataFrame[LabelSchema]]:
+) -> Union[DataFrame[LabelSchema], DataFrame[ClassificationLabelSchema], None]:
df = _load_csv_and_merge_metrics(predictions_dir / "labels.csv", metric_data)
if prediction_type == MainPredictionType.OBJECT: