Skip to content

Commit

Permalink
Add feature importance (#852)
Browse files Browse the repository at this point in the history
* Add feature importance

* Linters

* add spark support

---------

Co-authored-by: 0lgaF <[email protected]>
Co-authored-by: mike0sv <[email protected]>
  • Loading branch information
3 people authored Nov 14, 2023
1 parent 2b7a0af commit 4258133
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 60 deletions.
152 changes: 94 additions & 58 deletions src/evidently/metrics/data_drift/data_drift_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from evidently.metric_results import DatasetColumns
from evidently.metric_results import HistogramData
from evidently.metrics.data_drift.base import WithDriftOptions
from evidently.metrics.data_drift.feature_importance import FeatureImportanceMetric
from evidently.model.widget import BaseWidgetInfo
from evidently.options.base import AnyOptions
from evidently.options.data_drift import DataDriftOptions
Expand Down Expand Up @@ -39,10 +40,14 @@ class Config:
dataset_drift: bool
drift_by_columns: Dict[str, ColumnDataDriftMetrics]
dataset_columns: DatasetColumns
current_fi: Optional[Dict[str, float]] = None
reference_fi: Optional[Dict[str, float]] = None


class DataDriftTable(WithDriftOptions[DataDriftTableResults]):
columns: Optional[List[str]]
feature_importance: Optional[bool]
_feature_importance_metric: Optional[FeatureImportanceMetric]

def __init__(
self,
Expand All @@ -58,6 +63,7 @@ def __init__(
text_stattest_threshold: Optional[float] = None,
per_column_stattest_threshold: Optional[Dict[str, float]] = None,
options: AnyOptions = None,
feature_importance: Optional[bool] = False,
):
self.columns = columns
super().__init__(
Expand All @@ -72,6 +78,7 @@ def __init__(
text_stattest_threshold=text_stattest_threshold,
per_column_stattest_threshold=per_column_stattest_threshold,
options=options,
feature_importance=feature_importance,
)
self._drift_options = DataDriftOptions(
all_features_stattest=stattest,
Expand All @@ -85,6 +92,10 @@ def __init__(
text_features_threshold=text_stattest_threshold,
per_feature_threshold=per_column_stattest_threshold,
)
if feature_importance:
self._feature_importance_metric = FeatureImportanceMetric()
else:
self._feature_importance_metric = None

def get_parameters(self) -> tuple:
return None if self.columns is None else tuple(self.columns), self.drift_options
Expand All @@ -107,20 +118,35 @@ def calculate(self, data: InputData) -> DataDriftTableResults:
columns=self.columns,
agg_data=agg_data,
)
current_fi: Optional[Dict[str, float]] = None
reference_fi: Optional[Dict[str, float]] = None

if self._feature_importance_metric is not None:
res = self._feature_importance_metric.get_result()
current_fi = res.current
reference_fi = res.reference

return DataDriftTableResults(
number_of_columns=result.number_of_columns,
number_of_drifted_columns=result.number_of_drifted_columns,
share_of_drifted_columns=result.share_of_drifted_columns,
dataset_drift=result.dataset_drift,
drift_by_columns=result.drift_by_columns,
dataset_columns=result.dataset_columns,
current_fi=current_fi,
reference_fi=reference_fi,
)


@default_renderer(wrap_type=DataDriftTable)
class DataDriftTableRenderer(MetricRenderer):
def _generate_column_params(
self, column_name: str, data: ColumnDataDriftMetrics, agg_data: bool
self,
column_name: str,
data: ColumnDataDriftMetrics,
agg_data: bool,
current_fi: Optional[Dict[str, float]] = None,
reference_fi: Optional[Dict[str, float]] = None,
) -> Optional[RichTableDataRow]:
details = RowDetails()
if data.column_type == "text":
Expand Down Expand Up @@ -157,18 +183,13 @@ def _generate_column_params(

data_drift = "Detected" if data.drift_detected else "Not Detected"

return RichTableDataRow(
details=details,
fields={
"column_name": column_name,
"column_type": data.column_type,
"stattest_name": data.stattest_name,
# "reference_distribution": {},
# "current_distribution": {},
"data_drift": data_drift,
"drift_score": round(data.drift_score, 6),
},
)
fields = {
"column_name": column_name,
"column_type": data.column_type,
"stattest_name": data.stattest_name,
"data_drift": data_drift,
"drift_score": round(data.drift_score, 6),
}

else:
if (
Expand Down Expand Up @@ -221,24 +242,26 @@ def _generate_column_params(
)
distribution = plotly_figure(title="", figure=fig)
details.with_part("DATA DISTRIBUTION", info=distribution)
return RichTableDataRow(
details=details,
fields={
"column_name": column_name,
"column_type": data.column_type,
"stattest_name": data.stattest_name,
"reference_distribution": {
"x": list(ref_small_hist.x),
"y": list(ref_small_hist.y),
},
"current_distribution": {
"x": list(current_small_hist.x),
"y": list(current_small_hist.y),
},
"data_drift": data_drift,
"drift_score": round(data.drift_score, 6),
fields = {
"column_name": column_name,
"column_type": data.column_type,
"stattest_name": data.stattest_name,
"reference_distribution": {
"x": list(ref_small_hist.x),
"y": list(ref_small_hist.y),
},
)
"current_distribution": {
"x": list(current_small_hist.x),
"y": list(current_small_hist.y),
},
"data_drift": data_drift,
"drift_score": round(data.drift_score, 6),
}
if current_fi is not None:
fields["current_feature_importance"] = current_fi.get(column_name, "")
if reference_fi is not None:
fields["reference_feature_importance"] = reference_fi.get(column_name, "")
return RichTableDataRow(details=details, fields=fields)

def render_html(self, obj: DataDriftTable) -> List[BaseWidgetInfo]:
results = obj.get_result()
Expand Down Expand Up @@ -268,45 +291,58 @@ def render_html(self, obj: DataDriftTable) -> List[BaseWidgetInfo]:
columns = columns + all_columns

for column_name in columns:
column_params = self._generate_column_params(column_name, results.drift_by_columns[column_name], agg_data)
column_params = self._generate_column_params(
column_name,
results.drift_by_columns[column_name],
agg_data,
results.current_fi,
results.reference_fi,
)

if column_params is not None:
params_data.append(column_params)

drift_percents = round(results.share_of_drifted_columns * 100, 3)
table_columns = [
ColumnDefinition("Column", "column_name"),
ColumnDefinition("Type", "column_type"),
]
if results.current_fi is not None:
table_columns.append(ColumnDefinition("Current feature importance", "current_feature_importance"))
if results.reference_fi is not None:
table_columns.append(ColumnDefinition("Reference feature importance", "reference_feature_importance"))
table_columns = table_columns + [
ColumnDefinition(
"Reference Distribution",
"reference_distribution",
ColumnType.HISTOGRAM,
options={
"xField": "x",
"yField": "y",
"color": color_options.primary_color,
},
),
ColumnDefinition(
"Current Distribution",
"current_distribution",
ColumnType.HISTOGRAM,
options={
"xField": "x",
"yField": "y",
"color": color_options.primary_color,
},
),
ColumnDefinition("Data Drift", "data_drift"),
ColumnDefinition("Stat Test", "stattest_name"),
ColumnDefinition("Drift Score", "drift_score"),
]

return [
header_text(label="Data Drift Summary"),
rich_table_data(
title=f"Drift is detected for {drift_percents}% of columns "
f"({results.number_of_drifted_columns} out of {results.number_of_columns}).",
columns=[
ColumnDefinition("Column", "column_name"),
ColumnDefinition("Type", "column_type"),
ColumnDefinition(
"Reference Distribution",
"reference_distribution",
ColumnType.HISTOGRAM,
options={
"xField": "x",
"yField": "y",
"color": color_options.primary_color,
},
),
ColumnDefinition(
"Current Distribution",
"current_distribution",
ColumnType.HISTOGRAM,
options={
"xField": "x",
"yField": "y",
"color": color_options.primary_color,
},
),
ColumnDefinition("Data Drift", "data_drift"),
ColumnDefinition("Stat Test", "stattest_name"),
ColumnDefinition("Drift Score", "drift_score"),
],
columns=table_columns,
data=params_data,
),
]
83 changes: 83 additions & 0 deletions src/evidently/metrics/data_drift/feature_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Dict
from typing import List
from typing import Optional

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import OrdinalEncoder

from evidently.base_metric import InputData
from evidently.base_metric import Metric
from evidently.base_metric import MetricResult
from evidently.core import ColumnType
from evidently.model.widget import BaseWidgetInfo
from evidently.renderers.base_renderer import MetricRenderer
from evidently.renderers.base_renderer import default_renderer
from evidently.utils.data_preprocessing import DataDefinition

SAMPLE_SIZE = 5000


class FeatureImportanceMetricResult(MetricResult):
current: Optional[Dict[str, float]] = None
reference: Optional[Dict[str, float]] = None


class FeatureImportanceMetric(Metric[FeatureImportanceMetricResult]):
def calculate(self, data: InputData) -> FeatureImportanceMetricResult:
if data.additional_datasets.get("current_feature_importance") is not None:
return FeatureImportanceMetricResult(
current=data.additional_datasets.get("current_feature_importance"),
reference=data.additional_datasets.get("reference_feature_importance"),
)

curr_sampled_data = data.current_data.sample(min(SAMPLE_SIZE, data.current_data.shape[0]), random_state=0)
ref_sampled_data: Optional[pd.DataFrame] = None
if data.reference_data is not None:
ref_sampled_data = data.reference_data.sample(
min(SAMPLE_SIZE, data.reference_data.shape[0]), random_state=0
)

return get_feature_importance_from_samples(data.data_definition, curr_sampled_data, ref_sampled_data)


def get_feature_importance_from_samples(
data_definition: DataDefinition, curr_sampled_data: pd.DataFrame, ref_sampled_data: Optional[pd.DataFrame]
):
num_cols = data_definition.get_columns(filter_def=ColumnType.Numerical, features_only=True)
cat_cols = data_definition.get_columns(filter_def=ColumnType.Categorical, features_only=True)

columns = [x.column_name for x in num_cols] + [x.column_name for x in cat_cols]

for col in [x.column_name for x in cat_cols]:
enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan)
curr_sampled_data[col] = enc.fit_transform(curr_sampled_data[col].values.reshape(-1, 1))
if ref_sampled_data is not None:
ref_sampled_data[col] = enc.fit_transform(ref_sampled_data[col].values.reshape(-1, 1))

task = data_definition.task
target_column = data_definition.get_target_column()
if target_column is None:
return FeatureImportanceMetricResult(current=None, reference=None)
target_name = target_column.column_name
if task == "regression":
model = RandomForestRegressor(min_samples_leaf=10)
else:
model = RandomForestClassifier(min_samples_leaf=10)

model.fit(curr_sampled_data[columns], curr_sampled_data[target_name])
current_fi = {x: np.round(y, 3) for x, y in zip(columns, model.feature_importances_)}

reference_fi: Optional[Dict[str, float]] = None
if ref_sampled_data is not None:
model.fit(ref_sampled_data[columns], ref_sampled_data[target_name])
reference_fi = {x: np.round(y, 3) for x, y in zip(columns, model.feature_importances_)}
return FeatureImportanceMetricResult(current=current_fi, reference=reference_fi)


@default_renderer(wrap_type=FeatureImportanceMetric)
class FeatureImportanceRenderer(MetricRenderer):
def render_html(self, obj: FeatureImportanceMetric) -> List[BaseWidgetInfo]:
return []
3 changes: 2 additions & 1 deletion src/evidently/spark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .engine import SparkEngine
from .metrics import data_drift
from .metrics import feature_importance

__all__ = ["SparkEngine", "data_drift"]
__all__ = ["SparkEngine", "data_drift", "feature_importance"]
38 changes: 38 additions & 0 deletions src/evidently/spark/metrics/feature_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional

import pandas as pd

from evidently.calculation_engine.engine import metric_implementation
from evidently.metrics.data_drift.feature_importance import SAMPLE_SIZE
from evidently.metrics.data_drift.feature_importance import FeatureImportanceMetric
from evidently.metrics.data_drift.feature_importance import FeatureImportanceMetricResult
from evidently.metrics.data_drift.feature_importance import get_feature_importance_from_samples
from evidently.spark.engine import SparkInputData
from evidently.spark.engine import SparkMetricImplementation


@metric_implementation(FeatureImportanceMetric)
class SparkFeatureImportanceMetric(SparkMetricImplementation[FeatureImportanceMetric]):
def calculate(self, context, data: SparkInputData) -> FeatureImportanceMetricResult:
if data.additional_datasets.get("current_feature_importance") is not None:
return FeatureImportanceMetricResult(
current=data.additional_datasets.get("current_feature_importance"),
reference=data.additional_datasets.get("reference_feature_importance"),
)

cur_count = data.current_data.count()
curr_sampled_data: pd.DataFrame = (
data.current_data.toPandas()
if cur_count < SAMPLE_SIZE
else data.current_data.sample(cur_count / SAMPLE_SIZE, seed=0).toPandas()
)
ref_sampled_data: Optional[pd.DataFrame] = None
if data.reference_data is not None:
ref_count = data.reference_data.count()
ref_sampled_data = (
data.reference_data.toPandas()
if ref_count < SAMPLE_SIZE
else data.reference_data.sample(ref_count / SAMPLE_SIZE, seed=0).toPandas()
)

return get_feature_importance_from_samples(data.data_definition, curr_sampled_data, ref_sampled_data)
Loading

0 comments on commit 4258133

Please sign in to comment.