Skip to content

Commit

Permalink
reduce test time
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Dec 3, 2023
1 parent 1aa41b7 commit 1813e96
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 70 deletions.
7 changes: 3 additions & 4 deletions atom/data_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@
from scipy.stats import zscore
from sklearn.base import BaseEstimator, _clone_parametrized

Check notice on line 39 in atom/data_cleaning.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _clone_parametrized of a class
from sklearn.compose import ColumnTransformer
from sklearn.experimental import enable_iterative_imputer
from sklearn.experimental import enable_iterative_imputer # noqa: F401
from sklearn.impute import IterativeImputer, KNNImputer
from typing_extensions import Self

from atom.basetransformer import BaseTransformer
from atom.pipeline import Pipeline
from atom.utils.constants import CAT_TYPES, DEFAULT_MISSING
from atom.utils.types import (
Bins, Bool, CategoricalStrats, DataFrame, DiscretizerStrats, Engine,
Expand All @@ -53,8 +52,8 @@
dataframe_t, sequence_t, series_t,
)
from atom.utils.utils import (
bk, check_is_fitted, composed, crash, get_col_order, get_cols, it, lst,
merge, method_to_log, n_cols, replace_missing, sign, to_df, to_series,
bk, composed, crash, get_col_order, get_cols, it, lst, merge,
method_to_log, n_cols, replace_missing, sign, to_df, to_series,
variable_return, wrap_methods,
)

Expand Down
15 changes: 5 additions & 10 deletions atom/plots/predictionplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,21 +1699,16 @@ class is always the positive one.
for ds in ("train", "test"):
# Calculating shap values is computationally expensive,
# therefore, select a random subsample for large data sets
if len(data := getattr(m, ds)) > 500:
if len(data := getattr(m, f"X_{ds}")) > 500:
data = data.sample(500, random_state=self.random_state)

# Replace data with the calculated shap values
explanation = m._shap.get_explanation(data[m.branch.features], target_c)
data[m.branch.features] = explanation.values
explanation = m._shap.get_explanation(data, target_c)

Check notice on line 1705 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Accessing a protected member of a class or a module

Access to a protected member _shap of a class
shap = bk.DataFrame(explanation.values, columns=m.branch.features)

parshap[ds] = pd.Series(index=fxs, dtype=float)
for fx in fxs:
# All other features are covariates
covariates = [f for f in data.columns[:-1] if f != fx]
cols = [fx, data.columns[-1], *covariates]

# Compute covariance
V = data[cols].cov()
# Compute covariance (other variables are covariates)
V = shap[[c for c in shap if c != fx]].cov()

Check notice on line 1711 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase

# Inverse covariance matrix
Vi = np.linalg.pinv(V, hermitian=True)

Check notice on line 1714 in atom/plots/predictionplot.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
Expand Down
2 changes: 1 addition & 1 deletion atom/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def get_explanation(
)

