Skip to content

Commit

Permalink
Refactor: model quality (#275)
Browse files Browse the repository at this point in the history
* refactor: separate prediction rendering

* refactor: add classification metric

* refactor: merge _load_data duplicates

* refactor: delete old subpages

* refactor: merge duplicates

* refactor: merge duplicates

* refactor: reorganize method order

* refactor: merge fodlers

* refactor: reorder methods

* refactor: merge outcome pages

* refactor: base class structure

* fix: type constraining

* fix: remove old outcome pages

* refactor: styling

* fix: use_memo using key due to different projects

* fix: typos and single tab issue

* fix: adapt to new style

* fix: tag_creator in stick header
  • Loading branch information
Gorkem-Encord authored Mar 28, 2023
1 parent 4fe2e53 commit 49cdeac
Show file tree
Hide file tree
Showing 14 changed files with 974 additions and 1,282 deletions.
2 changes: 1 addition & 1 deletion src/encord_active/app/common/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class PredictionsState:
all_classes_objects: Dict[str, OntologyObjectJSON] = field(default_factory=dict)
all_classes_classifications: Dict[str, OntologyClassificationJSON] = field(default_factory=dict)
selected_classes_objects: Dict[str, OntologyObjectJSON] = field(default_factory=dict)
selected_classes_classifications: Dict[str, OntologyObjectJSON] = field(default_factory=dict)
selected_classes_classifications: Dict[str, OntologyClassificationJSON] = field(default_factory=dict)
labels: Optional[DataFrame[LabelSchema]] = None
nbins: int = 50

Expand Down
297 changes: 297 additions & 0 deletions src/encord_active/app/model_quality/prediction_type_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
from abc import abstractmethod
from enum import Enum
from typing import Dict, List, Tuple, Union

import streamlit as st
from pandera.typing import DataFrame

import encord_active.lib.model_predictions.reader as reader
from encord_active.app.common.page import Page
from encord_active.app.common.state import MetricNames, PredictionsState, get_state
from encord_active.app.common.state_hooks import use_memo
from encord_active.lib.charts.metric_importance import create_metric_importance_charts
from encord_active.lib.metrics.utils import MetricData
from encord_active.lib.model_predictions.reader import (
ClassificationLabelSchema,
ClassificationPredictionMatchSchemaWithClassNames,
ClassificationPredictionSchema,
LabelSchema,
OntologyObjectJSON,
PredictionMatchSchema,
PredictionSchema,
)
from encord_active.lib.model_predictions.writer import (
MainPredictionType,
OntologyClassificationJSON,
)


class ModelQualityPage(str, Enum):
METRICS = "metrics"
PERFORMANCE_BY_METRIC = "performance by metric"
EXPLORER = "explorer"


class MetricType(str, Enum):
PREDICTION = "prediction"
LABEL = "label"


class PredictionTypeBuilder(Page):
def sidebar_options(self, *args, **kwargs):
pass

def _render_class_filtering_component(
self, all_classes: Union[Dict[str, OntologyObjectJSON], Dict[str, OntologyClassificationJSON]]
):
return st.multiselect(
"Filter by class",
list(all_classes.items()),
format_func=lambda x: x[1]["name"],
help="""
With this selection, you can choose which classes to include in the performance metrics calculations.
This acts as a filter, i.e. when nothing is selected all classes are included.
Performance metrics will be automatically updated according to the chosen classes.
""",
)

def _description_expander(self, metric_datas: MetricNames):
with st.expander("Details", expanded=False):
st.markdown(
"""### The View
On this page, your model scores are displayed as a function of the metric that you selected in the top bar.
Samples are discritized into $n$ equally sized buckets and the middle point of each bucket is displayed as the x-value
in the plots. Bars indicate the number of samples in each bucket, while lines indicate the true positive and false
negative rates of each bucket.
Metrics marked with (P) are metrics computed on your predictions.
Metrics marked with (F) are frame level metrics, which depends on the frame that each prediction is associated
with. In the "False Negative Rate" plot, (O) means metrics computed on Object labels.
For metrics that are computed on predictions (P) in the "True Positive Rate" plot, the corresponding "label metrics"
(O/F) computed on your labels are used for the "False Negative Rate" plot.
""",
unsafe_allow_html=True,
)
self._metric_details_description(metric_datas)

@staticmethod
def _metric_details_description(metric_datas: MetricNames):
metric_name = metric_datas.selected_prediction
if not metric_name:
return

metric_data = metric_datas.predictions.get(metric_name)

if not metric_data:
metric_data = metric_datas.labels.get(metric_name)

if metric_data:
st.markdown(f"### The {metric_data.name[:-4]} metric")
st.markdown(metric_data.meta.long_description)

def _set_binning(self):
get_state().predictions.nbins = int(
st.number_input(
"Number of buckets (n)",
min_value=5,
max_value=200,
value=PredictionsState.nbins,
help="Choose the number of bins to discritize the prediction metric values into.",
)
)

def _set_sampling_for_metric_importance(
self,
model_predictions: Union[
DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema]
],
) -> int:
num_samples = st.slider(
"Number of samples",
min_value=1,
max_value=len(model_predictions),
step=max(1, (len(model_predictions) - 1) // 100),
value=max((len(model_predictions) - 1) // 2, 1),
help="To avoid too heavy computations, we subsample the data at random to the selected size, "
"computing importance values.",
)
if num_samples < 100:
st.warning(
"Number of samples is too low to compute reliable index importance. "
"We recommend using at least 100 samples.",
)

return num_samples

def _class_decomposition(self):
st.write("") # Make some spacing.
st.write("")
get_state().predictions.decompose_classes = st.checkbox(
"Show class decomposition",
value=PredictionsState.decompose_classes,
help="When checked, every plot will have a separate component for each class.",
)

def _topbar_metric_selection_component(self, metric_type: MetricType, metric_datas: MetricNames):
"""
Note: Adding the fixed options "confidence" and "iou" works here because
confidence is required on import and IOU is computed during prediction
import. So the two columns are already available in the
`st.session_state.model_predictions` data frame.
"""
if metric_type == MetricType.LABEL:
column_names = list(metric_datas.predictions.keys())
metric_datas.selected_label = st.selectbox(
"Select metric for your labels",
column_names,
help="The data in the main view will be sorted by the selected metric. "
"(F) := frame scores, (O) := object scores.",
)
elif metric_type == MetricType.PREDICTION:
fixed_options = {"confidence": "Model Confidence"}
column_names = list(metric_datas.predictions.keys())
metric_datas.selected_prediction = st.selectbox(
"Select metric for your predictions",
column_names + list(fixed_options.keys()),
format_func=lambda s: fixed_options.get(s, s),
help="The data in the main view will be sorted by the selected metric. "
"(F) := frame scores, (P) := prediction scores.",
)

def _read_prediction_files(
self, prediction_type: MainPredictionType, project_path: str
) -> Tuple[
List[MetricData],
List[MetricData],
Union[DataFrame[PredictionSchema], DataFrame[ClassificationPredictionSchema], None],
Union[DataFrame[LabelSchema], DataFrame[ClassificationLabelSchema], None],
]:
metrics_dir = get_state().project_paths.metrics
predictions_dir = get_state().project_paths.predictions / prediction_type.value

predictions_metric_datas, _ = use_memo(
lambda: reader.get_prediction_metric_data(predictions_dir, metrics_dir),
key=f"predictions_metrics_data_{project_path}_{prediction_type.value}",
)

label_metric_datas, _ = use_memo(
lambda: reader.get_label_metric_data(metrics_dir),
key=f"label_metric_datas_{project_path}_{prediction_type.value}",
)
model_predictions, _ = use_memo(
lambda: reader.get_model_predictions(predictions_dir, predictions_metric_datas, prediction_type),
key=f"model_predictions_{project_path}_{prediction_type.value}",
)
labels, _ = use_memo(
lambda: reader.get_labels(predictions_dir, label_metric_datas, prediction_type),
key=f"labels_{project_path}_{prediction_type.value}",
)

return predictions_metric_datas, label_metric_datas, model_predictions, labels

def _render_performance_by_metric_description(
self,
model_predictions: Union[
DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema]
],
metric_datas: MetricNames,
):
if model_predictions.shape[0] == 0:
st.write("No predictions of the given class(es).")
return

metric_name = metric_datas.selected_prediction
if not metric_name:
# This shouldn't happen with the current flow. The only way a user can do this
# is if he/she write custom code to bypass running the metrics. In this case,
# I think that it is fair to not give more information than this.
st.write(
"No metrics computed for the your model predictions. "
"With `encord-active import predictions /path/to/predictions.pkl`, "
"Encord Active will automatically run compute the metrics."
)
return

self._description_expander(metric_datas)

def _get_metric_importance(
self,
model_predictions: Union[
DataFrame[ClassificationPredictionMatchSchemaWithClassNames], DataFrame[PredictionMatchSchema]
],
metric_columns: List[str],
):
with st.container():
with st.expander("Description"):
st.write(
"The following charts show the dependency between model performance and each index. "
"In other words, these charts answer the question of how much is model "
"performance affected by each index. This relationship can be decomposed into two metrics:"
)
st.markdown(
"- **Metric importance**: measures the *strength* of the dependency between and metric and model "
"performance. A high value means that the model performance would be strongly affected by "
"a change in the index. For example, a high importance in 'Brightness' implies that a change "
"in that quantity would strongly affect model performance. Values range from 0 (no dependency) "
"to 1 (perfect dependency, one can completely predict model performance simply by looking "
"at this index)."
)
st.markdown(
"- **Metric [correlation](https://en.wikipedia.org/wiki/Correlation)**: measures the *linearity "
"and direction* of the dependency between an index and model performance. "
"Crucially, this metric tells us whether a positive change in an index "
"will lead to a positive change (positive correlation) or a negative change (negative correlation) "
"in model performance . Values range from -1 to 1."
)
st.write(
"Finally, you can also select how many samples are included in the computation "
"with the slider, as well as filter by class with the dropdown in the side bar."
)

if model_predictions.shape[0] > 60_000: # Computation are heavy so allow computing for only a subset.
num_samples = self._set_sampling_for_metric_importance(model_predictions)
else:
num_samples = model_predictions.shape[0]

with st.spinner("Computing index importance..."):
try:
metric_importance_chart = create_metric_importance_charts(
model_predictions,
metric_columns=metric_columns,
num_samples=num_samples,
prediction_type=MainPredictionType.CLASSIFICATION,
)
st.altair_chart(metric_importance_chart, use_container_width=True)
except ValueError as e:
if e.args:
st.info(e.args[0])
else:
st.info("Failed to compute metric importance")

def build(self, page_mode: ModelQualityPage):
if self.load_data(page_mode):
if page_mode == ModelQualityPage.METRICS:
self.render_metrics()
elif page_mode == ModelQualityPage.PERFORMANCE_BY_METRIC:
self.render_performance_by_metric()
elif page_mode == ModelQualityPage.EXPLORER:
self.render_explorer()

def render_metrics(self):
st.markdown("### This page is not implemented")

def render_performance_by_metric(self):
st.markdown("### This page is not implemented")

def render_explorer(self):
st.markdown("### This page is not implemented")

@abstractmethod
def load_data(self, page_mode: ModelQualityPage) -> bool:
pass

@abstractmethod
def is_available(self) -> bool:
pass
Empty file.
Loading

0 comments on commit 49cdeac

Please sign in to comment.