Skip to content

Commit 7a8fae4

Browse files
authored
scikit-learn 1.6 Compatibility (#105)
* Fixes for scikit-learn 1.6
1 parent 184a478 commit 7a8fae4

File tree

2 files changed

+111
-48
lines changed

2 files changed

+111
-48
lines changed

quantile_forest/_quantile_forest.py

+43-4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a
3030

3131
import joblib
3232
import numpy as np
33+
import sklearn
3334
from sklearn.ensemble._forest import (
3435
ForestRegressor,
3536
_generate_sample_indices,
@@ -38,11 +39,19 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a
3839
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor
3940
from sklearn.tree._tree import DTYPE
4041
from sklearn.utils._param_validation import Interval, RealNotInt
42+
from sklearn.utils.fixes import parse_version
4143
from sklearn.utils.validation import check_is_fitted
4244

45+
try:
46+
from sklearn.utils.validation import validate_data
47+
except ImportError:
48+
validate_data = None
49+
4350
from ._quantile_forest_fast import QuantileForest
4451
from ._utils import generate_unsampled_indices, group_indices_by_value, map_indices_to_leaves
4552

53+
sklearn_version = parse_version(sklearn.__version__)
54+
4655

4756
class BaseForestQuantileRegressor(ForestRegressor):
4857
"""Base class for quantile regression forests.
@@ -132,9 +141,23 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False):
132141
)
133142

134143
super(BaseForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight)
135-
X, y = self._validate_data(
136-
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE, force_all_finite=False
137-
)
144+
145+
validation_params = {
146+
"X": X,
147+
"y": y,
148+
"multi_output": True,
149+
"accept_sparse": "csc",
150+
"dtype": DTYPE,
151+
(
152+
"force_all_finite"
153+
if sklearn_version < parse_version("1.6.dev0")
154+
else "ensure_all_finite"
155+
): False,
156+
}
157+
if validate_data is None:
158+
X, y = self._validate_data(**validation_params)
159+
else:
160+
X, y = validate_data(self, **validation_params)
138161

139162
if y.ndim == 1:
140163
y = np.expand_dims(y, axis=1)
@@ -816,7 +839,23 @@ def quantile_ranks(
816839
Quantile ranks in range [0, 1].
817840
"""
818841
check_is_fitted(self)
819-
X, y = self._validate_data(X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE)
842+
843+
validation_params = {
844+
"X": X,
845+
"y": y,
846+
"multi_output": True,
847+
"accept_sparse": "csc",
848+
"dtype": DTYPE,
849+
(
850+
"force_all_finite"
851+
if sklearn_version < parse_version("1.6.dev0")
852+
else "ensure_all_finite"
853+
): False,
854+
}
855+
if validate_data is None:
856+
X, y = self._validate_data(**validation_params)
857+
else:
858+
X, y = validate_data(self, **validation_params)
820859

821860
if not isinstance(kind, (bytes, bytearray)):
822861
kind = kind.encode()

quantile_forest/tests/test_quantile_forest.py

+68-44
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
assert_almost_equal,
1919
assert_array_almost_equal,
2020
assert_array_equal,
21-
assert_raises,
2221
)
2322
from sklearn.utils.validation import check_is_fitted, check_random_state
2423

@@ -264,7 +263,8 @@ def check_predict_quantiles_toy(name):
264263
weighted_leaves=False,
265264
oob_score=oob_score,
266265
)
267-
assert_raises(AssertionError, assert_allclose, y_pred1, y_pred2)
266+
with pytest.raises(AssertionError):
267+
assert_allclose(y_pred1, y_pred2)
268268

