diff --git a/docs/user_guide.rst b/docs/user_guide.rst index 5144fb9..1337f74 100755 --- a/docs/user_guide.rst +++ b/docs/user_guide.rst @@ -104,9 +104,9 @@ The predictions of a standard random forest can also be recovered from a quantil >>> X, y = datasets.load_diabetes(return_X_y=True) >>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25) >>> rf = RandomForestRegressor(random_state=0) - >>> qrf = RandomForestQuantileRegressor(random_state=0) + >>> qrf = RandomForestQuantileRegressor(max_samples_leaf=None, random_state=0) >>> rf.fit(X_train, y_train), qrf.fit(X_train, y_train) - (RandomForestRegressor(random_state=0), RandomForestQuantileRegressor(random_state=0)) + (RandomForestRegressor(random_state=0), RandomForestQuantileRegressor(max_samples_leaf=None, random_state=0)) >>> y_pred_rf = rf.predict(X_test) >>> y_pred_qrf = qrf.predict(X_test, quantiles=None, aggregate_leaves_first=False) >>> np.allclose(y_pred_rf, y_pred_qrf) diff --git a/quantile_forest/_quantile_forest.py b/quantile_forest/_quantile_forest.py index cf6824b..89a737e 100755 --- a/quantile_forest/_quantile_forest.py +++ b/quantile_forest/_quantile_forest.py @@ -30,6 +30,7 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a import joblib import numpy as np +import sklearn from sklearn.ensemble._forest import ForestRegressor from sklearn.ensemble._forest import _generate_sample_indices @@ -42,6 +43,8 @@ class calls the ``fit`` method of the ``ForestRegressor`` and creates a from ._quantile_forest_fast import QuantileForest from ._quantile_forest_fast import generate_unsampled_indices +sklearn_version = tuple(map(int, (sklearn.__version__.split('.')))) + def _generate_unsampled_indices(sample_indices, duplicates=None): """Private function used by forest._get_unsampled_indices function.""" @@ -980,10 +983,10 @@ def __init__( ccp_alpha=0.0, max_samples=None, ): - super(RandomForestQuantileRegressor, self).__init__( - base_estimator=DecisionTreeRegressor(), - n_estimators=n_estimators, - estimator_params=( + init_dict = { + 'base_estimator' if sklearn_version < (1, 2) else 'estimator': DecisionTreeRegressor(), + 'n_estimators': n_estimators, + 'estimator_params': ( "criterion", "max_depth", "min_samples_split", @@ -995,14 +998,15 @@ def __init__( "random_state", "ccp_alpha", ), - bootstrap=bootstrap, - oob_score=oob_score, - n_jobs=n_jobs, - random_state=random_state, - verbose=verbose, - warm_start=warm_start, - max_samples=max_samples, - ) + 'bootstrap': bootstrap, + 'oob_score': oob_score, + 'n_jobs': n_jobs, + 'random_state': random_state, + 'verbose': verbose, + 'warm_start': warm_start, + 'max_samples': max_samples, + } + super(RandomForestQuantileRegressor, self).__init__(**init_dict) self.criterion = criterion self.max_depth = max_depth @@ -1253,10 +1257,10 @@ def __init__( ccp_alpha=0.0, max_samples=None, ): - super(ExtraTreesQuantileRegressor, self).__init__( - base_estimator=ExtraTreeRegressor(), - n_estimators=n_estimators, - estimator_params=( + init_dict = { + 'base_estimator' if sklearn_version < (1, 2) else 'estimator': ExtraTreeRegressor(), + 'n_estimators': n_estimators, + 'estimator_params': ( "criterion", "max_depth", "min_samples_split", @@ -1268,14 +1272,15 @@ def __init__( "random_state", "ccp_alpha", ), - bootstrap=bootstrap, - oob_score=oob_score, - n_jobs=n_jobs, - random_state=random_state, - verbose=verbose, - warm_start=warm_start, - max_samples=max_samples, - ) + 'bootstrap': bootstrap, + 'oob_score': oob_score, + 'n_jobs': n_jobs, + 'random_state': random_state, + 'verbose': verbose, + 'warm_start': warm_start, + 'max_samples': max_samples, + } + super(ExtraTreesQuantileRegressor, self).__init__(**init_dict) self.criterion = criterion self.max_depth = max_depth diff --git a/quantile_forest/version.txt b/quantile_forest/version.txt index 9084fa2..524cb55 100644 --- a/quantile_forest/version.txt +++ b/quantile_forest/version.txt @@ -1 +1 @@ -1.1.0 +1.1.1 diff --git a/requirements.txt b/requirements.txt index ed863fe..513b7ad 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ cython >= 3.0a4 -numpy -scipy +numpy >= 1.23 +scipy >= 1.4 scikit-learn >= 1.0