Skip to content

Commit

Permalink
add styler to evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Feb 2, 2024
1 parent 070983c commit 8b8d9e5
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 2,599 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ or via `conda`:
⚡ Usage
-------

[![SageMaker Studio Lab](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/tvdboom/ATOM/blob/master/examples/getting_started.ipynb)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1H8pL-iAICeaKqWQxWsb6fN9zPNZK722s#scrollTo=LrtjgDQFvU2z&forceEdit=true&sandboxMode=true)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/tvdboom/ATOM/HEAD)

ATOM contains a variety of classes and functions to perform data cleaning,
Expand Down Expand Up @@ -186,9 +186,9 @@ atom.run(models=["LDA", "AdaB"], metric="auc", n_trials=10)
And lastly, analyze the results.

```python
atom.evaluate()
print(atom.evaluate())

atom.plot_lift()
atom.plot_roc()
```

<br><br>
Expand Down
6 changes: 3 additions & 3 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,11 @@ def _est_class(self) -> type[Predictor]:

# Try engine, else import from the default module
try:
module = import_module(f"{self.engine.estimator}.{module.split('.', 1)[1:]}")
mod = import_module(f"{self.engine.estimator}.{module.split('.', 1)[1]}")
except (ModuleNotFoundError, AttributeError):
module = import_module(module)
mod = import_module(module)

return getattr(module, est_name)
return self._wrap_class(getattr(mod, est_name))

@property
def _shap(self) -> ShapExplanation:
Expand Down
19 changes: 16 additions & 3 deletions atom/baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,8 @@ def evaluate(
rows: RowSelector = "test",
*,
threshold: FloatZeroToOneExc | Sequence[FloatZeroToOneExc] = 0.5,
) -> pd.DataFrame:
as_frame: Bool = False,
) -> Any | pd.DataFrame:
"""Get all models' scores for the provided metrics.
Parameters
Expand All @@ -1048,15 +1049,27 @@ def evaluate(
The same threshold per target column is applied to all
models.
as_frame: bool, default=False
Whether to return the scores as a pd.DataFrame. If False, a
`pandas.io.formats.style.Styler` object is returned, which
has a `_repr_html_` method defined, so it is rendered
automatically in a notebook. The highest score per metric
is highlighted.
Returns
-------
pd.DataFrame
[Styler][] or pd.DataFrame
Scores of the models.
"""
check_is_fitted(self)

return pd.DataFrame([m.evaluate(metric, rows, threshold=threshold) for m in self._models])
df = pd.DataFrame([m.evaluate(metric, rows, threshold=threshold) for m in self._models])

if len(self._models) == 1 or as_frame:
return df
else:
return df.style.highlight_max(props="background-color: lightgreen")

@composed(crash, beartype)
def export_pipeline(self, model: str | Model | None = None) -> Pipeline:
Expand Down
9 changes: 5 additions & 4 deletions atom/models/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from functools import cached_property
from typing import Any

from atom.basemodel import BaseModel
Expand All @@ -17,10 +18,10 @@ class CustomModel(BaseModel):
def __init__(self, **kwargs):
# Assign the estimator and store the provided parameters
if callable(est := kwargs.pop("estimator")):
self._est = self._wrap_class(est)
self._est = est
self._params = {}
else:
self._est = self._wrap_class(est.__class__)
self._est = est.__class__
self._params = est.get_params()

if hasattr(est, "name"):
Expand Down Expand Up @@ -55,10 +56,10 @@ def fullname(self) -> str:
"""Return the estimator's class name."""
return self._est_class.__name__

@property
@cached_property
def _est_class(self) -> type[Predictor]:
"""Return the estimator's class."""
return self._est
return self._wrap_class(self._est)

