Skip to content

Commit

Permalink
scikit-learn 1.6 Compatibility (#105)
Browse files Browse the repository at this point in the history
* Fixes for scikit-learn 1.6
  • Loading branch information
reidjohnson authored Nov 12, 2024
1 parent 184a478 commit 7a8fae4
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 48 deletions.
47 changes: 43 additions & 4 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a

import joblib
import numpy as np
import sklearn
from sklearn.ensemble._forest import (
ForestRegressor,
_generate_sample_indices,
Expand All @@ -38,11 +39,19 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor
from sklearn.tree._tree import DTYPE
from sklearn.utils._param_validation import Interval, RealNotInt
from sklearn.utils.fixes import parse_version
from sklearn.utils.validation import check_is_fitted

try:
from sklearn.utils.validation import validate_data
except ImportError:
validate_data = None

from ._quantile_forest_fast import QuantileForest
from ._utils import generate_unsampled_indices, group_indices_by_value, map_indices_to_leaves

sklearn_version = parse_version(sklearn.__version__)


class BaseForestQuantileRegressor(ForestRegressor):
"""Base class for quantile regression forests.
Expand Down Expand Up @@ -132,9 +141,23 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False):
)

super(BaseForestQuantileRegressor, self).fit(X, y, sample_weight=sample_weight)
X, y = self._validate_data(
X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE, force_all_finite=False
)

validation_params = {
"X": X,
"y": y,
"multi_output": True,
"accept_sparse": "csc",
"dtype": DTYPE,
(
"force_all_finite"
if sklearn_version < parse_version("1.6.dev0")
else "ensure_all_finite"
): False,
}
if validate_data is None:
X, y = self._validate_data(**validation_params)
else:
X, y = validate_data(self, **validation_params)

if y.ndim == 1:
y = np.expand_dims(y, axis=1)
Expand Down Expand Up @@ -816,7 +839,23 @@ def quantile_ranks(
Quantile ranks in range [0, 1].
"""
check_is_fitted(self)
X, y = self._validate_data(X, y, multi_output=True, accept_sparse="csc", dtype=DTYPE)

validation_params = {
"X": X,
"y": y,
"multi_output": True,
"accept_sparse": "csc",
"dtype": DTYPE,
(
"force_all_finite"
if sklearn_version < parse_version("1.6.dev0")
else "ensure_all_finite"
): False,
}
if validate_data is None:
X, y = self._validate_data(**validation_params)
else:
X, y = validate_data(self, **validation_params)

if not isinstance(kind, (bytes, bytearray)):
kind = kind.encode()
Expand Down
112 changes: 68 additions & 44 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
assert_raises,
)
from sklearn.utils.validation import check_is_fitted, check_random_state

Expand Down Expand Up @@ -264,7 +263,8 @@ def check_predict_quantiles_toy(name):
weighted_leaves=False,
oob_score=oob_score,
)
assert_raises(AssertionError, assert_allclose, y_pred1, y_pred2)
with pytest.raises(AssertionError):
assert_allclose(y_pred1, y_pred2)

