Skip to content

Commit 7399c42

Browse files
committed
fix: adjust features to make shap happy
1 parent 33a9cfa commit 7399c42

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

pydra_ml/tests/test_classifier.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_classifier(tmpdir):
1717
csv_file = os.path.join(os.path.dirname(__file__), "data", "breast_cancer.csv")
1818
inputs = {
1919
"filename": csv_file,
20-
"x_indices": range(30),
20+
"x_indices": range(10),
2121
"target_vars": ("target",),
2222
"group_var": None,
2323
"n_splits": 2,
@@ -29,7 +29,7 @@ def test_classifier(tmpdir):
2929
"permutation_importance_n_repeats": 5,
3030
"permutation_importance_scoring": "accuracy",
3131
"gen_shap": True,
32-
"nsamples": 5,
32+
"nsamples": 15,
3333
"l1_reg": "aic",
3434
"plot_top_n_shap": 16,
3535
"metrics": ["roc_auc_score", "accuracy_score"],
@@ -40,20 +40,20 @@ def test_classifier(tmpdir):
4040
assert results[0][0]["ml_wf.permute"]
4141
assert results[0][1].output.score[0][0] < results[1][1].output.score[0][0]
4242
assert hasattr(results[2][1].output.model, "predict")
43-
assert isinstance(results[2][1].output.model.predict(np.ones((1, 30))), np.ndarray)
43+
assert isinstance(results[2][1].output.model.predict(np.ones((1, 10))), np.ndarray)
4444

4545

4646
def test_regressor(tmpdir):
4747
clfs = [
4848
[
4949
["sklearn.impute", "SimpleImputer"],
5050
["sklearn.preprocessing", "StandardScaler"],
51-
["sklearn.neural_network", "MLPRegressor", {"alpha": 1, "max_iter": 1000}],
51+
["sklearn.neural_network", "MLPRegressor", {"alpha": 1, "max_iter": 100}],
5252
],
5353
(
5454
"sklearn.linear_model",
5555
"LinearRegression",
56-
{"fit_intercept": True, "normalize": True},
56+
{"fit_intercept": True},
5757
),
5858
]
5959
csv_file = os.path.join(os.path.dirname(__file__), "data", "diabetes_table.csv")
@@ -71,7 +71,7 @@ def test_regressor(tmpdir):
7171
"permutation_importance_n_repeats": 5,
7272
"permutation_importance_scoring": "accuracy",
7373
"gen_shap": True,
74-
"nsamples": 5,
74+
"nsamples": 15,
7575
"l1_reg": "aic",
7676
"plot_top_n_shap": 10,
7777
"metrics": ["explained_variance_score"],

0 commit comments

Comments
 (0)