def _get_est(self, params: dict[str, Any]) -> Predictor:
"""Get the model's estimator with unpacked parameters.
Expand Down
34 changes: 18 additions & 16 deletions atom/plots/dataplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def plot_ccf(
plot_interval: Bool = False,
title: str | dict[str, Any] | None = None,
legend: Legend | dict[str, Any] | None = "upper right",
figsize: tuple[IntLargerZero, IntLargerZero] = (900, 600),
figsize: tuple[IntLargerZero, IntLargerZero] | None = None,
filename: str | Path | None = None,
display: Bool | None = True,
) -> go.Figure | None:
Expand Down Expand Up @@ -1010,20 +1010,21 @@ def plot_fft(
) -> go.Figure | None:
"""Plot the fourier transformation of a time series.
A Fast Fourier Transformper (FFT) plot visualizes the frequency
A Fast Fourier Transformer (FFT) plot visualizes the frequency
domain representation of a signal by transforming it from the
time domain to the frequency domain using the FFT algorithm.
The x-axis shows the frequencies, normalized to the
[Nyquist frequency][], and the y-axis shows the power spectral
density or squared amplitude per frequency unit on a logarithmic
scale. This plot is only available for [forecast][time-series]
tasks.
[Nyquist frequency][nyquist], and the y-axis shows the power
spectral density or squared amplitude per frequency unit on a
logarithmic scale. This plot is only available for
[forecast][time-series] tasks.
!!! tip
- If the plot peaks at f~0, it can indicate the wandering
behavior characteristic of a [random walk][] that needs
to be differentiated. It could also be indicative of a
stationary [ARMA][] process with a high positive phi value.
behavior characteristic of a [random walk][random_walk]
that needs to be differentiated. It could also be indicative
of a stationary [ARMA][] process with a high positive phi
value.
- Peaking at a frequency and its multiples is indicative of
seasonality. The lowest frequency in this case is called
the fundamental frequency, and the inverse of this
Expand Down Expand Up @@ -1631,16 +1632,17 @@ def plot_periodogram(
series analysis for identifying dominant frequencies, periodic
patterns, and overall spectral characteristics of the data.
The x-axis shows the frequencies, normalized to the
[Nyquist frequency][], and the y-axis shows the power spectral
density or squared amplitude per frequency unit on a logarithmic
scale. This plot is only available for [forecast][time-series]
tasks.
[Nyquist frequency][nyquist], and the y-axis shows the power
spectral density or squared amplitude per frequency unit on a
logarithmic scale. This plot is only available for
[forecast][time-series] tasks.
!!! tip
- If the plot peaks at f~0, it can indicate the wandering
behavior characteristic of a [random walk][] that needs
to be differentiated. It could also be indicative of a
stationary [ARMA][] process with a high positive phi value.
behavior characteristic of a [random walk][random_walk]
that needs to be differentiated. It could also be indicative
of a stationary [ARMA][] process with a high positive phi
value.
- Peaking at a frequency and its multiples is indicative of
seasonality. The lowest frequency in this case is called
the fundamental frequency, and the inverse of this
Expand Down
7 changes: 3 additions & 4 deletions atom/plots/predictionplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,9 @@ def plot_errors(
Plot the actual targets from a set against the predicted values
generated by the regressor. A linear fit is made on the data.
The gray, intersected line shows the identity line. This plot
can be useful to detect noise or heteroscedasticity along a
range of the target domain. This plot is unavailable for
classification tasks.
This plot can be useful to detect noise or heteroscedasticity
along a range of the target domain. This plot is unavailable
for classification tasks.
Parameters
----------
Expand Down
2 changes: 2 additions & 0 deletions docs_sources/changelog/v6.x.x.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
* Transformations only on `y` are now accepted, e.g., `atom.scale(columns=-1)`.
* The [Imputer][] class has many more strategies for numerical columns designed
for time series.
* The [evaluate][atomclassifier-evaluate] method highlights the highest score
per metric.
* Full support for [pandas nullable dtypes](https://pandas.pydata.org/docs/user_guide/integer_na.html).
* The dataset can now be provided as callable.
* The [FeatureExtractor][] class can extract features from the dataset's index.
Expand Down
4 changes: 2 additions & 2 deletions docs_sources/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ package files for all versions published on PyPI.

## Usage

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tvdboom/ATOM/blob/master/examples/getting_started.ipynb)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1H8pL-iAICeaKqWQxWsb6fN9zPNZK722s#scrollTo=LrtjgDQFvU2z&forceEdit=true&sandboxMode=true)
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/tvdboom/ATOM/HEAD)

ATOM contains a variety of classes and functions to perform data cleaning,
Expand Down Expand Up @@ -152,5 +152,5 @@ atom.run(models=["LR", "LDA"], metric="auc", n_trials=6) # hide

print(atom.evaluate())

atom.plot_lift()
atom.plot_roc()
```
3 changes: 2 additions & 1 deletion docs_sources/scripts/autodocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
registry="https://www.mlflow.org/docs/latest/model-registry.html",
ray="https://docs.ray.io/en/latest/cluster/getting-started.html",
# BaseRunner
styler="https://pandas.pydata.org/docs/reference/api/pandas.io.formats.style.Styler.html",
stackingclassifier="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.StackingClassifier.html",
stackingregressor="https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.StackingRegressor.html",
stackingforecaster="https://www.sktime.net/en/latest/api_reference/auto_generated/sktime.forecasting.compose.StackingForecaster.html",
Expand Down Expand Up @@ -254,7 +255,7 @@
update_traces="https://plotly.com/python-api-reference/generated/plotly.graph_objects.Figure.html#plotly.graph_objects.Figure.update_traces",
fanova="https://optuna.readthedocs.io/en/stable/reference/generated/optuna.importance.FanovaImportanceEvaluator.html",
kde="https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html",
nyquist_frequency="https://en.wikipedia.org/wiki/Nyquist_frequency",
nyquist="https://en.wikipedia.org/wiki/Nyquist_frequency",
random_walk="https://en.wikipedia.org/wiki/Random_walk",
arma="https://en.wikipedia.org/wiki/Autoregressive_moving-average_model",
wordcloud="https://amueller.github.io/word_cloud/generated/wordcloud.WordCloud.html",
Expand Down
2 changes: 1 addition & 1 deletion docs_sources/user_guide/accelerating.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ regardless of the engine parameter.
one to use, the first one is used by default.

!!! example
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tvdboom/ATOM/blob/master/examples/accelerating_cuml.ipynb)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gbTMqTt5sDuP3kBLy1-_U6Z2uZaSm43O?authuser=0#scrollTo=FEB9_7R7Wq4h&forceEdit=true&sandboxMode=true)

Train a model on a GPU yourself using Google Colab. Just click on the
badge above and run the notebook! Make sure to choose the GPU runtime
Expand Down
2,619 changes: 59 additions & 2,560 deletions examples/multivariate_forecast.ipynb

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions tests/test_baserunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,8 +864,15 @@ def test_evaluate(metric):
"""Assert that the evaluate method works."""
atom = ATOMClassifier(X_bin, y_bin, random_state=1)
pytest.raises(NotFittedError, atom.evaluate)
atom.run(["Tree", "SVM"])
assert isinstance(atom.evaluate(metric), pd.DataFrame)
atom.run("Tree")
assert isinstance(atom.evaluate(metric, as_frame=True), pd.DataFrame)


def test_evaluate_returns_styler():
"""Assert that the evaluate method returns a pandas styler."""
atom = ATOMClassifier(X_bin, y_bin, random_state=1)
atom.run(["Tree", "LR"])
assert isinstance(atom.evaluate(), pd.io.formats.style.Styler)


def test_export_pipeline_same_transformer():
Expand Down
2 changes: 2 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def test_models_sklearnex_classification():
n_trials=2,
est_params={"LR": {"max_iter": 5}, "RF": {"n_estimators": 5}},
)
assert all(m.estimator.__module__.startswith(("daal4py", "sklearnex")) for m in atom._models)


@pytest.mark.skipif(machine() not in ("x86_64", "AMD64"), reason="Only x86 support.")
Expand All @@ -156,6 +157,7 @@ def test_models_sklearnex_regression():
n_trials=2,
est_params={"RF": {"n_estimators": 5}},
)
assert all(m.estimator.__module__.startswith(("daal4py", "sklearnex")) for m in atom._models)


@patch.dict(
Expand Down

0 comments on commit 8b8d9e5

Please sign in to comment.