Skip to content

Commit 0eedb1a

Browse files
committed
fix stacking classifiers
1 parent 2c25333 commit 0eedb1a

File tree

6 files changed

+39
-30
lines changed

6 files changed

+39
-30
lines changed

docs/sources/CHANGELOG.md

+14
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@ The CHANGELOG for the current development version is available at
77

88
---
99

10+
### Version 0.23.2 (TBD)
11+
12+
##### Downloads
13+
14+
- [Source code (zip)](https://github.com/rasbt/mlxtend/archive/v0.23.2.zip)
15+
16+
- [Source code (tar.gz)](https://github.com/rasbt/mlxtend/archive/v0.23.2.tar.gz)
17+
18+
##### Changes
19+
20+
- Add `n_classes_` attribute to stacking classifiers for compatibility with scikit-learn 1.3 ([#1091](https://github.com/rasbt/mlxtend/issues/1091)
21+
22+
23+
1024
### Version 0.23.1 (5 Jan 2024)
1125

1226
##### Downloads

mlxtend/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
#
55
# License: BSD 3 clause
66

7-
__version__ = "0.23.1"
7+
__version__ = "0.23.2dev"

mlxtend/classifier/stacking_classification.py

+11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414
from scipy import sparse
1515
from sklearn.base import TransformerMixin, clone
16+
from sklearn.preprocessing import LabelEncoder
1617

1718
from ..externals.estimator_checks import check_is_fitted
1819
from ..externals.name_estimators import _name_estimators
@@ -95,6 +96,9 @@ class StackingClassifier(_BaseXComposition, _BaseStackingClassifier, Transformer
9596
Fitted classifiers (clones of the original classifiers)
9697
meta_clf_ : estimator
9798
Fitted meta-classifier (clone of the original meta-estimator)
99+
classes_ : ndarray of shape (n_classes,) or list of ndarray if `y` \
100+
is of type `"multilabel-indicator"`.
101+
Class labels.
98102
train_meta_features : numpy array, shape = [n_samples, n_classifiers]
99103
meta-features for training data, where n_samples is the
100104
number of samples
@@ -175,6 +179,13 @@ def fit(self, X, y, sample_weight=None):
175179
self.clfs_ = self.classifiers
176180
self.meta_clf_ = self.meta_classifier
177181

182+
if y.ndim > 1:
183+
self._label_encoder = [LabelEncoder().fit(yk) for yk in y.T]
184+
self.classes_ = [le.classes_ for le in self._label_encoder]
185+
else:
186+
self._label_encoder = LabelEncoder().fit(y)
187+
self.classes_ = self._label_encoder.classes_
188+
178189
if self.fit_base_estimators:
179190
if self.verbose > 0:
180191
print("Fitting %d classifiers..." % (len(self.classifiers)))

mlxtend/classifier/stacking_cv_classification.py

+11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.base import TransformerMixin, clone
1515
from sklearn.model_selection import cross_val_predict
1616
from sklearn.model_selection._split import check_cv
17+
from sklearn.preprocessing import LabelEncoder
1718

1819
from ..externals.estimator_checks import check_is_fitted
1920
from ..externals.name_estimators import _name_estimators
@@ -129,6 +130,9 @@ class StackingCVClassifier(
129130
Fitted classifiers (clones of the original classifiers)
130131
meta_clf_ : estimator
131132
Fitted meta-classifier (clone of the original meta-estimator)
133+
classes_ : ndarray of shape (n_classes,) or list of ndarray if `y` \
134+
is of type `"multilabel-indicator"`.
135+
Class labels.
132136
train_meta_features : numpy array, shape = [n_samples, n_classifiers]
133137
meta-features for training data, where n_samples is the
134138
number of samples
@@ -220,6 +224,13 @@ def fit(self, X, y, groups=None, sample_weight=None):
220224
if self.verbose > 0:
221225
print("Fitting %d classifiers..." % (len(self.classifiers)))
222226

227+
if y.ndim > 1:
228+
self._label_encoder = [LabelEncoder().fit(yk) for yk in y.T]
229+
self.classes_ = [le.classes_ for le in self._label_encoder]
230+
else:
231+
self._label_encoder = LabelEncoder().fit(y)
232+
self.classes_ = self._label_encoder.classes_
233+
223234
final_cv = check_cv(self.cv, y, classifier=self.stratify)
224235
if isinstance(self.cv, int):
225236
# Override shuffle parameter in case of self generated

mlxtend/classifier/tests/test_stacking_classifier.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
y2 = np.c_[y, y]
3434

3535

36-
@pytest.mark.skip(reason="scikit-learn implemented a StackingClassifier in 0.22.")
3736
def test_StackingClassifier():
3837
np.random.seed(123)
3938
meta = LogisticRegression(
@@ -162,7 +161,6 @@ def test_weight_unsupported_no_weight():
162161
sclf.fit(X, y)
163162

164163

165-
@pytest.mark.skip(reason="scikit-learn implemented a StackingClassifier in 0.22.")
166164
def test_StackingClassifier_proba_avg_1():
167165
np.random.seed(123)
168166
meta = LogisticRegression(solver="liblinear", multi_class="ovr", random_state=1)
@@ -177,7 +175,6 @@ def test_StackingClassifier_proba_avg_1():
177175
assert scores_mean == 0.93, scores_mean
178176

179177

180-
@pytest.mark.skip(reason="scikit-learn implemented a StackingClassifier in 0.22.")
181178
def test_StackingClassifier_proba_concat_1():
182179
np.random.seed(123)
183180
meta = LogisticRegression(solver="liblinear", multi_class="ovr")
@@ -325,7 +322,6 @@ def test_gridsearch_enumerate_names():
325322
grid = grid.fit(X, y)
326323

327324

328-
@pytest.mark.skip(reason="scikit-learn implemented a StackingClassifier in 0.22.")
329325
def test_use_probas():
330326
np.random.seed(123)
331327
meta = LogisticRegression(solver="liblinear", multi_class="ovr")
@@ -391,7 +387,6 @@ def test_verbose():
391387
sclf.fit(X, y)
392388

393389

394-
@pytest.mark.skip(reason="scikit-learn implemented a StackingClassifier in 0.22.")
395390
def test_use_features_in_secondary_predict():
396391
np.random.seed(123)
397392
X, y = iris_data()
@@ -424,7 +419,6 @@ def test_use_features_in_secondary_predict_proba():
424419
np.testing.assert_almost_equal(y_pred, expect, 3)
425420

426421

427-
@pytest.mark.skip(reason="scikit-learn implemented a StackingClassifier in 0.22.")
428422
def test_use_features_in_secondary_sparse_input_predict():
429423
np.random.seed(123)
430424
X, y = iris_data()
@@ -537,7 +531,6 @@ def test_clone():
537531
clone(stclf)
538532

539533

540-
@pytest.mark.skip(reason="scikit-learn implemented a StackingClassifier in 0.22.")
541534
def test_decision_function():
542535
np.random.seed(123)
543536

@@ -572,7 +565,7 @@ def test_decision_function():
572565
if Version(sklearn_version) < Version("0.22"):
573566
assert scores_mean == 0.95, scores_mean
574567
else:
575-
assert scores_mean == 0.94, scores_mean
568+
assert scores_mean == 0.93, scores_mean
576569

577570

578571
def test_drop_col_unsupported():

mlxtend/classifier/tests/test_stacking_cv_classifier.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@
4040
X_breast, y_breast = breast_cancer.data[:, 1:3], breast_cancer.target
4141

4242

43-
@pytest.mark.skip(
44-
reason="scikit-learn implemented a StackingClassifier in 0.22. It has built-in cross-validation."
45-
)
4643
def test_StackingCVClassifier():
4744
np.random.seed(123)
4845
meta = LogisticRegression(multi_class="ovr", solver="liblinear")
@@ -174,9 +171,7 @@ def test_no_weight_support_with_no_weight():
174171
sclf.fit(X_iris, y_iris)
175172

176173

177-
@pytest.mark.skip(
178-
reason="scikit-learn implemented a StackingClassifier in 0.22. It has built-in cross-validation."
179-
)
174+
180175
def test_StackingClassifier_proba():
181176
np.random.seed(12)
182177
meta = LogisticRegression(multi_class="ovr", solver="liblinear")
@@ -245,9 +240,6 @@ def test_gridsearch_enumerate_names():
245240
grid = grid.fit(X, y)
246241

247242

248-
@pytest.mark.skip(
249-
reason="scikit-learn implemented a StackingClassifier in 0.22. It has built-in cross-validation."
250-
)
251243
def test_use_probas():
252244
np.random.seed(123)
253245
meta = LogisticRegression(multi_class="ovr", solver="liblinear")
@@ -262,9 +254,6 @@ def test_use_probas():
262254
assert scores_mean == 0.94, scores_mean
263255

264256

265-
@pytest.mark.skip(
266-
reason="scikit-learn implemented a StackingClassifier in 0.22. It has built-in cross-validation."
267-
)
268257
def test_use_features_in_secondary():
269258
np.random.seed(123)
270259
meta = LogisticRegression(multi_class="ovr", solver="liblinear")
@@ -282,9 +271,6 @@ def test_use_features_in_secondary():
282271
assert scores_mean == 0.93, scores_mean
283272

284273

285-
@pytest.mark.skip(
286-
reason="scikit-learn implemented a StackingClassifier in 0.22. It has built-in cross-validation."
287-
)
288274
def test_do_not_stratify():
289275
meta = LogisticRegression(multi_class="ovr", solver="liblinear")
290276
clf1 = RandomForestClassifier(n_estimators=10)
@@ -298,9 +284,6 @@ def test_do_not_stratify():
298284
assert scores_mean == 0.93, scores.mean()
299285

300286

301-
@pytest.mark.skip(
302-
reason="scikit-learn implemented a StackingClassifier in 0.22. It has built-in cross-validation."
303-
)
304287
def test_cross_validation_technique():
305288
# This is like the `test_do_not_stratify` but instead
306289
# autogenerating the cross validation strategy it provides
@@ -640,9 +623,6 @@ def test_works_with_df_if_fold_indexes_missing():
640623
)
641624

642625

643-
@pytest.mark.skip(
644-
reason="scikit-learn implemented a StackingClassifier in 0.22. It has built-in cross-validation."
645-
)
646626
def test_decision_function():
647627
np.random.seed(123)
648628

0 commit comments

Comments
 (0)