# Check that leaf weighting without weighted quantiles does nothing.
y_pred1 = est.predict(
Expand Down Expand Up @@ -579,8 +579,10 @@ def check_predict_quantiles(
assert np.any(y_pred_1 != y_pred_2)

# Check error if invalid quantiles.
assert_raises(ValueError, est.predict, X_test, -0.01)
assert_raises(ValueError, est.predict, X_test, 1.01)
with pytest.raises(ValueError):
est.predict(X_test, -0.01)
with pytest.raises(ValueError):
est.predict(X_test, 1.01)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down Expand Up @@ -720,7 +722,8 @@ def check_quantile_ranks(name):

# Check error if training and test number of targets are not equal.
est.fit(X_train, y_train[:, 0]) # training target size = 1
assert_raises(ValueError, est.quantile_ranks, X_test, y_test[:, :2]) # test target size = 2
with pytest.raises(ValueError):
est.quantile_ranks(X_test, y_test[:, :2]) # test target size = 2


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down Expand Up @@ -773,10 +776,12 @@ def check_proximity_counts(name):
assert_array_equal([len(p) for p in proximities], [len(e) for e in expected])

# Check error if `max_proximities` < 1.
assert_raises(ValueError, est.proximity_counts, X, max_proximities=0)
with pytest.raises(ValueError):
est.proximity_counts(X, max_proximities=0)

# Check error if `max_proximities` is a float.
assert_raises(ValueError, est.proximity_counts, X, max_proximities=1.5)
with pytest.raises(ValueError):
est.proximity_counts(X, max_proximities=1.5)

# Check that proximity counts match expected counts without splits.
est = ForestRegressor(
Expand Down Expand Up @@ -869,14 +874,25 @@ def check_max_samples_leaf(name):
for param_validation in [True, False]:
est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf)
est.param_validation = param_validation
assert_raises(ValueError, est.fit, X, y)
with pytest.raises(ValueError):
est.fit(X, y)
est.max_samples_leaf = max_samples_leaf
assert_raises(ValueError, est._get_y_train_leaves, X, y)
with pytest.raises(ValueError):
est._get_y_train_leaves(X, y)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
def test_max_samples_leaf(name):
check_max_samples_leaf(name)
"""
Test that `max_samples_leaf` is correctly passed to the `fit` method,
and that it results in the correct maximum leaf size.
Parameters
----------
name : str
The name of the forest regressor to test.
"""


def check_oob_samples(name):
Expand Down Expand Up @@ -1065,16 +1081,16 @@ def check_predict_oob(
assert_allclose(y_pred_oob1, y_pred_oob2)

# Check error if OOB score without `indices` do not match training count.
assert_raises(ValueError, est.predict, X[:1], oob_score=True)
with pytest.raises(ValueError):
est.predict(X[:1], oob_score=True)

# Check error if OOB score with `indices` do not match samples count.
assert_raises(
ValueError,
est.predict,
X,
oob_score=True,
indices=-np.ones(len(X) - 1),
)
with pytest.raises(ValueError):
est.predict(
X,
oob_score=True,
indices=-np.ones(len(X) - 1),
)

# Check warning if not enough estimators.
with np.errstate(divide="ignore", invalid="ignore"):
Expand Down Expand Up @@ -1106,30 +1122,28 @@ def check_predict_oob(
# Check error if no bootstrapping.
est = ForestRegressor(n_estimators=1, bootstrap=False)
est.fit(X, y)
assert_raises(
ValueError,
est.predict,
X,
weighted_quantile=weighted_quantile,
aggregate_leaves_first=aggregate_leaves_first,
oob_score=True,
)
with pytest.raises(ValueError):
est.predict(
X,
weighted_quantile=weighted_quantile,
aggregate_leaves_first=aggregate_leaves_first,
oob_score=True,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
assert np.all(est._get_unsampled_indices(est.estimators_[0]) == np.array([]))

# Check error if number of scoring and training samples are different.
est = ForestRegressor(n_estimators=1, bootstrap=True)
est.fit(X, y)
assert_raises(
ValueError,
est.predict,
X[:1],
y[:1],
weighted_quantile=weighted_quantile,
aggregate_leaves_first=aggregate_leaves_first,
oob_score=True,
)
with pytest.raises(ValueError):
est.predict(
X[:1],
y[:1],
weighted_quantile=weighted_quantile,
aggregate_leaves_first=aggregate_leaves_first,
oob_score=True,
)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down Expand Up @@ -1200,12 +1214,14 @@ def check_quantile_ranks_oob(name):
# Check error if no bootstrapping.
est = ForestRegressor(n_estimators=1, bootstrap=False)
est.fit(X, y)
assert_raises(ValueError, est.quantile_ranks, X, y, oob_score=True)
with pytest.raises(ValueError):
est.quantile_ranks(X, y, oob_score=True)

# Check error if number of scoring and training samples are different.
est = ForestRegressor(n_estimators=1, bootstrap=True)
est.fit(X, y)
assert_raises(ValueError, est.quantile_ranks, X[:1], y[:1], oob_score=True)
with pytest.raises(ValueError):
est.quantile_ranks(X[:1], y[:1], oob_score=True)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down Expand Up @@ -1284,7 +1300,8 @@ def check_proximity_counts_oob(name):
# Check error if no bootstrapping.
est = ForestRegressor(n_estimators=1, max_samples_leaf=None, bootstrap=False)
est.fit(X, y)
assert_raises(ValueError, est.proximity_counts, X, oob_score=True)
with pytest.raises(ValueError):
est.proximity_counts(X, oob_score=True)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down Expand Up @@ -1357,7 +1374,8 @@ def check_monotonic_constraints(name, max_samples_leaf):
max_leaf_nodes=n_samples_train,
bootstrap=True,
)
assert_raises(ValueError, est.fit, X_train, y_train)
with pytest.raises(ValueError):
est.fit(X_train, y_train)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down Expand Up @@ -1466,8 +1484,10 @@ def test_calc_quantile():
assert actual1 != actual2

# Check error if invalid parameters.
assert_raises(TypeError, calc_quantile, [1, 2], 0.5)
assert_raises(TypeError, calc_quantile, [1, 2], [0.5], interpolation=None)
with pytest.raises(TypeError):
calc_quantile([1, 2], 0.5)
with pytest.raises(TypeError):
calc_quantile([1, 2], [0.5], interpolation=None)


def test_calc_weighted_quantile():
Expand Down Expand Up @@ -1585,8 +1605,10 @@ def _dicts_to_input_pairs(input_dicts):
assert actual1 != actual2

# Check error if invalid parameters.
assert_raises(TypeError, calc_weighted_quantile, [1, 2], [1, 1], 0.5)
assert_raises(TypeError, calc_weighted_quantile, [1, 2], [1, 1], [0.5], interpolation=None)
with pytest.raises(TypeError):
calc_weighted_quantile([1, 2], [1, 1], 0.5)
with pytest.raises(TypeError):
calc_weighted_quantile([1, 2], [1, 1], [0.5], interpolation=None)


def test_calc_quantile_rank():
Expand Down Expand Up @@ -1635,5 +1657,7 @@ def test_calc_quantile_rank():
assert actual1 != actual2

# Check error if invalid parameters.
assert_raises(TypeError, calc_quantile_rank, [1, 2], [1])
assert_raises(TypeError, calc_quantile_rank, [1, 2], float(1), kind=None)
with pytest.raises(TypeError):
calc_quantile_rank([1, 2], [1])
with pytest.raises(TypeError):
calc_quantile_rank([1, 2], float(1), kind=None)

0 comments on commit 7a8fae4

Please sign in to comment.