diff --git a/src/encord_active/app/common/state.py b/src/encord_active/app/common/state.py index 3d4e91f68..c574e5e85 100644 --- a/src/encord_active/app/common/state.py +++ b/src/encord_active/app/common/state.py @@ -43,7 +43,7 @@ class PredictionsState: all_classes_objects: Dict[str, OntologyObjectJSON] = field(default_factory=dict) all_classes_classifications: Dict[str, OntologyClassificationJSON] = field(default_factory=dict) selected_classes_objects: Dict[str, OntologyObjectJSON] = field(default_factory=dict) - selected_classes_classifications: Dict[str, OntologyObjectJSON] = field(default_factory=dict) + selected_classes_classifications: Dict[str, OntologyClassificationJSON] = field(default_factory=dict) labels: Optional[DataFrame[LabelSchema]] = None nbins: int = 50 diff --git a/src/encord_active/app/model_quality/prediction_type_builder.py b/src/encord_active/app/model_quality/prediction_type_builder.py new file mode 100644 index 000000000..b82dd98b4 --- /dev/null +++ b/src/encord_active/app/model_quality/prediction_type_builder.py @@ -0,0 +1,297 @@ +from abc import abstractmethod +from enum import Enum +from typing import Dict, List, Tuple, Union + +import streamlit as st +from pandera.typing import DataFrame + +import encord_active.lib.model_predictions.reader as reader +from encord_active.app.common.page import Page +from encord_active.app.common.state import MetricNames, PredictionsState, get_state +from encord_active.app.common.state_hooks import use_memo +from encord_active.lib.charts.metric_importance import create_metric_importance_charts +from encord_active.lib.metrics.utils import MetricData +from encord_active.lib.model_predictions.reader import ( + ClassificationLabelSchema, + ClassificationPredictionMatchSchemaWithClassNames, + ClassificationPredictionSchema, + LabelSchema, + OntologyObjectJSON, + PredictionMatchSchema, + PredictionSchema, +) +from encord_active.lib.model_predictions.writer import ( + MainPredictionType, + OntologyClassificationJSON, +) + + +class ModelQualityPage(str, Enum): + METRICS = "metrics" + PERFORMANCE_BY_METRIC = "performance by metric" + EXPLORER = "explorer" + + +class MetricType(str, Enum): + PREDICTION = "prediction" + LABEL = "label" + + +class PredictionTypeBuilder(Page): + def sidebar_options(self, *args, **kwargs): + pass + + def _render_class_filtering_component( + self, all_classes: Union[Dict[str, OntologyObjectJSON], Dict[str, OntologyClassificationJSON]] + ): + return st.multiselect( + "Filter by class", + list(all_classes.items()), + format_func=lambda x: x[1]["name"], + help=""" +With this selection, you can choose which classes to include in the performance metrics calculations. +This acts as a filter, i.e. when nothing is selected all classes are included. +Performance metrics will be automatically updated according to the chosen classes. + """, + ) + + def _description_expander(self, metric_datas: MetricNames): + with st.expander("Details", expanded=False): + st.markdown( + """### The View + +On this page, your model scores are displayed as a function of the metric that you selected in the top bar. +Samples are discritized into $n$ equally sized buckets and the middle point of each bucket is displayed as the x-value +in the plots. Bars indicate the number of samples in each bucket, while lines indicate the true positive and false +negative rates of each bucket. + +Metrics marked with (P) are metrics computed on your predictions. +Metrics marked with (F) are frame level metrics, which depends on the frame that each prediction is associated +with. In the "False Negative Rate" plot, (O) means metrics computed on Object labels. + +For metrics that are computed on predictions (P) in the "True Positive Rate" plot, the corresponding "label metrics" +(O/F) computed on your labels are used for the "False Negative Rate" plot. +""", + unsafe_allow_html=True, + ) + self._metric_details_description(metric_datas) + + @staticmethod + def _metric_details_description(metric_datas: MetricNames): + metric_name = metric_datas.selected_prediction + if not metric_name: + return + + metric_data = metric_datas.predictions.get(metric_name) + + if not metric_data: + metric_data = metric_datas.labels.get(metric_name) + + if metric_data: + st.markdown(f"### The {metric_data.name[:-4]} metric") + st.markdown(metric_data.meta.long_description) + + def _set_binning(self): + get_state().predictions.nbins = int( + st.number_input( + "Number of buckets (n)", + min_value=5, + max_value=200, + value=PredictionsState.nbins, + help="Choose the number of bins to discritize the prediction metric values into.", + ) + ) + + def _set_sampling_for_metric_importance( + self, + model_predictions: Union[ + DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema] + ], + ) -> int: + num_samples = st.slider( + "Number of samples", + min_value=1, + max_value=len(model_predictions), + step=max(1, (len(model_predictions) - 1) // 100), + value=max((len(model_predictions) - 1) // 2, 1), + help="To avoid too heavy computations, we subsample the data at random to the selected size, " + "computing importance values.", + ) + if num_samples < 100: + st.warning( + "Number of samples is too low to compute reliable index importance. " + "We recommend using at least 100 samples.", + ) + + return num_samples + + def _class_decomposition(self): + st.write("") # Make some spacing. + st.write("") + get_state().predictions.decompose_classes = st.checkbox( + "Show class decomposition", + value=PredictionsState.decompose_classes, + help="When checked, every plot will have a separate component for each class.", + ) + + def _topbar_metric_selection_component(self, metric_type: MetricType, metric_datas: MetricNames): + """ + Note: Adding the fixed options "confidence" and "iou" works here because + confidence is required on import and IOU is computed during prediction + import. So the two columns are already available in the + `st.session_state.model_predictions` data frame. + """ + if metric_type == MetricType.LABEL: + column_names = list(metric_datas.predictions.keys()) + metric_datas.selected_label = st.selectbox( + "Select metric for your labels", + column_names, + help="The data in the main view will be sorted by the selected metric. " + "(F) := frame scores, (O) := object scores.", + ) + elif metric_type == MetricType.PREDICTION: + fixed_options = {"confidence": "Model Confidence"} + column_names = list(metric_datas.predictions.keys()) + metric_datas.selected_prediction = st.selectbox( + "Select metric for your predictions", + column_names + list(fixed_options.keys()), + format_func=lambda s: fixed_options.get(s, s), + help="The data in the main view will be sorted by the selected metric. " + "(F) := frame scores, (P) := prediction scores.", + ) + + def _read_prediction_files( + self, prediction_type: MainPredictionType, project_path: str + ) -> Tuple[ + List[MetricData], + List[MetricData], + Union[DataFrame[PredictionSchema], DataFrame[ClassificationPredictionSchema], None], + Union[DataFrame[LabelSchema], DataFrame[ClassificationLabelSchema], None], + ]: + metrics_dir = get_state().project_paths.metrics + predictions_dir = get_state().project_paths.predictions / prediction_type.value + + predictions_metric_datas, _ = use_memo( + lambda: reader.get_prediction_metric_data(predictions_dir, metrics_dir), + key=f"predictions_metrics_data_{project_path}_{prediction_type.value}", + ) + + label_metric_datas, _ = use_memo( + lambda: reader.get_label_metric_data(metrics_dir), + key=f"label_metric_datas_{project_path}_{prediction_type.value}", + ) + model_predictions, _ = use_memo( + lambda: reader.get_model_predictions(predictions_dir, predictions_metric_datas, prediction_type), + key=f"model_predictions_{project_path}_{prediction_type.value}", + ) + labels, _ = use_memo( + lambda: reader.get_labels(predictions_dir, label_metric_datas, prediction_type), + key=f"labels_{project_path}_{prediction_type.value}", + ) + + return predictions_metric_datas, label_metric_datas, model_predictions, labels + + def _render_performance_by_metric_description( + self, + model_predictions: Union[ + DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema] + ], + metric_datas: MetricNames, + ): + if model_predictions.shape[0] == 0: + st.write("No predictions of the given class(es).") + return + + metric_name = metric_datas.selected_prediction + if not metric_name: + # This shouldn't happen with the current flow. The only way a user can do this + # is if he/she write custom code to bypass running the metrics. In this case, + # I think that it is fair to not give more information than this. + st.write( + "No metrics computed for the your model predictions. " + "With `encord-active import predictions /path/to/predictions.pkl`, " + "Encord Active will automatically run compute the metrics." + ) + return + + self._description_expander(metric_datas) + + def _get_metric_importance( + self, + model_predictions: Union[ + DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema] + ], + metric_columns: List[str], + ): + with st.container(): + with st.expander("Description"): + st.write( + "The following charts show the dependency between model performance and each index. " + "In other words, these charts answer the question of how much is model " + "performance affected by each index. This relationship can be decomposed into two metrics:" + ) + st.markdown( + "- **Metric importance**: measures the *strength* of the dependency between and metric and model " + "performance. A high value means that the model performance would be strongly affected by " + "a change in the index. For example, a high importance in 'Brightness' implies that a change " + "in that quantity would strongly affect model performance. Values range from 0 (no dependency) " + "to 1 (perfect dependency, one can completely predict model performance simply by looking " + "at this index)." + ) + st.markdown( + "- **Metric [correlation](https://en.wikipedia.org/wiki/Correlation)**: measures the *linearity " + "and direction* of the dependency between an index and model performance. " + "Crucially, this metric tells us whether a positive change in an index " + "will lead to a positive change (positive correlation) or a negative change (negative correlation) " + "in model performance . Values range from -1 to 1." + ) + st.write( + "Finally, you can also select how many samples are included in the computation " + "with the slider, as well as filter by class with the dropdown in the side bar." + ) + + if model_predictions.shape[0] > 60_000: # Computation are heavy so allow computing for only a subset. + num_samples = self._set_sampling_for_metric_importance(model_predictions) + else: + num_samples = model_predictions.shape[0] + + with st.spinner("Computing index importance..."): + try: + metric_importance_chart = create_metric_importance_charts( + model_predictions, + metric_columns=metric_columns, + num_samples=num_samples, + prediction_type=MainPredictionType.CLASSIFICATION, + ) + st.altair_chart(metric_importance_chart, use_container_width=True) + except ValueError as e: + if e.args: + st.info(e.args[0]) + else: + st.info("Failed to compute metric importance") + + def build(self, page_mode: ModelQualityPage): + if self.load_data(page_mode): + if page_mode == ModelQualityPage.METRICS: + self.render_metrics() + elif page_mode == ModelQualityPage.PERFORMANCE_BY_METRIC: + self.render_performance_by_metric() + elif page_mode == ModelQualityPage.EXPLORER: + self.render_explorer() + + def render_metrics(self): + st.markdown("### This page is not implemented") + + def render_performance_by_metric(self): + st.markdown("### This page is not implemented") + + def render_explorer(self): + st.markdown("### This page is not implemented") + + @abstractmethod + def load_data(self, page_mode: ModelQualityPage) -> bool: + pass + + @abstractmethod + def is_available(self) -> bool: + pass diff --git a/src/encord_active/app/model_quality/prediction_types/__init__.py b/src/encord_active/app/model_quality/prediction_types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py b/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py new file mode 100644 index 000000000..1549f5e22 --- /dev/null +++ b/src/encord_active/app/model_quality/prediction_types/classification_type_builder.py @@ -0,0 +1,264 @@ +from copy import deepcopy +from enum import Enum +from typing import List, Optional + +import altair as alt +import streamlit as st +from loguru import logger +from pandera.typing import DataFrame + +import encord_active.lib.model_predictions.reader as reader +from encord_active.app.common.components import sticky_header +from encord_active.app.common.components.prediction_grid import ( + prediction_grid_classifications, +) +from encord_active.app.common.state import MetricNames, get_state +from encord_active.app.model_quality.prediction_type_builder import ( + MetricType, + ModelQualityPage, + PredictionTypeBuilder, +) +from encord_active.lib.charts.classification_metrics import ( + get_accuracy, + get_confusion_matrix, + get_precision_recall_f1, + get_precision_recall_graph, +) +from encord_active.lib.charts.histogram import get_histogram +from encord_active.lib.charts.performance_by_metric import performance_rate_by_metric +from encord_active.lib.charts.scopes import PredictionMatchScope +from encord_active.lib.metrics.utils import MetricScope +from encord_active.lib.model_predictions.classification_metrics import ( + match_predictions_and_labels, +) +from encord_active.lib.model_predictions.filters import ( + prediction_and_label_filtering_classification, +) +from encord_active.lib.model_predictions.reader import ( + ClassificationLabelSchema, + ClassificationPredictionMatchSchemaWithClassNames, + ClassificationPredictionSchema, + get_class_idx, +) +from encord_active.lib.model_predictions.writer import MainPredictionType + + +class ClassificationTypeBuilder(PredictionTypeBuilder): + title = "Classification" + + class OutcomeType(str, Enum): + CORRECT_CLASSIFICATIONS = "Correct Classifications" + MISCLASSIFICATIONS = "Misclassifications" + + def __init__(self): + self._explorer_outcome_type = self.OutcomeType.CORRECT_CLASSIFICATIONS + self._labels: Optional[List] = None + self._predictions: Optional[List] = None + self._model_predictions: Optional[DataFrame[ClassificationPredictionMatchSchemaWithClassNames]] = None + + def load_data(self, page_mode: ModelQualityPage) -> bool: + predictions_metric_datas, label_metric_datas, model_predictions, labels = self._read_prediction_files( + MainPredictionType.CLASSIFICATION, project_path=get_state().project_paths.project_dir.as_posix() + ) + + if model_predictions is None: + st.error("Couldn't load model predictions") + return False + + if labels is None: + st.error("Couldn't load labels properly") + return False + + get_state().predictions.metric_datas_classification = MetricNames( + predictions={m.name: m for m in predictions_metric_datas}, + ) + + with sticky_header(): + self._common_settings() + self._topbar_additional_settings(page_mode) + + model_predictions_matched = match_predictions_and_labels(model_predictions, labels) + + ( + labels_filtered, + predictions_filtered, + model_predictions_matched_filtered, + ) = prediction_and_label_filtering_classification( + get_state().predictions.selected_classes_classifications, + get_state().predictions.all_classes_classifications, + labels, + model_predictions, + model_predictions_matched, + ) + + img_id_intersection = list( + set(labels_filtered[ClassificationLabelSchema.img_id]).intersection( + set(predictions_filtered[ClassificationPredictionSchema.img_id]) + ) + ) + labels_filtered_intersection = labels_filtered[ + labels_filtered[ClassificationLabelSchema.img_id].isin(img_id_intersection) + ] + predictions_filtered_intersection = predictions_filtered[ + predictions_filtered[ClassificationPredictionSchema.img_id].isin(img_id_intersection) + ] + + self._labels, self._predictions = ( + list(labels_filtered_intersection[ClassificationLabelSchema.class_id]), + list(predictions_filtered_intersection[ClassificationPredictionSchema.class_id]), + ) + + self._model_predictions = model_predictions_matched_filtered.copy()[ + model_predictions_matched_filtered[ClassificationPredictionMatchSchemaWithClassNames.img_id].isin( + img_id_intersection + ) + ] + + return True + + def _common_settings(self): + if not get_state().predictions.all_classes_classifications: + get_state().predictions.all_classes_classifications = get_class_idx( + get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value + ) + + all_classes = get_state().predictions.all_classes_classifications + selected_classes = self._render_class_filtering_component(all_classes) + + get_state().predictions.selected_classes_classifications = dict(selected_classes) or deepcopy(all_classes) + + def _topbar_additional_settings(self, page_mode: ModelQualityPage): + if page_mode == ModelQualityPage.METRICS: + return + elif page_mode == ModelQualityPage.PERFORMANCE_BY_METRIC: + c1, c2, c3 = st.columns([4, 4, 3]) + with c1: + self._topbar_metric_selection_component( + MetricType.PREDICTION, get_state().predictions.metric_datas_classification + ) + with c2: + self._set_binning() + with c3: + self._class_decomposition() + elif page_mode == ModelQualityPage.EXPLORER: + c1, c2 = st.columns([4, 4]) + + with c1: + self._explorer_outcome_type = st.selectbox( + "Outcome", + [x for x in self.OutcomeType], + format_func=lambda x: x.value, + help="Only the samples with this outcome will be shown", + ) + with c2: + self._topbar_metric_selection_component( + MetricType.PREDICTION, get_state().predictions.metric_datas_classification + ) + + self.display_settings(MetricScope.MODEL_QUALITY) + + def is_available(self) -> bool: + return reader.check_model_prediction_availability( + get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value + ) + + def render_metrics(self): + class_names = sorted(list(set(self._labels).union(self._predictions))) + + precision, recall, f1, support = get_precision_recall_f1(self._labels, self._predictions) + accuracy = get_accuracy(self._labels, self._predictions) + + # PERFORMANCE METRICS SUMMARY + + col_acc, col_prec, col_rec, col_f1 = st.columns(4) + col_acc.metric("Accuracy", f"{float(accuracy):.2f}") + col_prec.metric( + "Mean Precision", + f"{float(precision.mean()):.2f}", + help="Average of precision scores of all classes", + ) + col_rec.metric("Mean Recall", f"{float(recall.mean()):.2f}", help="Average of recall scores of all classes") + col_f1.metric("Mean F1", f"{float(f1.mean()):.2f}", help="Average of F1 scores of all classes") + + # METRIC IMPORTANCE + self._get_metric_importance( + self._model_predictions, list(get_state().predictions.metric_datas_classification.predictions.keys()) + ) + + col1, col2 = st.columns(2) + + # CONFUSION MATRIX + confusion_matrix = get_confusion_matrix(self._labels, self._predictions, class_names) + col1.plotly_chart(confusion_matrix, use_container_width=True) + + # PRECISION_RECALL BARS + pr_graph = get_precision_recall_graph(precision, recall, class_names) + col2.plotly_chart(pr_graph, use_container_width=True) + + def render_performance_by_metric(self): + self._render_performance_by_metric_description( + self._model_predictions, get_state().predictions.metric_datas_classification + ) + metric_name = get_state().predictions.metric_datas_classification.selected_prediction + + classes_for_coloring = ["Average"] + decompose_classes = get_state().predictions.decompose_classes + if decompose_classes: + unique_classes = set(self._model_predictions["class_name"].unique()) + classes_for_coloring += sorted(list(unique_classes)) + + # Ensure same colors between plots + chart_args = dict( + color_params={"scale": alt.Scale(domain=classes_for_coloring)}, + bins=get_state().predictions.nbins, + show_decomposition=decompose_classes, + ) + + try: + tpr = performance_rate_by_metric( + self._model_predictions, + metric_name, + scope=PredictionMatchScope.TRUE_POSITIVES, + **chart_args, + ) + if tpr is not None: + st.altair_chart(tpr.interactive(), use_container_width=True) + except Exception as e: + logger.warning(e) + pass + + def render_explorer(self): + with st.expander("Details"): + if self._explorer_outcome_type == self.OutcomeType.CORRECT_CLASSIFICATIONS: + view_text = "These are the predictions where the model correctly predicts the true class." + else: + view_text = "These are the predictions where the model incorrectly predicts the positive class." + st.markdown( + f"""### The view +{view_text} + """, + unsafe_allow_html=True, + ) + + self._metric_details_description(get_state().predictions.metric_datas_classification) + + metric_name = get_state().predictions.metric_datas_classification.selected_prediction + if not metric_name: + st.error("No prediction metric selected") + return + + if self._explorer_outcome_type == self.OutcomeType.CORRECT_CLASSIFICATIONS: + view_df = self._model_predictions[ + self._model_predictions[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive] == 1.0 + ].dropna(subset=[metric_name]) + else: + view_df = self._model_predictions[ + self._model_predictions[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive] == 0.0 + ].dropna(subset=[metric_name]) + + if view_df.shape[0] == 0: + st.write(f"No {self._explorer_outcome_type}") + else: + histogram = get_histogram(view_df, metric_name) + st.altair_chart(histogram, use_container_width=True) + prediction_grid_classifications(get_state().project_paths, model_predictions=view_df) diff --git a/src/encord_active/app/model_quality/prediction_types/object_type_builder.py b/src/encord_active/app/model_quality/prediction_types/object_type_builder.py new file mode 100644 index 000000000..a6476607b --- /dev/null +++ b/src/encord_active/app/model_quality/prediction_types/object_type_builder.py @@ -0,0 +1,382 @@ +import json +import re +from copy import deepcopy +from enum import Enum +from typing import Optional, Union + +import altair as alt +import streamlit as st +from loguru import logger +from pandera.typing import DataFrame + +import encord_active.lib.model_predictions.reader as reader +from encord_active.app.common.components import sticky_header +from encord_active.app.common.components.prediction_grid import prediction_grid +from encord_active.app.common.state import MetricNames, State, get_state +from encord_active.app.common.state_hooks import use_memo +from encord_active.app.model_quality.prediction_type_builder import ( + MetricType, + ModelQualityPage, + PredictionTypeBuilder, +) +from encord_active.lib.charts.histogram import get_histogram +from encord_active.lib.charts.performance_by_metric import performance_rate_by_metric +from encord_active.lib.charts.precision_recall import create_pr_chart_plotly +from encord_active.lib.charts.scopes import PredictionMatchScope +from encord_active.lib.common.colors import Color +from encord_active.lib.metrics.utils import MetricScope +from encord_active.lib.model_predictions.filters import ( + filter_labels_for_frames_wo_predictions, + prediction_and_label_filtering, +) +from encord_active.lib.model_predictions.map_mar import ( + PerformanceMetricSchema, + PrecisionRecallSchema, + compute_mAP_and_mAR, +) +from encord_active.lib.model_predictions.reader import ( + LabelMatchSchema, + PredictionMatchSchema, + get_class_idx, +) +from encord_active.lib.model_predictions.writer import MainPredictionType + + +class ObjectTypeBuilder(PredictionTypeBuilder): + title = "Object" + + class OutcomeType(str, Enum): + TRUE_POSITIVES = "True Positive" + FALSE_POSITIVES = "False Positive" + FALSE_NEGATIVES = "False Negative" + + def __init__(self): + self._explorer_outcome_type = self.OutcomeType.TRUE_POSITIVES + self._model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None + self._labels: Optional[DataFrame[LabelMatchSchema]] = None + self._metrics: Optional[DataFrame[PerformanceMetricSchema]] = None + self._precisions: Optional[DataFrame[PrecisionRecallSchema]] = None + + def load_data(self, page_mode: ModelQualityPage) -> bool: + predictions_dir = get_state().project_paths.predictions / MainPredictionType.OBJECT.value + predictions_metric_datas, label_metric_datas, model_predictions, labels = self._read_prediction_files( + MainPredictionType.OBJECT, project_path=get_state().project_paths.project_dir.as_posix() + ) + + if model_predictions is None: + st.error("Couldn't load model predictions") + return False + + if labels is None: + st.error("Couldn't load labels properly") + return False + + get_state().predictions.metric_datas = MetricNames( + predictions={m.name: m for m in predictions_metric_datas}, + labels={m.name: m for m in label_metric_datas}, + ) + + with sticky_header(): + self._common_settings() + self._topbar_additional_settings(page_mode) + + matched_gt, _ = use_memo( + lambda: reader.get_gt_matched(predictions_dir), + key=f"matched_gt_{get_state().project_paths.project_dir.as_posix()}", + ) + if not matched_gt: + st.error("Couldn't match ground truths") + return False + + (predictions_filtered, labels_filtered, metrics, precisions,) = compute_mAP_and_mAR( + model_predictions, + labels, + matched_gt, + get_state().predictions.all_classes_objects, + iou_threshold=get_state().iou_threshold, + ignore_unmatched_frames=get_state().ignore_frames_without_predictions, + ) + + # Sort predictions and labels according to selected metrics. + pred_sort_column = get_state().predictions.metric_datas.selected_prediction or predictions_metric_datas[0].name + sorted_model_predictions = predictions_filtered.sort_values([pred_sort_column], axis=0) + + label_sort_column = get_state().predictions.metric_datas.selected_label or label_metric_datas[0].name + sorted_labels = labels_filtered.sort_values([label_sort_column], axis=0) + + if get_state().ignore_frames_without_predictions: + labels_filtered = filter_labels_for_frames_wo_predictions(predictions_filtered, sorted_labels) + else: + labels_filtered = sorted_labels + + self._labels, self._metrics, self._model_predictions, self._precisions = prediction_and_label_filtering( + get_state().predictions.selected_classes_objects, + labels_filtered, + metrics, + sorted_model_predictions, + precisions, + ) + + return True + + def _common_settings(self): + if not get_state().predictions.all_classes_objects: + get_state().predictions.all_classes_objects = get_class_idx( + get_state().project_paths.predictions / MainPredictionType.OBJECT.value + ) + + all_classes = get_state().predictions.all_classes_objects + col1, col2, col3 = st.columns([4, 4, 3]) + + with col1: + selected_classes = self._render_class_filtering_component(all_classes) + + get_state().predictions.selected_classes_objects = dict(selected_classes) or deepcopy(all_classes) + + with col2: + # IOU + get_state().iou_threshold = st.slider( + "Select an IOU threshold", + min_value=0.0, + max_value=1.0, + value=State.iou_threshold, + help="The mean average precision (mAP) score is based on true positives and false positives. " + "The IOU threshold determines how closely predictions need to match labels to be considered " + "as true positives.", + ) + + with col3: + st.write("") + st.write("") + # Ignore unmatched frames + get_state().ignore_frames_without_predictions = st.checkbox( + "Ignore frames without predictions", + value=State.ignore_frames_without_predictions, + help="Scores like mAP and mAR are effected negatively if there are frames in the dataset for /" + "which there exist no predictions. With this flag, you can ignore those.", + ) + + def _topbar_additional_settings(self, page_mode: ModelQualityPage): + if page_mode == ModelQualityPage.METRICS: + return + elif page_mode == ModelQualityPage.PERFORMANCE_BY_METRIC: + c1, c2, c3 = st.columns([4, 4, 3]) + with c1: + self._topbar_metric_selection_component(MetricType.PREDICTION, get_state().predictions.metric_datas) + with c2: + self._set_binning() + with c3: + self._class_decomposition() + elif page_mode == ModelQualityPage.EXPLORER: + c1, c2 = st.columns([4, 4]) + with c1: + self._explorer_outcome_type = st.selectbox( + "Outcome", + [x for x in self.OutcomeType], + format_func=lambda x: x.value, + help="Only the samples with this outcome will be shown", + ) + with c2: + explorer_metric_type = ( + MetricType.PREDICTION + if self._explorer_outcome_type + in [self.OutcomeType.TRUE_POSITIVES, self.OutcomeType.FALSE_POSITIVES] + else MetricType.LABEL + ) + + self._topbar_metric_selection_component(explorer_metric_type, get_state().predictions.metric_datas) + + self.display_settings(MetricScope.MODEL_QUALITY) + + def _get_metric_name(self) -> Optional[str]: + if self._explorer_outcome_type in [self.OutcomeType.TRUE_POSITIVES, self.OutcomeType.FALSE_POSITIVES]: + return get_state().predictions.metric_datas.selected_prediction + else: + return get_state().predictions.metric_datas.selected_label + + def _render_explorer_details(self) -> Optional[Color]: + color: Optional[Color] = None + with st.expander("Details"): + if self._explorer_outcome_type == self.OutcomeType.TRUE_POSITIVES: + color = Color.PURPLE + + st.markdown( + f"""### The view +These are the predictions for which the IOU was sufficiently high and the confidence score was +the highest amongst predictions that overlap with the label. + +--- + +**Color**: +The {color.name.lower()} boxes marks the true positive predictions. +The remaining colors correspond to the dataset labels with the colors you are used to from the label editor. + """, + unsafe_allow_html=True, + ) + + elif self._explorer_outcome_type == self.OutcomeType.FALSE_POSITIVES: + color = Color.RED + + st.markdown( + f"""### The view +These are the predictions for which either of the following is true +1. The IOU between the prediction and the best matching label was too low +2. There was another prediction with higher model confidence which matched the label already +3. The predicted class didn't match + +--- + +**Color**: +The {color.name.lower()} boxes marks the false positive predictions. +The remaining colors correspond to the dataset labels with the colors you are used to from the label editor. + """, + unsafe_allow_html=True, + ) + + elif self._explorer_outcome_type == self.OutcomeType.FALSE_NEGATIVES: + color = Color.PURPLE + + st.markdown( + f"""### The view +These are the labels that were not matched with any predictions. + +--- +**Color**: +The {color.name.lower()} boxes mark the false negatives. That is, the labels that were not +matched to any predictions. The remaining objects are predictions, where colors correspond to their predicted class +(identical colors to labels objects in the editor). + """, + unsafe_allow_html=True, + ) + self._metric_details_description(get_state().predictions.metric_datas) + return color + + def _get_target_df( + self, metric_name: str + ) -> Union[DataFrame[PredictionMatchSchema], DataFrame[LabelMatchSchema], None]: + + if self._model_predictions is not None and self._labels is not None: + if self._explorer_outcome_type == self.OutcomeType.TRUE_POSITIVES: + return self._model_predictions[ + self._model_predictions[PredictionMatchSchema.is_true_positive] == 1.0 + ].dropna(subset=[metric_name]) + elif self._explorer_outcome_type == self.OutcomeType.FALSE_POSITIVES: + return self._model_predictions[ + self._model_predictions[PredictionMatchSchema.is_true_positive] == 0.0 + ].dropna(subset=[metric_name]) + elif self._explorer_outcome_type == self.OutcomeType.FALSE_NEGATIVES: + return self._labels[self._labels[LabelMatchSchema.is_false_negative]].dropna(subset=[metric_name]) + else: + return None + else: + return None + + def is_available(self) -> bool: + return reader.check_model_prediction_availability( + get_state().project_paths.predictions / MainPredictionType.OBJECT.value + ) + + def render_metrics(self): + _map = self._metrics[self._metrics[PerformanceMetricSchema.metric] == "mAP"]["value"].item() + _mar = self._metrics[self._metrics[PerformanceMetricSchema.metric] == "mAR"]["value"].item() + col1, col2 = st.columns(2) + col1.metric("mAP", f"{_map:.3f}") + col2.metric("mAR", f"{_mar:.3f}") + + # METRIC IMPORTANCE + self._get_metric_importance( + self._model_predictions, list(get_state().predictions.metric_datas.predictions.keys()) + ) + + st.subheader("Subset selection scores") + with st.container(): + project_ontology = json.loads(get_state().project_paths.ontology.read_text(encoding="utf-8")) + chart = create_pr_chart_plotly(self._metrics, self._precisions, project_ontology["objects"]) + st.plotly_chart(chart, use_container_width=True) + + def render_performance_by_metric(self): + self._render_performance_by_metric_description(self._model_predictions, get_state().predictions.metric_datas) + metric_name = get_state().predictions.metric_datas.selected_prediction + + label_metric_name = metric_name + if metric_name[-3:] == "(P)": # Replace the P with O: "Metric (P)" -> "Metric (O)" + label_metric_name = re.sub(r"(.*?\()P(\))", r"\1O\2", metric_name) + + if label_metric_name not in self._labels.columns: + label_metric_name = re.sub( + r"(.*?\()O(\))", r"\1F\2", label_metric_name + ) # Look for it in frame label metrics. + + classes_for_coloring = ["Average"] + decompose_classes = get_state().predictions.decompose_classes + if decompose_classes: + unique_classes = set(self._model_predictions["class_name"].unique()).union( + self._labels["class_name"].unique() + ) + classes_for_coloring += sorted(list(unique_classes)) + + # Ensure same colors between plots + chart_args = dict( + color_params={"scale": alt.Scale(domain=classes_for_coloring)}, + bins=get_state().predictions.nbins, + show_decomposition=decompose_classes, + ) + + try: + if metric_name in self._model_predictions.columns: + tpr = performance_rate_by_metric( + self._model_predictions, metric_name, scope=PredictionMatchScope.TRUE_POSITIVES, **chart_args + ) + if tpr is not None: + st.altair_chart(tpr.interactive(), use_container_width=True) + else: + st.info(f"True Positive Rate is not available for `{metric_name}` metric") + except Exception as e: + logger.warning(e) + pass + + try: + if label_metric_name in self._labels.columns: + fnr = performance_rate_by_metric( + self._labels, label_metric_name, scope=PredictionMatchScope.FALSE_NEGATIVES, **chart_args + ) + if fnr is not None: + st.altair_chart(fnr.interactive(), use_container_width=True) + else: + st.info(f"False Negative Rate is not available for `{label_metric_name}` metric") + except Exception as e: + logger.warning(e) + pass + + def render_explorer(self): + metric_name = self._get_metric_name() + if not metric_name: + st.error("No metric selected") + return + + color = self._render_explorer_details() + if color is None: + st.warning("An error occurred while rendering the explorer page") + + view_df = self._get_target_df(metric_name) + if view_df is None: + st.error(f"An error occurred during getting data according to the {metric_name} metric") + return + + if view_df.shape[0] == 0: + st.write(f"No {self._explorer_outcome_type}") + else: + histogram = get_histogram(view_df, metric_name) + st.altair_chart(histogram, use_container_width=True) + if self._explorer_outcome_type in [self.OutcomeType.TRUE_POSITIVES, self.OutcomeType.FALSE_POSITIVES]: + prediction_grid(get_state().project_paths, model_predictions=view_df, box_color=color) + else: + prediction_grid( + get_state().project_paths, + model_predictions=self._model_predictions, + labels=view_df, + box_color=color, + ) diff --git a/src/encord_active/app/model_quality/sub_pages/__init__.py b/src/encord_active/app/model_quality/sub_pages/__init__.py deleted file mode 100644 index f6870db1c..000000000 --- a/src/encord_active/app/model_quality/sub_pages/__init__.py +++ /dev/null @@ -1,181 +0,0 @@ -from abc import abstractmethod -from typing import Optional - -import streamlit as st -from loguru import logger -from pandera.typing import DataFrame -from streamlit.delta_generator import DeltaGenerator - -import encord_active.app.common.state as state -from encord_active.app.common.page import Page -from encord_active.app.common.state import MetricNames -from encord_active.lib.constants import DOCS_URL -from encord_active.lib.model_predictions.map_mar import ( - PerformanceMetricSchema, - PrecisionRecallSchema, -) -from encord_active.lib.model_predictions.reader import ( - ClassificationPredictionMatchSchema, - LabelMatchSchema, - PredictionMatchSchema, -) - - -class ModelQualityPage(Page): - @property - @abstractmethod - def title(self) -> str: - pass - - @abstractmethod - def sidebar_options(self): - """ - Used to append options to the sidebar. - """ - pass - - def sidebar_options_classifications(self): - pass - - def check_building_object_quality( - self, - object_predictions_exist: bool, - object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None, - object_labels: Optional[DataFrame[LabelMatchSchema]] = None, - object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None, - object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None, - ) -> bool: - if not object_predictions_exist: - st.markdown( - "## Missing model predictions for the objects\n" - "This project does not have any imported predictions for the objects. " - "Please refer to the " - f"[Importing Model Predictions]({DOCS_URL}/sdk/importing-model-predictions) " - "section of the documentation to learn how to import your predictions." - ) - return False - elif not ( - (object_model_predictions is not None) - and (object_labels is not None) - and (object_metrics is not None) - and (object_precisions is not None) - ): - logger.error( - "If object_prediction_exist is True, the followings should be provided: object_model_predictions, \ - object_labels, object_metrics, object_precisions" - ) - return False - - return True - - def check_building_classification_quality( - self, - classification_predictions_exist: bool, - classification_labels: Optional[list] = None, - classification_pred: Optional[list] = None, - classification_model_predictions_matched: Optional[DataFrame[ClassificationPredictionMatchSchema]] = None, - ) -> bool: - if not classification_predictions_exist: - st.markdown( - "## Missing model predictions for the classifications\n" - "This project does not have any imported predictions for the classifications. " - "Please refer to the " - f"[Importing Model Predictions]({DOCS_URL}/sdk/importing-model-predictions) " - "section of the documentation to learn how to import your predictions." - ) - return False - elif not ( - (classification_labels is not None) - and (classification_pred is not None) - and (classification_model_predictions_matched is not None), - ): - logger.error( - "If classification_predictions_exist is True, the followings should be provided: classification_labels, \ - classification_pred, classification_model_predictions_matched_filtered" - ) - return False - - return True - - @abstractmethod - def build( - self, - object_predictions_exist: bool, - classification_predictions_exist: bool, - object_tab: DeltaGenerator, - classification_tab: DeltaGenerator, - object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None, - object_labels: Optional[DataFrame[LabelMatchSchema]] = None, - object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None, - object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None, - classification_labels: Optional[list] = None, - classification_pred: Optional[list] = None, - classification_model_predictions_matched_filtered: Optional[ - DataFrame[ClassificationPredictionMatchSchema] - ] = None, - ): - """ - If object_predictions_exist is True, the followings should be provided: object_model_predictions, \ - object_labels, object_metrics, object_precisions - If classification_predictions_exist is True, the followings should be provided: classification_labels, \ - classification_pred, classification_model_predictions_matched_filtered - """ - - pass - - def __call__( - self, - model_predictions: DataFrame[PredictionMatchSchema], - labels: DataFrame[LabelMatchSchema], - metrics: DataFrame[PerformanceMetricSchema], - precisions: DataFrame[PrecisionRecallSchema], - ): - return self.build(model_predictions, labels, metrics, precisions) - - def __repr__(self): - return f"{type(self).__name__}()" - - @staticmethod - def prediction_metric_in_sidebar_objects(): - """ - Note: Adding the fixed options "confidence" and "iou" works here because - confidence is required on import and IOU is computed during prediction - import. So the two columns are already available in the - `st.session_state.model_predictions` data frame. - """ - fixed_options = {"confidence": "Model Confidence", "iou": "IOU"} - column_names = list(state.get_state().predictions.metric_datas.predictions.keys()) - state.get_state().predictions.metric_datas.selected_prediction = st.selectbox( - "Select metric for your predictions", - column_names + list(fixed_options.keys()), - format_func=lambda s: fixed_options.get(s, s), - help="The data in the main view will be sorted by the selected metric. " - "(F) := frame scores, (P) := prediction scores.", - ) - - @staticmethod - def prediction_metric_in_sidebar_classifications(): - fixed_options = {"confidence": "Model Confidence"} - column_names = list(state.get_state().predictions.metric_datas_classification.predictions.keys()) - state.get_state().predictions.metric_datas_classification.selected_prediction = st.selectbox( - "Select metric for your predictions", - column_names + list(fixed_options.keys()), - format_func=lambda s: fixed_options.get(s, s), - help="The data in the main view will be sorted by the selected metric. " - "(F) := frame scores, (P) := prediction scores.", - ) - - @staticmethod - def metric_details_description(metric_datas: MetricNames): - metric_name = metric_datas.selected_prediction - if not metric_name: - return - - metric_data = metric_datas.predictions.get(metric_name) - - if not metric_data: - metric_data = metric_datas.labels.get(metric_name) - - if metric_data: - st.markdown(f"### The {metric_data.name[:-4]} metric") - st.markdown(metric_data.meta.long_description) diff --git a/src/encord_active/app/model_quality/sub_pages/false_negatives.py b/src/encord_active/app/model_quality/sub_pages/false_negatives.py deleted file mode 100644 index f95d2ef62..000000000 --- a/src/encord_active/app/model_quality/sub_pages/false_negatives.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import Optional, cast - -import streamlit as st -from pandera.typing import DataFrame -from streamlit.delta_generator import DeltaGenerator - -from encord_active.app.common.components.prediction_grid import prediction_grid -from encord_active.app.common.state import get_state -from encord_active.lib.charts.histogram import get_histogram -from encord_active.lib.common.colors import Color -from encord_active.lib.metrics.utils import MetricScope -from encord_active.lib.model_predictions.map_mar import ( - PerformanceMetricSchema, - PrecisionRecallSchema, -) -from encord_active.lib.model_predictions.reader import ( - ClassificationPredictionMatchSchemaWithClassNames, - LabelMatchSchema, - PredictionMatchSchema, -) - -from . import ModelQualityPage - - -class FalseNegativesPage(ModelQualityPage): - title = "🔍 False Negatives" - - def sidebar_options(self): - metric_columns = list(get_state().predictions.metric_datas.labels.keys()) - get_state().predictions.metric_datas.selected_label = st.selectbox( - "Select metric for your labels", - metric_columns, - help="The data in the main view will be sorted by the selected metric. " - "(F) := frame scores, (O) := object scores.", - ) - self.display_settings(MetricScope.MODEL_QUALITY) - - def sidebar_options_classifications(self): - pass - - def _build_objects( - self, - object_model_predictions: DataFrame[PredictionMatchSchema], - object_labels: DataFrame[LabelMatchSchema], - ): - metric_name = get_state().predictions.metric_datas.selected_label - if not metric_name: - st.error("Prediction label not selected") - return - - with st.expander("Details"): - color = Color.PURPLE - st.markdown( - f"""### The view -These are the labels that were not matched with any predictions. - ---- -**Color**: -The {color.name.lower()} boxes mark the false negatives. -That is, the labels that were not matched to any predictions. -The remaining objects are predictions, where colors correspond to their predicted class (identical colors to labels objects in the editor). - """, - unsafe_allow_html=True, - ) - self.metric_details_description(get_state().predictions.metric_datas) - fns_df = object_labels[object_labels[LabelMatchSchema.is_false_negative]].dropna(subset=[metric_name]) - if fns_df.shape[0] == 0: - st.write("No false negatives") - else: - histogram = get_histogram(fns_df, metric_name) - st.altair_chart(histogram, use_container_width=True) - prediction_grid( - get_state().project_paths, - labels=fns_df, - model_predictions=object_model_predictions, - box_color=color, - ) - - def build( - self, - object_predictions_exist: bool, - classification_predictions_exist: bool, - object_tab: DeltaGenerator, - classification_tab: DeltaGenerator, - object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None, - object_labels: Optional[DataFrame[LabelMatchSchema]] = None, - object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None, - object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None, - classification_labels: Optional[list] = None, - classification_pred: Optional[list] = None, - classification_model_predictions_matched: Optional[ - DataFrame[ClassificationPredictionMatchSchemaWithClassNames] - ] = None, - ): - - with object_tab: - if self.check_building_object_quality( - object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions - ): - self._build_objects( - cast(DataFrame[PredictionMatchSchema], object_model_predictions), - cast(DataFrame[LabelMatchSchema], object_labels), - ) - - with classification_tab: - st.markdown( - "## False Negatives view for the classification predictions is not available\n" - "Please use **Filter by class** field in True Positives page to inspect different classes." - ) diff --git a/src/encord_active/app/model_quality/sub_pages/false_positives.py b/src/encord_active/app/model_quality/sub_pages/false_positives.py deleted file mode 100644 index 7153f5711..000000000 --- a/src/encord_active/app/model_quality/sub_pages/false_positives.py +++ /dev/null @@ -1,142 +0,0 @@ -from typing import Optional, cast - -import streamlit as st -from pandera.typing import DataFrame -from streamlit.delta_generator import DeltaGenerator - -from encord_active.app.common.components.prediction_grid import ( - prediction_grid, - prediction_grid_classifications, -) -from encord_active.app.common.state import get_state -from encord_active.lib.charts.histogram import get_histogram -from encord_active.lib.common.colors import Color -from encord_active.lib.metrics.utils import MetricScope -from encord_active.lib.model_predictions.map_mar import ( - PerformanceMetricSchema, - PrecisionRecallSchema, -) -from encord_active.lib.model_predictions.reader import ( - ClassificationPredictionMatchSchemaWithClassNames, - LabelMatchSchema, - PredictionMatchSchema, -) - -from . import ModelQualityPage - - -class FalsePositivesPage(ModelQualityPage): - title = "🌡 False Positives" - - def sidebar_options(self): - self.prediction_metric_in_sidebar_objects() - self.display_settings(MetricScope.MODEL_QUALITY) - - def sidebar_options_classifications(self): - self.prediction_metric_in_sidebar_classifications() - - def _build_objects( - self, - object_model_predictions: DataFrame[PredictionMatchSchema], - ): - metric_name = get_state().predictions.metric_datas.selected_prediction - if not metric_name: - st.error("No prediction metric selected") - return - - st.markdown(f"# {self.title}") - color = Color.RED - with st.expander("Details"): - st.markdown( - f"""### The view -These are the predictions for which either of the following is true -1. The IOU between the prediction and the best matching label was too low -2. There was another prediction with higher model confidence which matched the label already -3. The predicted class didn't match - ---- - -**Color**: -The {color.name.lower()} boxes marks the false positive predictions. -The remaining colors correspond to the dataset labels with the colors you are used to from the label editor. - """, - unsafe_allow_html=True, - ) - self.metric_details_description(get_state().predictions.metric_datas) - - fp_df = object_model_predictions[ - object_model_predictions[PredictionMatchSchema.is_true_positive] == 0.0 - ].dropna(subset=[metric_name]) - if fp_df.shape[0] == 0: - st.write("No false positives") - else: - histogram = get_histogram(fp_df, metric_name) - st.altair_chart(histogram, use_container_width=True) - prediction_grid(get_state().project_paths, model_predictions=fp_df, box_color=color) - - def _build_classifications( - self, - classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - ): - with st.expander("Details"): - st.markdown( - """### The view -These are the predictions where the model incorrectly predicts the positive class. - """, - unsafe_allow_html=True, - ) - self.metric_details_description(get_state().predictions.metric_datas_classification) - - metric_name = get_state().predictions.metric_datas_classification.selected_prediction - if not metric_name: - st.error("No prediction metric selected") - return - - fp_df = classification_model_predictions_matched[ - classification_model_predictions_matched[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive] - == 0.0 - ].dropna(subset=[metric_name]) - - if fp_df.shape[0] == 0: - st.write("No false positives") - else: - histogram = get_histogram(fp_df, metric_name) - st.altair_chart(histogram, use_container_width=True) - prediction_grid_classifications(get_state().project_paths, model_predictions=fp_df) - - def build( - self, - object_predictions_exist: bool, - classification_predictions_exist: bool, - object_tab: DeltaGenerator, - classification_tab: DeltaGenerator, - object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None, - object_labels: Optional[DataFrame[LabelMatchSchema]] = None, - object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None, - object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None, - classification_labels: Optional[list] = None, - classification_pred: Optional[list] = None, - classification_model_predictions_matched: Optional[ - DataFrame[ClassificationPredictionMatchSchemaWithClassNames] - ] = None, - ): - - with object_tab: - if self.check_building_object_quality( - object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions - ): - self._build_objects(cast(DataFrame[PredictionMatchSchema], object_model_predictions)) - - with classification_tab: - if self.check_building_classification_quality( - classification_predictions_exist, - classification_labels, - classification_pred, - classification_model_predictions_matched, - ): - self._build_classifications( - cast( - DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - classification_model_predictions_matched, - ) - ) diff --git a/src/encord_active/app/model_quality/sub_pages/metrics.py b/src/encord_active/app/model_quality/sub_pages/metrics.py deleted file mode 100644 index 920551b7f..000000000 --- a/src/encord_active/app/model_quality/sub_pages/metrics.py +++ /dev/null @@ -1,240 +0,0 @@ -import json -from typing import Optional, cast - -import streamlit as st -from pandera.typing import DataFrame -from streamlit.delta_generator import DeltaGenerator - -from encord_active.app.common.state import get_state -from encord_active.lib.charts.classification_metrics import ( - get_accuracy, - get_confusion_matrix, - get_precision_recall_f1, - get_precision_recall_graph, -) -from encord_active.lib.charts.metric_importance import create_metric_importance_charts -from encord_active.lib.charts.precision_recall import create_pr_chart_plotly -from encord_active.lib.model_predictions.map_mar import ( - PerformanceMetricSchema, - PrecisionRecallSchema, -) -from encord_active.lib.model_predictions.reader import ( - ClassificationPredictionMatchSchemaWithClassNames, - LabelMatchSchema, - PredictionMatchSchema, -) -from encord_active.lib.model_predictions.writer import MainPredictionType - -from . import ModelQualityPage - -_M_COLS = PerformanceMetricSchema - - -class MetricsPage(ModelQualityPage): - title = "📈 Metrics" - - def sidebar_options(self): - pass - - def sidebar_options_classifications(self): - pass - - def _build_objects( - self, - object_model_predictions: DataFrame[PredictionMatchSchema], - object_metrics: DataFrame[PerformanceMetricSchema], - object_precisions: DataFrame[PrecisionRecallSchema], - ): - - _map = object_metrics[object_metrics[_M_COLS.metric] == "mAP"]["value"].item() - _mar = object_metrics[object_metrics[_M_COLS.metric] == "mAR"]["value"].item() - col1, col2 = st.columns(2) - col1.metric("mAP", f"{_map:.3f}") - col2.metric("mAR", f"{_mar:.3f}") - st.markdown("""---""") - - st.subheader("Metric Importance") - with st.container(): - with st.expander("Description"): - st.write( - "The following charts show the dependency between model performance and each index. " - "In other words, these charts answer the question of how much is model " - "performance affected by each index. This relationship can be decomposed into two metrics:" - ) - st.markdown( - "- **Metric importance**: measures the *strength* of the dependency between and metric and model " - "performance. A high value means that the model performance would be strongly affected by " - "a change in the index. For example, a high importance in 'Brightness' implies that a change " - "in that quantity would strongly affect model performance. Values range from 0 (no dependency) " - "to 1 (perfect dependency, one can completely predict model performance simply by looking " - "at this index)." - ) - st.markdown( - "- **Metric [correlation](https://en.wikipedia.org/wiki/Correlation)**: measures the *linearity " - "and direction* of the dependency between an index and model performance. " - "Crucially, this metric tells us whether a positive change in an index " - "will lead to a positive change (positive correlation) or a negative change (negative correlation) " - "in model performance . Values range from -1 to 1." - ) - st.write( - "Finally, you can also select how many samples are included in the computation " - "with the slider, as well as filter by class with the dropdown in the side bar." - ) - - if ( - object_model_predictions.shape[0] > 60_000 - ): # Computation are heavy so allow computing for only a subset. - num_samples = st.slider( - "Number of samples", - min_value=1, - max_value=len(object_model_predictions), - step=max(1, (len(object_model_predictions) - 1) // 100), - value=max((len(object_model_predictions) - 1) // 2, 1), - help="To avoid too heavy computations, we subsample the data at random to the selected size, " - "computing importance values.", - ) - if num_samples < 100: - st.warning( - "Number of samples is too low to compute reliable index importances. " - "We recommend using at least 100 samples.", - ) - else: - num_samples = object_model_predictions.shape[0] - - metric_columns = list(get_state().predictions.metric_datas.predictions.keys()) - with st.spinner("Computing index importance..."): - try: - chart = create_metric_importance_charts( - object_model_predictions, - metric_columns=metric_columns, - num_samples=num_samples, - prediction_type=MainPredictionType.OBJECT, - ) - st.altair_chart(chart, use_container_width=True) - except ValueError as e: - if e.args: - st.info(e.args[0]) - else: - st.info("Failed to compute metric importance") - - st.subheader("Subset selection scores") - with st.container(): - project_ontology = json.loads(get_state().project_paths.ontology.read_text(encoding="utf-8")) - chart = create_pr_chart_plotly(object_metrics, object_precisions, project_ontology["objects"]) - st.plotly_chart(chart, use_container_width=True) - - def _build_classifications( - self, - classification_labels: list, - classification_pred: list, - classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - ): - class_names = sorted(list(set(classification_labels).union(classification_pred))) - - precision, recall, f1, support = get_precision_recall_f1(classification_labels, classification_pred) - accuracy = get_accuracy(classification_labels, classification_pred) - - # PERFORMANCE METRICS SUMMARY - - col_acc, col_prec, col_rec, col_f1 = st.columns(4) - col_acc.metric("Accuracy", f"{float(accuracy):.2f}") - col_prec.metric( - "Mean Precision", - f"{float(precision.mean()):.2f}", - help="Average of precision scores of all classes", - ) - col_rec.metric("Mean Recall", f"{float(recall.mean()):.2f}", help="Average of recall scores of all classes") - col_f1.metric("Mean F1", f"{float(f1.mean()):.2f}", help="Average of F1 scores of all classes") - - # METRIC IMPORTANCE - if ( - classification_model_predictions_matched.shape[0] > 60_000 - ): # Computation are heavy so allow computing for only a subset. - num_samples = st.slider( - "Number of samples", - min_value=1, - max_value=len(classification_model_predictions_matched), - step=max(1, (len(classification_model_predictions_matched) - 1) // 100), - value=max((len(classification_model_predictions_matched) - 1) // 2, 1), - help="To avoid too heavy computations, we subsample the data at random to the selected size, " - "computing importance values.", - ) - if num_samples < 100: - st.warning( - "Number of samples is too low to compute reliable index importances. " - "We recommend using at least 100 samples.", - ) - else: - num_samples = classification_model_predictions_matched.shape[0] - - metric_columns = list(get_state().predictions.metric_datas_classification.predictions.keys()) - try: - metric_importance_chart = create_metric_importance_charts( - classification_model_predictions_matched, - metric_columns=metric_columns, - num_samples=num_samples, - prediction_type=MainPredictionType.CLASSIFICATION, - ) - st.altair_chart(metric_importance_chart, use_container_width=True) - except ValueError as e: - if e.args: - st.info(e.args[0]) - else: - st.info("Failed to compute metric importance") - - col1, col2 = st.columns(2) - - # CONFUSION MATRIX - confusion_matrix = get_confusion_matrix(classification_labels, classification_pred, class_names) - col1.plotly_chart(confusion_matrix, use_container_width=True) - - # PRECISION_RECALL BARS - pr_graph = get_precision_recall_graph(precision, recall, class_names) - col2.plotly_chart(pr_graph, use_container_width=True) - - # In order to plot ROC curve, we need confidences for the ground - # truth label. Currently, predictions.pkl file only has confidence - # value for the predicted class. - # roc_graph = get_roc_curve(classification_labels, classification_pred) - - def build( - self, - object_predictions_exist: bool, - classification_predictions_exist: bool, - object_tab: DeltaGenerator, - classification_tab: DeltaGenerator, - object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None, - object_labels: Optional[DataFrame[LabelMatchSchema]] = None, - object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None, - object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None, - classification_labels: Optional[list] = None, - classification_pred: Optional[list] = None, - classification_model_predictions_matched: Optional[ - DataFrame[ClassificationPredictionMatchSchemaWithClassNames] - ] = None, - ): - with object_tab: - if self.check_building_object_quality( - object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions - ): - self._build_objects( - cast(DataFrame[PredictionMatchSchema], object_model_predictions), - cast(DataFrame[PerformanceMetricSchema], object_metrics), - cast(DataFrame[PrecisionRecallSchema], object_precisions), - ) - - with classification_tab: - if self.check_building_classification_quality( - classification_predictions_exist, - classification_labels, - classification_pred, - classification_model_predictions_matched, - ): - self._build_classifications( - cast(list, classification_labels), - cast(list, classification_pred), - cast( - DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - classification_model_predictions_matched, - ), - ) diff --git a/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py b/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py deleted file mode 100644 index 180156acb..000000000 --- a/src/encord_active/app/model_quality/sub_pages/performance_by_metric.py +++ /dev/null @@ -1,239 +0,0 @@ -import re -from typing import Optional, cast - -import altair as alt -import streamlit as st -from pandera.typing import DataFrame -from streamlit.delta_generator import DeltaGenerator - -import encord_active.app.common.state as state -from encord_active.app.common.state import MetricNames, PredictionsState, get_state -from encord_active.lib.charts.performance_by_metric import performance_rate_by_metric -from encord_active.lib.charts.scopes import PredictionMatchScope -from encord_active.lib.model_predictions.map_mar import ( - PerformanceMetricSchema, - PrecisionRecallSchema, -) -from encord_active.lib.model_predictions.reader import ( - ClassificationPredictionMatchSchemaWithClassNames, - LabelMatchSchema, - PredictionMatchSchema, -) - -from . import ModelQualityPage - -FLOAT_FMT = ",.4f" -PCT_FMT = ",.2f" -COUNT_FMT = ",d" - - -CHART_TITLES = { - PredictionMatchScope.TRUE_POSITIVES: "True Positive Rate", - PredictionMatchScope.FALSE_POSITIVES: "False Positive Rate", - PredictionMatchScope.FALSE_NEGATIVES: "False Negative Rate", -} - - -class PerformanceMetric(ModelQualityPage): - title = "⚡️ Performance by Metric" - - def description_expander(self, metric_datas: MetricNames): - with st.expander("Details", expanded=False): - st.markdown( - """### The View - -On this page, your model scores are displayed as a function of the metric that you selected in the top bar. -Samples are discritized into $n$ equally sized buckets and the middle point of each bucket is displayed as the x-value in the plots. -Bars indicate the number of samples in each bucket, while lines indicate the true positive and false negative rates of each bucket. - - -Metrics marked with (P) are metrics computed on your predictions. -Metrics marked with (F) are frame level metrics, which depends on the frame that each prediction is associated -with. In the "False Negative Rate" plot, (O) means metrics computed on Object labels. - -For metrics that are computed on predictions (P) in the "True Positive Rate" plot, the corresponding "label metrics" (O/F) computed -on your labels are used for the "False Negative Rate" plot. -""", - unsafe_allow_html=True, - ) - self.metric_details_description(metric_datas) - - def sidebar_options(self): - c1, c2, c3 = st.columns([4, 4, 3]) - with c1: - self.prediction_metric_in_sidebar_objects() - - with c2: - get_state().predictions.nbins = int( - st.number_input( - "Number of buckets (n)", - min_value=5, - max_value=200, - value=PredictionsState.nbins, - help="Choose the number of bins to discritize the prediction metric values into.", - ) - ) - with c3: - st.write("") # Make some spacing. - st.write("") - get_state().predictions.decompose_classes = st.checkbox( - "Show class decomposition", - value=PredictionsState.decompose_classes, - help="When checked, every plot will have a separate component for each class.", - ) - - def sidebar_options_classifications(self): - self.prediction_metric_in_sidebar_classifications() - - def _build_objects( - self, - object_model_predictions: DataFrame[PredictionMatchSchema], - object_labels: DataFrame[LabelMatchSchema], - ): - - if object_model_predictions.shape[0] == 0: - st.write("No predictions of the given class(es).") - return - - metric_name = state.get_state().predictions.metric_datas.selected_prediction - if not metric_name: - # This shouldn't happen with the current flow. The only way a user can do this - # is if he/she write custom code to bypass running the metrics. In this case, - # I think that it is fair to not give more information than this. - st.write( - "No metrics computed for the your model predictions. " - "With `encord-active import predictions /path/to/predictions.pkl`, " - "Encord Active will automatically run compute the metrics." - ) - return - - self.description_expander(get_state().predictions.metric_datas) - - label_metric_name = metric_name - if metric_name[-3:] == "(P)": # Replace the P with O: "Metric (P)" -> "Metric (O)" - label_metric_name = re.sub(r"(.*?\()P(\))", r"\1O\2", metric_name) - - if not label_metric_name in object_labels.columns: - label_metric_name = re.sub( - r"(.*?\()O(\))", r"\1F\2", label_metric_name - ) # Look for it in frame label metrics. - - classes_for_coloring = ["Average"] - decompose_classes = get_state().predictions.decompose_classes - if decompose_classes: - unique_classes = set(object_model_predictions["class_name"].unique()).union( - object_labels["class_name"].unique() - ) - classes_for_coloring += sorted(list(unique_classes)) - - # Ensure same colors between plots - chart_args = dict( - color_params={"scale": alt.Scale(domain=classes_for_coloring)}, - bins=get_state().predictions.nbins, - show_decomposition=decompose_classes, - ) - - try: - tpr = performance_rate_by_metric( - object_model_predictions, metric_name, scope=PredictionMatchScope.TRUE_POSITIVES, **chart_args - ) - if tpr is not None: - st.altair_chart(tpr.interactive(), use_container_width=True) - except: - pass - - try: - fnr = performance_rate_by_metric( - object_labels, label_metric_name, scope=PredictionMatchScope.FALSE_NEGATIVES, **chart_args - ) - if fnr is not None: - st.altair_chart(fnr.interactive(), use_container_width=True) - except: - pass - - def _build_classifications( - self, - classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - ): - if classification_model_predictions_matched.shape[0] == 0: - st.write("No predictions of the given class(es).") - return - - metric_name = state.get_state().predictions.metric_datas_classification.selected_prediction - if not metric_name: - # This shouldn't happen with the current flow. The only way a user can do this - # is if he/she write custom code to bypass running the metrics. In this case, - # I think that it is fair to not give more information than this. - st.write( - "No metrics computed for the your model predictions. " - "With `encord-active import predictions /path/to/predictions.pkl`, " - "Encord Active will automatically run compute the metrics." - ) - return - - self.description_expander(get_state().predictions.metric_datas_classification) - - classes_for_coloring = ["Average"] - decompose_classes = get_state().predictions.decompose_classes - if decompose_classes: - unique_classes = set(classification_model_predictions_matched["class_name"].unique()) - classes_for_coloring += sorted(list(unique_classes)) - - # Ensure same colors between plots - chart_args = dict( - color_params={"scale": alt.Scale(domain=classes_for_coloring)}, - bins=get_state().predictions.nbins, - show_decomposition=decompose_classes, - ) - - try: - tpr = performance_rate_by_metric( - classification_model_predictions_matched, - metric_name, - scope=PredictionMatchScope.TRUE_POSITIVES, - **chart_args, - ) - if tpr is not None: - st.altair_chart(tpr.interactive(), use_container_width=True) - except: - pass - - def build( - self, - object_predictions_exist: bool, - classification_predictions_exist: bool, - object_tab: DeltaGenerator, - classification_tab: DeltaGenerator, - object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None, - object_labels: Optional[DataFrame[LabelMatchSchema]] = None, - object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None, - object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None, - classification_labels: Optional[list] = None, - classification_pred: Optional[list] = None, - classification_model_predictions_matched: Optional[ - DataFrame[ClassificationPredictionMatchSchemaWithClassNames] - ] = None, - ): - - with object_tab: - if self.check_building_object_quality( - object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions - ): - self._build_objects( - cast(DataFrame[PredictionMatchSchema], object_model_predictions), - cast(DataFrame[LabelMatchSchema], object_labels), - ) - - with classification_tab: - if self.check_building_classification_quality( - classification_predictions_exist, - classification_labels, - classification_pred, - classification_model_predictions_matched, - ): - self._build_classifications( - cast( - DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - classification_model_predictions_matched, - ) - ) diff --git a/src/encord_active/app/model_quality/sub_pages/true_positives.py b/src/encord_active/app/model_quality/sub_pages/true_positives.py deleted file mode 100644 index 9df9561f7..000000000 --- a/src/encord_active/app/model_quality/sub_pages/true_positives.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Optional, cast - -import streamlit as st -from pandera.typing import DataFrame -from streamlit.delta_generator import DeltaGenerator - -from encord_active.app.common.components.prediction_grid import ( - prediction_grid, - prediction_grid_classifications, -) -from encord_active.app.common.state import get_state -from encord_active.lib.charts.histogram import get_histogram -from encord_active.lib.common.colors import Color -from encord_active.lib.metrics.utils import MetricScope -from encord_active.lib.model_predictions.map_mar import ( - PerformanceMetricSchema, - PrecisionRecallSchema, -) -from encord_active.lib.model_predictions.reader import ( - ClassificationPredictionMatchSchemaWithClassNames, - LabelMatchSchema, - PredictionMatchSchema, -) - -from . import ModelQualityPage - - -class TruePositivesPage(ModelQualityPage): - title = "✅ True Positives" - - def sidebar_options(self): - self.prediction_metric_in_sidebar_objects() - self.display_settings(MetricScope.MODEL_QUALITY) - - def sidebar_options_classifications(self): - self.prediction_metric_in_sidebar_classifications() - - def _build_objects( - self, - object_model_predictions: DataFrame[PredictionMatchSchema], - ): - with st.expander("Details"): - color = Color.PURPLE - st.markdown( - f"""### The view -These are the predictions for which the IOU was sufficiently high and the confidence score was -the highest amongst predictions that overlap with the label. - ---- - -**Color**: -The {color.name.lower()} boxes marks the true positive predictions. -The remaining colors correspond to the dataset labels with the colors you are used to from the label editor. - """, - unsafe_allow_html=True, - ) - self.metric_details_description(get_state().predictions.metric_datas) - - metric_name = get_state().predictions.metric_datas.selected_prediction - if not metric_name: - st.error("No prediction metric selected") - return - - tp_df = object_model_predictions[ - object_model_predictions[PredictionMatchSchema.is_true_positive] == 1.0 - ].dropna(subset=[metric_name]) - if tp_df.shape[0] == 0: - st.write("No true positives") - else: - histogram = get_histogram(tp_df, metric_name) - st.altair_chart(histogram, use_container_width=True) - prediction_grid(get_state().project_paths, model_predictions=tp_df, box_color=color) - - def _build_classifications( - self, - classification_model_predictions_matched: DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - ): - with st.expander("Details"): - st.markdown( - """### The view -These are the predictions where the model correctly predicts the true class. - """, - unsafe_allow_html=True, - ) - self.metric_details_description(get_state().predictions.metric_datas_classification) - - metric_name = get_state().predictions.metric_datas_classification.selected_prediction - if not metric_name: - st.error("No prediction metric selected") - return - - tp_df = classification_model_predictions_matched[ - classification_model_predictions_matched[ClassificationPredictionMatchSchemaWithClassNames.is_true_positive] - == 1.0 - ].dropna(subset=[metric_name]) - - if tp_df.shape[0] == 0: - st.write("No true positives") - else: - histogram = get_histogram(tp_df, metric_name) - st.altair_chart(histogram, use_container_width=True) - prediction_grid_classifications(get_state().project_paths, model_predictions=tp_df) - - def build( - self, - object_predictions_exist: bool, - classification_predictions_exist: bool, - object_tab: DeltaGenerator, - classification_tab: DeltaGenerator, - object_model_predictions: Optional[DataFrame[PredictionMatchSchema]] = None, - object_labels: Optional[DataFrame[LabelMatchSchema]] = None, - object_metrics: Optional[DataFrame[PerformanceMetricSchema]] = None, - object_precisions: Optional[DataFrame[PrecisionRecallSchema]] = None, - classification_labels: Optional[list] = None, - classification_pred: Optional[list] = None, - classification_model_predictions_matched: Optional[ - DataFrame[ClassificationPredictionMatchSchemaWithClassNames] - ] = None, - ): - with object_tab: - if self.check_building_object_quality( - object_predictions_exist, object_model_predictions, object_labels, object_metrics, object_precisions - ): - self._build_objects(cast(DataFrame[PredictionMatchSchema], object_model_predictions)) - - with classification_tab: - if self.check_building_classification_quality( - classification_predictions_exist, - classification_labels, - classification_pred, - classification_model_predictions_matched, - ): - self._build_classifications( - cast( - DataFrame[ClassificationPredictionMatchSchemaWithClassNames], - classification_model_predictions_matched, - ) - ) diff --git a/src/encord_active/app/page_menu.py b/src/encord_active/app/page_menu.py index a341bf939..42f8f43cc 100644 --- a/src/encord_active/app/page_menu.py +++ b/src/encord_active/app/page_menu.py @@ -14,13 +14,7 @@ from encord_active.app.actions_page.versioning import is_latest, version_form from encord_active.app.common.state import get_state, refresh from encord_active.app.common.state_hooks import UseState -from encord_active.app.model_quality.sub_pages.false_negatives import FalseNegativesPage -from encord_active.app.model_quality.sub_pages.false_positives import FalsePositivesPage -from encord_active.app.model_quality.sub_pages.metrics import MetricsPage -from encord_active.app.model_quality.sub_pages.performance_by_metric import ( - PerformanceMetric, -) -from encord_active.app.model_quality.sub_pages.true_positives import TruePositivesPage +from encord_active.app.model_quality.prediction_type_builder import ModelQualityPage from encord_active.app.views.metrics import explorer, summary from encord_active.app.views.model_quality import model_quality from encord_active.lib.metrics.utils import MetricScope @@ -29,11 +23,9 @@ "Data Quality": {"Summary": summary(MetricScope.DATA_QUALITY), "Explorer": explorer(MetricScope.DATA_QUALITY)}, "Label Quality": {"Summary": summary(MetricScope.LABEL_QUALITY), "Explorer": explorer(MetricScope.LABEL_QUALITY)}, "Model Quality": { - "Metrics": model_quality(MetricsPage()), - "Performance By Metric": model_quality(PerformanceMetric()), - "True Positives": model_quality(TruePositivesPage()), - "False Positives": model_quality(FalsePositivesPage()), - "False Negatives": model_quality(FalseNegativesPage()), + "Metrics": model_quality(ModelQualityPage.METRICS), + "Performance By Metric": model_quality(ModelQualityPage.PERFORMANCE_BY_METRIC), + "Explorer": model_quality(ModelQualityPage.EXPLORER), }, "Actions": {"Filter & Export": export_filter, "Balance & Export": export_balance, "Versioning": version_form}, } diff --git a/src/encord_active/app/views/model_quality.py b/src/encord_active/app/views/model_quality.py index f2daae51b..468cd91b6 100644 --- a/src/encord_active/app/views/model_quality.py +++ b/src/encord_active/app/views/model_quality.py @@ -1,237 +1,43 @@ -from pathlib import Path +from typing import List import streamlit as st -import encord_active.lib.model_predictions.reader as reader from encord_active.app.common.components import sticky_header from encord_active.app.common.components.tags.tag_creator import tag_creator -from encord_active.app.common.state import MetricNames, get_state -from encord_active.app.common.state_hooks import use_memo from encord_active.app.common.utils import setup_page -from encord_active.app.model_quality.settings import ( - common_settings_classifications, - common_settings_objects, +from encord_active.app.model_quality.prediction_type_builder import ( + ModelQualityPage, + PredictionTypeBuilder, ) -from encord_active.app.model_quality.sub_pages import Page -from encord_active.lib.model_predictions.classification_metrics import ( - match_predictions_and_labels, +from encord_active.app.model_quality.prediction_types.classification_type_builder import ( + ClassificationTypeBuilder, ) -from encord_active.lib.model_predictions.filters import ( - filter_labels_for_frames_wo_predictions, - prediction_and_label_filtering, - prediction_and_label_filtering_classification, +from encord_active.app.model_quality.prediction_types.object_type_builder import ( + ObjectTypeBuilder, ) -from encord_active.lib.model_predictions.map_mar import compute_mAP_and_mAR -from encord_active.lib.model_predictions.reader import ( - ClassificationLabelSchema, - ClassificationPredictionMatchSchemaWithClassNames, - ClassificationPredictionSchema, -) -from encord_active.lib.model_predictions.writer import MainPredictionType - - -def model_quality(page: Page): - def _get_object_model_quality_data(metrics_dir: Path): - if not reader.check_model_prediction_availability( - get_state().project_paths.predictions / MainPredictionType.OBJECT.value - ): - return False, None, None, None, None - - predictions_dir = get_state().project_paths.predictions / MainPredictionType.OBJECT.value - - predictions_metric_datas, _ = use_memo(lambda: reader.get_prediction_metric_data(predictions_dir, metrics_dir)) - label_metric_datas, _ = use_memo(lambda: reader.get_label_metric_data(metrics_dir)) - model_predictions, _ = use_memo( - lambda: reader.get_model_predictions(predictions_dir, predictions_metric_datas, MainPredictionType.OBJECT) - ) - labels, _ = use_memo(lambda: reader.get_labels(predictions_dir, label_metric_datas, MainPredictionType.OBJECT)) - - if model_predictions is None: - st.error("Couldn't load model predictions") - return - - if labels is None: - st.error("Couldn't load labels properly") - return - - matched_gt, _ = use_memo(lambda: reader.get_gt_matched(predictions_dir)) - get_state().predictions.metric_datas = MetricNames( - predictions={m.name: m for m in predictions_metric_datas}, - labels={m.name: m for m in label_metric_datas}, - ) - - if not matched_gt: - st.error("Couldn't match ground truths") - return - - with sticky_header(): - common_settings_objects() - page.sidebar_options() - - (predictions_filtered, labels_filtered, metrics, precisions,) = compute_mAP_and_mAR( - model_predictions, - labels, - matched_gt, - get_state().predictions.all_classes_objects, - iou_threshold=get_state().iou_threshold, - ignore_unmatched_frames=get_state().ignore_frames_without_predictions, - ) - - # Sort predictions and labels according to selected metrics. - pred_sort_column = get_state().predictions.metric_datas.selected_prediction or predictions_metric_datas[0].name - sorted_model_predictions = predictions_filtered.sort_values([pred_sort_column], axis=0) - - label_sort_column = get_state().predictions.metric_datas.selected_label or label_metric_datas[0].name - sorted_labels = labels_filtered.sort_values([label_sort_column], axis=0) - - if get_state().ignore_frames_without_predictions: - labels_filtered = filter_labels_for_frames_wo_predictions(predictions_filtered, sorted_labels) - else: - labels_filtered = sorted_labels - - object_labels, object_metrics, object_model_pred, object_precisions = prediction_and_label_filtering( - get_state().predictions.selected_classes_objects, - labels_filtered, - metrics, - sorted_model_predictions, - precisions, - ) - return True, object_labels, object_metrics, object_model_pred, object_precisions - def _get_classification_model_quality_data(metrics_dir: Path): - if not reader.check_model_prediction_availability( - get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value - ): - return False, None, None, None - - predictions_dir_classification = get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value - - predictions = reader.get_classification_predictions( - get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value - ) - labels = reader.get_classification_labels( - get_state().project_paths.predictions / MainPredictionType.CLASSIFICATION.value - ) - - predictions_metric_datas, _ = use_memo( - lambda: reader.get_prediction_metric_data(predictions_dir_classification, metrics_dir) - ) - label_metric_datas, _ = use_memo(lambda: reader.get_label_metric_data(metrics_dir)) - model_predictions, _ = use_memo( - lambda: reader.get_model_predictions( - predictions_dir_classification, predictions_metric_datas, MainPredictionType.CLASSIFICATION - ) - ) - labels_all, _ = use_memo( - lambda: reader.get_labels( - predictions_dir_classification, label_metric_datas, MainPredictionType.CLASSIFICATION - ) - ) - - get_state().predictions.metric_datas_classification = MetricNames( - predictions={m.name: m for m in predictions_metric_datas}, - ) - - if model_predictions is None: - st.error("Couldn't load model predictions") - return - - if labels_all is None: - st.error("Couldn't load labels properly") - return - - with sticky_header(): - common_settings_classifications() - page.sidebar_options_classifications() - - model_predictions_matched = match_predictions_and_labels(model_predictions, labels_all) - - ( - labels_filtered, - predictions_filtered, - model_predictions_matched_filtered, - ) = prediction_and_label_filtering_classification( - get_state().predictions.selected_classes_classifications, - get_state().predictions.all_classes_classifications, - labels, - predictions, - model_predictions_matched, - ) - - img_id_intersection = list( - set(labels_filtered[ClassificationLabelSchema.img_id]).intersection( - set(predictions_filtered[ClassificationPredictionSchema.img_id]) - ) - ) - labels_filtered_intersection = labels_filtered[ - labels_filtered[ClassificationLabelSchema.img_id].isin(img_id_intersection) - ] - predictions_filtered_intersection = predictions_filtered[ - predictions_filtered[ClassificationPredictionSchema.img_id].isin(img_id_intersection) - ] - - y_true, y_pred = ( - list(labels_filtered_intersection[ClassificationLabelSchema.class_id]), - list(predictions_filtered_intersection[ClassificationPredictionSchema.class_id]), - ) - - return ( - True, - y_true, - y_pred, - model_predictions_matched_filtered.copy()[ - model_predictions_matched_filtered[ClassificationPredictionMatchSchemaWithClassNames.img_id].isin( - img_id_intersection - ) - ], - ) +def model_quality(page_mode: ModelQualityPage): + def get_available_predictions() -> List[PredictionTypeBuilder]: + builders: List[PredictionTypeBuilder] = [ClassificationTypeBuilder(), ObjectTypeBuilder()] + return [b for b in builders if b.is_available()] def render(): setup_page() - tag_creator() - - """ - Note: Streamlit tabs should be initialized here for two reasons: - 1. Selected tab will be same for each page in the model quality tabs. Otherwise, if we create tabs inside the - pages, selected tab will be reset in each click. - 2. Common top bar sticker includes filtering, so their result should be reflected in each page. Otherwise, it - would be hard to get filters from other pages. - """ - - object_tab, classification_tab = st.tabs(["Objects", "Classifications"]) - metrics_dir = get_state().project_paths.metrics + with sticky_header(): + tag_creator() - with object_tab: - ( - object_predictions_exist, - object_labels, - object_metrics, - object_model_pred, - object_precisions, - ) = _get_object_model_quality_data(metrics_dir) + available_predictions: List[PredictionTypeBuilder] = get_available_predictions() - with classification_tab: + if not available_predictions: + st.markdown("## No predictions imported into this project.") + return - ( - classification_predictions_exist, - classification_labels, - classification_pred, - classification_model_predictions_matched, - ) = _get_classification_model_quality_data(metrics_dir) + tab_names = [m.title for m in available_predictions] + tabs = st.tabs(tab_names) if len(available_predictions) > 1 else [st.container()] - page.build( - object_predictions_exist, - classification_predictions_exist, - object_tab, - classification_tab, - object_model_predictions=object_model_pred, - object_labels=object_labels, - object_metrics=object_metrics, - object_precisions=object_precisions, - classification_labels=classification_labels, - classification_pred=classification_pred, - classification_model_predictions_matched=classification_model_predictions_matched, - ) + for tab, builder in zip(tabs, available_predictions): + with tab: + builder.build(page_mode) return render diff --git a/src/encord_active/lib/model_predictions/reader.py b/src/encord_active/lib/model_predictions/reader.py index a37c21a77..34c4f0346 100644 --- a/src/encord_active/lib/model_predictions/reader.py +++ b/src/encord_active/lib/model_predictions/reader.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Iterable, List, Optional, TypedDict, cast +from typing import Any, Callable, Iterable, List, Optional, TypedDict, Union, cast import pandas as pd import pandera as pa @@ -170,7 +170,7 @@ def get_prediction_metric_data(predictions_dir: Path, metrics_dir: Path) -> List def get_model_predictions( predictions_dir: Path, metric_data: List[MetricData], prediction_type: MainPredictionType -) -> Optional[DataFrame[PredictionSchema]]: +) -> Union[DataFrame[PredictionSchema], DataFrame[ClassificationPredictionSchema], None]: df = _load_csv_and_merge_metrics(predictions_dir / "predictions.csv", metric_data) if prediction_type == MainPredictionType.CLASSIFICATION: @@ -194,7 +194,7 @@ def get_label_metric_data(metrics_dir: Path) -> List[MetricData]: def get_labels( predictions_dir: Path, metric_data: List[MetricData], prediction_type: MainPredictionType -) -> Optional[DataFrame[LabelSchema]]: +) -> Union[DataFrame[LabelSchema], DataFrame[ClassificationLabelSchema], None]: df = _load_csv_and_merge_metrics(predictions_dir / "labels.csv", metric_data) if prediction_type == MainPredictionType.OBJECT: