Skip to content

Commit

Permalink
Refactor Usage of scikit-learn Utilities (#95)
Browse files Browse the repository at this point in the history
Updates for sklearn 1.5+
  • Loading branch information
reidjohnson authored Sep 28, 2024
1 parent 6ad7ffb commit cd5208c
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 61 deletions.
74 changes: 13 additions & 61 deletions quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,19 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a

import joblib
import numpy as np
import sklearn
from sklearn.ensemble._forest import (
ForestRegressor,
_generate_sample_indices,
_get_n_samples_bootstrap,
)
from sklearn.tree import DecisionTreeRegressor, ExtraTreeRegressor
from sklearn.tree._tree import DTYPE

try:
from sklearn.utils.fixes import parse_version
except ImportError:
from sklearn.utils import parse_version

param_validation = True
try:
from sklearn.utils._param_validation import Interval, RealNotInt
except ImportError:
param_validation = False
from sklearn.utils._param_validation import Interval, RealNotInt
from sklearn.utils.validation import check_is_fitted

from ._quantile_forest_fast import QuantileForest
from ._utils import generate_unsampled_indices, group_indices_by_value, map_indices_to_leaves

sklearn_version = parse_version(sklearn.__version__)


class BaseForestQuantileRegressor(ForestRegressor):
"""Base class for quantile regression forests.
Expand All @@ -64,17 +51,16 @@ class BaseForestQuantileRegressor(ForestRegressor):
instead.
"""

if param_validation:
_parameter_constraints: dict = {
**ForestRegressor._parameter_constraints,
**DecisionTreeRegressor._parameter_constraints,
"max_samples_leaf": [
None,
Interval(RealNotInt, 0, 1, closed="right"),
Interval(Integral, 1, None, closed="left"),
],
}
_parameter_constraints.pop("splitter")
_parameter_constraints: dict = {
**ForestRegressor._parameter_constraints,
**DecisionTreeRegressor._parameter_constraints,
"max_samples_leaf": [
None,
Interval(RealNotInt, 0, 1, closed="right"),
Interval(Integral, 1, None, closed="left"),
],
}
_parameter_constraints.pop("splitter")

@abstractmethod
def __init__(
Expand Down Expand Up @@ -107,8 +93,6 @@ def __init__(
}
super().__init__(**init_dict)

self.param_validation = hasattr(self, "_parameter_constraints")

def fit(self, X, y, sample_weight=None, sparse_pickle=False):
"""Build a forest from the training set (X, y).
Expand All @@ -135,26 +119,8 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False):
self : object
Fitted estimator.
"""
if self.param_validation:
self._validate_params()
else:
if isinstance(self.max_samples_leaf, (Integral, np.integer)):
if self.max_samples_leaf < 1:
raise ValueError(
"If max_samples_leaf is an integer, it must be be >= 1, "
f"got {self.max_samples_leaf}."
)
elif isinstance(self.max_samples_leaf, Real):
if not 0.0 < self.max_samples_leaf <= 1.0:
raise ValueError(
"If max_samples_leaf is a float, it must be in range (0, 1], "
f"got {self.max_samples_leaf}."
)
elif self.max_samples_leaf is not None:
raise ValueError(
"max_samples_leaf must be of integer, float, or None type, got "
f"{self.max_samples_leaf}."
)
self._validate_params()

if self.monotonic_cst is not None:
if (
not isinstance(self.max_samples_leaf, (Integral, np.integer))
Expand Down Expand Up @@ -1210,17 +1176,12 @@ class RandomForestQuantileRegressor(BaseForestQuantileRegressor):
- regressions trained on data with missing values,
- trees with multi-sample leaves (i.e. when `max_samples_leaf > 1`).
.. sklearn-versionadded:: 1.4
Attributes
----------
estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor`
The child estimator template used to create the collection of fitted
sub-estimators.
.. sklearn-versionadded:: 1.2
`base_estimator_` was renamed to `estimator_`.
estimators_ : list of DecisionTreeRegressor
The collection of fitted sub-estimators.
Expand Down Expand Up @@ -1257,8 +1218,6 @@ class RandomForestQuantileRegressor(BaseForestQuantileRegressor):
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by an array of the indices selected.
.. sklearn-versionadded:: 1.4
See Also
--------
ExtraTreesQuantileRegressor : Quantile ensemble of extremely randomized
Expand Down Expand Up @@ -1556,17 +1515,12 @@ class ExtraTreesQuantileRegressor(BaseForestQuantileRegressor):
- regressions trained on data with missing values,
- trees with multi-sample leaves (i.e. when `max_samples_leaf > 1`).
.. sklearn-versionadded:: 1.4
Attributes
----------
estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor`
The child estimator template used to create the collection of fitted
sub-estimators.
.. sklearn-versionadded:: 1.2
`base_estimator_` was renamed to `estimator_`.
estimators_ : list of DecisionTreeRegressor
The collection of fitted sub-estimators.
Expand Down Expand Up @@ -1603,8 +1557,6 @@ class ExtraTreesQuantileRegressor(BaseForestQuantileRegressor):
The subset of drawn samples (i.e., the in-bag samples) for each base
estimator. Each subset is defined by an array of the indices selected.
.. sklearn-versionadded:: 1.4
See Also
--------
RandomForestQuantileRegressor : Quantile ensemble regressor using trees.
Expand Down
29 changes: 29 additions & 0 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,35 @@ def test_regression_toy(name, weighted_quantile):
check_regression_toy(name, weighted_quantile)


def check_regression_params(name):
params = {
"criterion": "squared_error",
"max_depth": 2,
"min_samples_split": 2,
"min_samples_leaf": 1,
"min_weight_fraction_leaf": 0.0,
"max_features": 1.0,
"max_leaf_nodes": 16,
"min_impurity_decrease": 0.0,
"ccp_alpha": 0.0,
"monotonic_cst": [0, 1, -1, 0],
}

ForestRegressor = FOREST_REGRESSORS[name]

X, y = datasets.make_regression(n_features=4, n_informative=2, shuffle=True, random_state=0)

est = ForestRegressor(**params, random_state=0).fit(X, y)

for param in params:
assert getattr(est, param) == getattr(est.estimators_[0], param)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
def test_regression_params(name):
check_regression_params(name)


def check_california_criterion(name, criterion):
"""Check for consistency on the California Housing dataset."""
ForestRegressor = FOREST_REGRESSORS[name]
Expand Down

0 comments on commit cd5208c

Please sign in to comment.