269269
# Check that leaf weighting without weighted quantiles does nothing.
270270
y_pred1 = est.predict(
@@ -579,8 +579,10 @@ def check_predict_quantiles(
579579
assert np.any(y_pred_1 != y_pred_2)
580580

581581
# Check error if invalid quantiles.
582-
assert_raises(ValueError, est.predict, X_test, -0.01)
583-
assert_raises(ValueError, est.predict, X_test, 1.01)
582+
with pytest.raises(ValueError):
583+
est.predict(X_test, -0.01)
584+
with pytest.raises(ValueError):
585+
est.predict(X_test, 1.01)
584586

585587

586588
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -720,7 +722,8 @@ def check_quantile_ranks(name):
720722

721723
# Check error if training and test number of targets are not equal.
722724
est.fit(X_train, y_train[:, 0]) # training target size = 1
723-
assert_raises(ValueError, est.quantile_ranks, X_test, y_test[:, :2]) # test target size = 2
725+
with pytest.raises(ValueError):
726+
est.quantile_ranks(X_test, y_test[:, :2]) # test target size = 2
724727

725728

726729
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -773,10 +776,12 @@ def check_proximity_counts(name):
773776
assert_array_equal([len(p) for p in proximities], [len(e) for e in expected])
774777

775778
# Check error if `max_proximities` < 1.
776-
assert_raises(ValueError, est.proximity_counts, X, max_proximities=0)
779+
with pytest.raises(ValueError):
780+
est.proximity_counts(X, max_proximities=0)
777781

778782
# Check error if `max_proximities` is a float.
779-
assert_raises(ValueError, est.proximity_counts, X, max_proximities=1.5)
783+
with pytest.raises(ValueError):
784+
est.proximity_counts(X, max_proximities=1.5)
780785

781786
# Check that proximity counts match expected counts without splits.
782787
est = ForestRegressor(
@@ -869,14 +874,25 @@ def check_max_samples_leaf(name):
869874
for param_validation in [True, False]:
870875
est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf)
871876
est.param_validation = param_validation
872-
assert_raises(ValueError, est.fit, X, y)
877+
with pytest.raises(ValueError):
878+
est.fit(X, y)
873879
est.max_samples_leaf = max_samples_leaf
874-
assert_raises(ValueError, est._get_y_train_leaves, X, y)
880+
with pytest.raises(ValueError):
881+
est._get_y_train_leaves(X, y)
875882

876883

877884
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
878885
def test_max_samples_leaf(name):
879886
check_max_samples_leaf(name)
887+
"""
888+
Test that `max_samples_leaf` is correctly passed to the `fit` method,
889+
and that it results in the correct maximum leaf size.
890+
891+
Parameters
892+
----------
893+
name : str
894+
The name of the forest regressor to test.
895+
"""
880896

881897

882898
def check_oob_samples(name):
@@ -1065,16 +1081,16 @@ def check_predict_oob(
10651081
assert_allclose(y_pred_oob1, y_pred_oob2)
10661082

10671083
# Check error if OOB score without `indices` do not match training count.
1068-
assert_raises(ValueError, est.predict, X[:1], oob_score=True)
1084+
with pytest.raises(ValueError):
1085+
est.predict(X[:1], oob_score=True)
10691086

10701087
# Check error if OOB score with `indices` do not match samples count.
1071-
assert_raises(
1072-
ValueError,
1073-
est.predict,
1074-
X,
1075-
oob_score=True,
1076-
indices=-np.ones(len(X) - 1),
1077-
)
1088+
with pytest.raises(ValueError):
1089+
est.predict(
1090+
X,
1091+
oob_score=True,
1092+
indices=-np.ones(len(X) - 1),
1093+
)
10781094

10791095
# Check warning if not enough estimators.
10801096
with np.errstate(divide="ignore", invalid="ignore"):
@@ -1106,30 +1122,28 @@ def check_predict_oob(
11061122
# Check error if no bootstrapping.
11071123
est = ForestRegressor(n_estimators=1, bootstrap=False)
11081124
est.fit(X, y)
1109-
assert_raises(
1110-
ValueError,
1111-
est.predict,
1112-
X,
1113-
weighted_quantile=weighted_quantile,
1114-
aggregate_leaves_first=aggregate_leaves_first,
1115-
oob_score=True,
1116-
)
1125+
with pytest.raises(ValueError):
1126+
est.predict(
1127+
X,
1128+
weighted_quantile=weighted_quantile,
1129+
aggregate_leaves_first=aggregate_leaves_first,
1130+
oob_score=True,
1131+
)
11171132
with warnings.catch_warnings():
11181133
warnings.simplefilter("ignore", UserWarning)
11191134
assert np.all(est._get_unsampled_indices(est.estimators_[0]) == np.array([]))
11201135

11211136
# Check error if number of scoring and training samples are different.
11221137
est = ForestRegressor(n_estimators=1, bootstrap=True)
11231138
est.fit(X, y)
1124-
assert_raises(
1125-
ValueError,
1126-
est.predict,
1127-
X[:1],
1128-
y[:1],
1129-
weighted_quantile=weighted_quantile,
1130-
aggregate_leaves_first=aggregate_leaves_first,
1131-
oob_score=True,
1132-
)
1139+
with pytest.raises(ValueError):
1140+
est.predict(
1141+
X[:1],
1142+
y[:1],
1143+
weighted_quantile=weighted_quantile,
1144+
aggregate_leaves_first=aggregate_leaves_first,
1145+
oob_score=True,
1146+
)
11331147

11341148

11351149
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1200,12 +1214,14 @@ def check_quantile_ranks_oob(name):
12001214
# Check error if no bootstrapping.
12011215
est = ForestRegressor(n_estimators=1, bootstrap=False)
12021216
est.fit(X, y)
1203-
assert_raises(ValueError, est.quantile_ranks, X, y, oob_score=True)
1217+
with pytest.raises(ValueError):
1218+
est.quantile_ranks(X, y, oob_score=True)
12041219

12051220
# Check error if number of scoring and training samples are different.
12061221
est = ForestRegressor(n_estimators=1, bootstrap=True)
12071222
est.fit(X, y)
1208-
assert_raises(ValueError, est.quantile_ranks, X[:1], y[:1], oob_score=True)
1223+
with pytest.raises(ValueError):
1224+
est.quantile_ranks(X[:1], y[:1], oob_score=True)
12091225

12101226

12111227
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1284,7 +1300,8 @@ def check_proximity_counts_oob(name):
12841300
# Check error if no bootstrapping.
12851301
est = ForestRegressor(n_estimators=1, max_samples_leaf=None, bootstrap=False)
12861302
est.fit(X, y)
1287-
assert_raises(ValueError, est.proximity_counts, X, oob_score=True)
1303+
with pytest.raises(ValueError):
1304+
est.proximity_counts(X, oob_score=True)
12881305

12891306

12901307
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1357,7 +1374,8 @@ def check_monotonic_constraints(name, max_samples_leaf):
13571374
max_leaf_nodes=n_samples_train,
13581375
bootstrap=True,
13591376
)
1360-
assert_raises(ValueError, est.fit, X_train, y_train)
1377+
with pytest.raises(ValueError):
1378+
est.fit(X_train, y_train)
13611379

13621380

13631381
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1466,8 +1484,10 @@ def test_calc_quantile():
14661484
assert actual1 != actual2
14671485

14681486
# Check error if invalid parameters.
1469-
assert_raises(TypeError, calc_quantile, [1, 2], 0.5)
1470-
assert_raises(TypeError, calc_quantile, [1, 2], [0.5], interpolation=None)
1487+
with pytest.raises(TypeError):
1488+
calc_quantile([1, 2], 0.5)
1489+
with pytest.raises(TypeError):
1490+
calc_quantile([1, 2], [0.5], interpolation=None)
14711491

14721492

14731493
def test_calc_weighted_quantile():
@@ -1585,8 +1605,10 @@ def _dicts_to_input_pairs(input_dicts):
15851605
assert actual1 != actual2
15861606

15871607
# Check error if invalid parameters.
1588-
assert_raises(TypeError, calc_weighted_quantile, [1, 2], [1, 1], 0.5)
1589-
assert_raises(TypeError, calc_weighted_quantile, [1, 2], [1, 1], [0.5], interpolation=None)
1608+
with pytest.raises(TypeError):
1609+
calc_weighted_quantile([1, 2], [1, 1], 0.5)
1610+
with pytest.raises(TypeError):
1611+
calc_weighted_quantile([1, 2], [1, 1], [0.5], interpolation=None)
15901612

15911613

15921614
def test_calc_quantile_rank():
@@ -1635,5 +1657,7 @@ def test_calc_quantile_rank():
16351657
assert actual1 != actual2
16361658

16371659
# Check error if invalid parameters.
1638-
assert_raises(TypeError, calc_quantile_rank, [1, 2], [1])
1639-
assert_raises(TypeError, calc_quantile_rank, [1, 2], float(1), kind=None)
1660+
with pytest.raises(TypeError):
1661+
calc_quantile_rank([1, 2], [1])
1662+
with pytest.raises(TypeError):
1663+
calc_quantile_rank([1, 2], float(1), kind=None)

0 commit comments

Comments
 (0)