# Remember shap values in the _shap_values attribute
self._shap_values = pd.concat(
self._shap_values = bk.concat(
[
self._shap_values,
bk.Series(list(self._explanation.values), index=calculate.index),
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ ignore = [
]
per-file-ignores = [
"__init__.py: F401", # Imported but unused
"data_cleaning.py: F401", # Imported but unused (import experimental)
]

[tool.isort]
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from pathlib import Path
from typing import Any
from unittest.mock import patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -89,6 +90,13 @@ def change_current_dir(tmp_path: Path, monkeypatch: MonkeyPatch):
monkeypatch.chdir(tmp_path)


@pytest.fixture(autouse=True)
def mock_mlflow_log_model():
"""Mock mlflow's log_model function."""
with patch("mlflow.sklearn.log_model"):
yield


def get_train_test(
X: XSelector | None,
y: Sequence[Any] | DataFrame,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ def test_nested_runs_to_mlflow(mlflow):
"""Assert that the trials are logged to mlflow as nested runs."""
atom = ATOMClassifier(X_bin, y_bin, experiment="test", random_state=1)
atom.log_ht = True
atom.run("Tree", n_trials=3)
assert mlflow.call_count == 4 # n_trials + fit
atom.run("Tree", n_trials=1, errors='raise')
assert mlflow.call_count == 2 # n_trials + fit


@patch("mlflow.log_params")
Expand Down
35 changes: 14 additions & 21 deletions tests/test_basetrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,16 @@ def test_invalid_model_name():

def test_multiple_models_with_add():
"""Assert that you can add model names to select them."""
trainer = DirectClassifier("gnb+lr+lr_2", random_state=1)
trainer = DirectClassifier("Dummy+tree+tree_2", random_state=1)
trainer.run(bin_train, bin_test)
assert trainer.models == ["GNB", "LR", "LR_2"]
assert trainer.models == ["Dummy", "Tree", "Tree_2"]


def test_multiple_same_models():
"""Assert that the same model can used with different names."""
trainer = DirectClassifier(["lr", "lr_2", "lr_3"], random_state=1)
trainer = DirectClassifier(["Tree", "Tree_2", "Tree_3"], random_state=1)
trainer.run(bin_train, bin_test)
assert trainer.models == ["LR", "LR_2", "LR_3"]
assert trainer.models == ["Tree", "Tree_2", "Tree_3"]


def test_only_task_models():
Expand Down Expand Up @@ -378,39 +378,32 @@ def test_errors_keep():
assert trainer._models == [trainer.lda]


def test_parallel_with_ray():
@patch("atom.basetransformer.ray")
@patch("atom.basetrainer.ray")
def test_parallel_with_ray(_, __):
"""Assert that parallel runs successfully with ray backend."""
trainer = DirectClassifier(
models=["LR", "LDA"],
parallel=True,
n_jobs=2,
n_jobs=1,
backend="ray",
random_state=1,
)
trainer.run(bin_train, bin_test)
assert trainer._models == [trainer.lr, trainer.lda]
# Fails because Mock returns empty list
with pytest.raises(RuntimeError, match=".*All models failed.*"):
trainer.run(bin_train, bin_test)
ray.shutdown()


def test_parallel():
@patch("atom.basetrainer.Parallel")
def test_parallel(_):
"""Assert that parallel runs successfully."""
trainer = DirectClassifier(
models=["LR", "LDA"],
parallel=True,
n_jobs=2,
random_state=1,
)
trainer.run(bin_train, bin_test)
assert trainer._models == [trainer.lr, trainer.lda]


def test_all_models_failed():
"""Assert that an error is raised when all models failed."""
trainer = DirectClassifier(
models=["LR", "RF"],
n_trials=1,
ht_params={"distributions": "test"},
random_state=1,
)
# Fails because Mock returns empty list
with pytest.raises(RuntimeError, match=".*All models failed.*"):
trainer.run(bin_train, bin_test)
14 changes: 7 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from platform import machine
from unittest.mock import MagicMock, patch
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -130,9 +130,9 @@ def test_models_sklearnex_regression():

@patch.dict(
"sys.modules", {
"cuml": MagicMock(spec=["__spec__"]),
"cuml.common.device_selection": MagicMock(spec=["set_global_device_type"]),
"cuml.internals.memory_utils": MagicMock(spec=["set_global_output_type"]),
"cuml": Mock(spec=["__spec__"]),
"cuml.common.device_selection": Mock(spec=["set_global_device_type"]),
"cuml.internals.memory_utils": Mock(spec=["set_global_output_type"]),
}
)
def test_models_cuml_classification():
Expand All @@ -159,9 +159,9 @@ def test_models_cuml_classification():

@patch.dict(
"sys.modules", {
"cuml": MagicMock(spec=["__spec__"]),
"cuml.common.device_selection": MagicMock(spec=["set_global_device_type"]),
"cuml.internals.memory_utils": MagicMock(spec=["set_global_output_type"]),
"cuml": Mock(spec=["__spec__"]),
"cuml.common.device_selection": Mock(spec=["set_global_device_type"]),
"cuml.internals.memory_utils": Mock(spec=["set_global_output_type"]),
}
)
def test_models_cuml_regression():
Expand Down
Loading

0 comments on commit 1813e96

Please sign in to comment.