Skip to content

Commit 29fa8e7

Browse files
committed
Fix unit tests
1 parent 1a99de9 commit 29fa8e7

File tree

2 files changed

+70
-46
lines changed

2 files changed

+70
-46
lines changed

quantile_forest/_quantile_forest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False):
155155
): False,
156156
}
157157
if validate_data is None:
158-
self._validate_data(**validation_params)
158+
X, y = self._validate_data(**validation_params)
159159
else:
160160
X, y = validate_data(self, **validation_params)
161161

@@ -853,7 +853,7 @@ def quantile_ranks(
853853
): False,
854854
}
855855
if validate_data is None:
856-
self._validate_data(**validation_params)
856+
X, y = self._validate_data(**validation_params)
857857
else:
858858
X, y = validate_data(self, **validation_params)
859859

quantile_forest/tests/test_quantile_forest.py

+68-44
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
assert_almost_equal,
1919
assert_array_almost_equal,
2020
assert_array_equal,
21-
assert_raises,
2221
)
2322
from sklearn.utils.validation import check_is_fitted, check_random_state
2423

@@ -264,7 +263,8 @@ def check_predict_quantiles_toy(name):
264263
weighted_leaves=False,
265264
oob_score=oob_score,
266265
)
267-
assert_raises(AssertionError, assert_allclose, y_pred1, y_pred2)
266+
with pytest.raises(AssertionError):
267+
assert_allclose(y_pred1, y_pred2)
268268

269269
# Check that leaf weighting without weighted quantiles does nothing.
270270
y_pred1 = est.predict(
@@ -579,8 +579,10 @@ def check_predict_quantiles(
579579
assert np.any(y_pred_1 != y_pred_2)
580580

581581
# Check error if invalid quantiles.
582-
assert_raises(ValueError, est.predict, X_test, -0.01)
583-
assert_raises(ValueError, est.predict, X_test, 1.01)
582+
with pytest.raises(ValueError):
583+
est.predict(X_test, -0.01)
584+
with pytest.raises(ValueError):
585+
est.predict(X_test, 1.01)
584586

585587

586588
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -720,7 +722,8 @@ def check_quantile_ranks(name):
720722

721723
# Check error if training and test number of targets are not equal.
722724
est.fit(X_train, y_train[:, 0]) # training target size = 1
723-
assert_raises(ValueError, est.quantile_ranks, X_test, y_test[:, :2]) # test target size = 2
725+
with pytest.raises(ValueError):
726+
est.quantile_ranks(X_test, y_test[:, :2]) # test target size = 2
724727

725728

726729
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -773,10 +776,12 @@ def check_proximity_counts(name):
773776
assert_array_equal([len(p) for p in proximities], [len(e) for e in expected])
774777

775778
# Check error if `max_proximities` < 1.
776-
assert_raises(ValueError, est.proximity_counts, X, max_proximities=0)
779+
with pytest.raises(ValueError):
780+
est.proximity_counts(X, max_proximities=0)
777781

778782
# Check error if `max_proximities` is a float.
779-
assert_raises(ValueError, est.proximity_counts, X, max_proximities=1.5)
783+
with pytest.raises(ValueError):
784+
est.proximity_counts(X, max_proximities=1.5)
780785

781786
# Check that proximity counts match expected counts without splits.
782787
est = ForestRegressor(
@@ -869,14 +874,25 @@ def check_max_samples_leaf(name):
869874
for param_validation in [True, False]:
870875
est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf)
871876
est.param_validation = param_validation
872-
assert_raises(ValueError, est.fit, X, y)
877+
with pytest.raises(ValueError):
878+
est.fit(X, y)
873879
est.max_samples_leaf = max_samples_leaf
874-
assert_raises(ValueError, est._get_y_train_leaves, X, y)
880+
with pytest.raises(ValueError):
881+
est._get_y_train_leaves(X, y)
875882

876883

877884
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
878885
def test_max_samples_leaf(name):
879886
check_max_samples_leaf(name)
887+
"""
888+
Test that `max_samples_leaf` is correctly passed to the `fit` method,
889+
and that it results in the correct maximum leaf size.
890+
891+
Parameters
892+
----------
893+
name : str
894+
The name of the forest regressor to test.
895+
"""
880896

881897

882898
def check_oob_samples(name):
@@ -1065,16 +1081,16 @@ def check_predict_oob(
10651081
assert_allclose(y_pred_oob1, y_pred_oob2)
10661082

10671083
# Check error if OOB score without `indices` do not match training count.
1068-
assert_raises(ValueError, est.predict, X[:1], oob_score=True)
1084+
with pytest.raises(ValueError):
1085+
est.predict(X[:1], oob_score=True)
10691086

10701087
# Check error if OOB score with `indices` do not match samples count.
1071-
assert_raises(
1072-
ValueError,
1073-
est.predict,
1074-
X,
1075-
oob_score=True,
1076-
indices=-np.ones(len(X) - 1),
1077-
)
1088+
with pytest.raises(ValueError):
1089+
est.predict(
1090+
X,
1091+
oob_score=True,
1092+
indices=-np.ones(len(X) - 1),
1093+
)
10781094

10791095
# Check warning if not enough estimators.
10801096
with np.errstate(divide="ignore", invalid="ignore"):
@@ -1106,30 +1122,28 @@ def check_predict_oob(
11061122
# Check error if no bootstrapping.
11071123
est = ForestRegressor(n_estimators=1, bootstrap=False)
11081124
est.fit(X, y)
1109-
assert_raises(
1110-
ValueError,
1111-
est.predict,
1112-
X,
1113-
weighted_quantile=weighted_quantile,
1114-
aggregate_leaves_first=aggregate_leaves_first,
1115-
oob_score=True,
1116-
)
1125+
with pytest.raises(ValueError):
1126+
est.predict(
1127+
X,
1128+
weighted_quantile=weighted_quantile,
1129+
aggregate_leaves_first=aggregate_leaves_first,
1130+
oob_score=True,
1131+
)
11171132
with warnings.catch_warnings():
11181133
warnings.simplefilter("ignore", UserWarning)
11191134
assert np.all(est._get_unsampled_indices(est.estimators_[0]) == np.array([]))
11201135

11211136
# Check error if number of scoring and training samples are different.
11221137
est = ForestRegressor(n_estimators=1, bootstrap=True)
11231138
est.fit(X, y)
1124-
assert_raises(
1125-
ValueError,
1126-
est.predict,
1127-
X[:1],
1128-
y[:1],
1129-
weighted_quantile=weighted_quantile,
1130-
aggregate_leaves_first=aggregate_leaves_first,
1131-
oob_score=True,
1132-
)
1139+
with pytest.raises(ValueError):
1140+
est.predict(
1141+
X[:1],
1142+
y[:1],
1143+
weighted_quantile=weighted_quantile,
1144+
aggregate_leaves_first=aggregate_leaves_first,
1145+
oob_score=True,
1146+
)
11331147

11341148

11351149
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1200,12 +1214,14 @@ def check_quantile_ranks_oob(name):
12001214
# Check error if no bootstrapping.
12011215
est = ForestRegressor(n_estimators=1, bootstrap=False)
12021216
est.fit(X, y)
1203-
assert_raises(ValueError, est.quantile_ranks, X, y, oob_score=True)
1217+
with pytest.raises(ValueError):
1218+
est.quantile_ranks(X, y, oob_score=True)
12041219

12051220
# Check error if number of scoring and training samples are different.
12061221
est = ForestRegressor(n_estimators=1, bootstrap=True)
12071222
est.fit(X, y)
1208-
assert_raises(ValueError, est.quantile_ranks, X[:1], y[:1], oob_score=True)
1223+
with pytest.raises(ValueError):
1224+
est.quantile_ranks(X[:1], y[:1], oob_score=True)
12091225

12101226

12111227
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1284,7 +1300,8 @@ def check_proximity_counts_oob(name):
12841300
# Check error if no bootstrapping.
12851301
est = ForestRegressor(n_estimators=1, max_samples_leaf=None, bootstrap=False)
12861302
est.fit(X, y)
1287-
assert_raises(ValueError, est.proximity_counts, X, oob_score=True)
1303+
with pytest.raises(ValueError):
1304+
est.proximity_counts(X, oob_score=True)
12881305

12891306

12901307
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1357,7 +1374,8 @@ def check_monotonic_constraints(name, max_samples_leaf):
13571374
max_leaf_nodes=n_samples_train,
13581375
bootstrap=True,
13591376
)
1360-
assert_raises(ValueError, est.fit, X_train, y_train)
1377+
with pytest.raises(ValueError):
1378+
est.fit(X_train, y_train)
13611379

13621380

13631381
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@@ -1466,8 +1484,10 @@ def test_calc_quantile():
14661484
assert actual1 != actual2
14671485

14681486
# Check error if invalid parameters.
1469-
assert_raises(TypeError, calc_quantile, [1, 2], 0.5)
1470-
assert_raises(TypeError, calc_quantile, [1, 2], [0.5], interpolation=None)
1487+
with pytest.raises(TypeError):
1488+
calc_quantile([1, 2], 0.5)
1489+
with pytest.raises(TypeError):
1490+
calc_quantile([1, 2], [0.5], interpolation=None)
14711491

14721492

14731493
def test_calc_weighted_quantile():
@@ -1585,8 +1605,10 @@ def _dicts_to_input_pairs(input_dicts):
15851605
assert actual1 != actual2
15861606

15871607
# Check error if invalid parameters.
1588-
assert_raises(TypeError, calc_weighted_quantile, [1, 2], [1, 1], 0.5)
1589-
assert_raises(TypeError, calc_weighted_quantile, [1, 2], [1, 1], [0.5], interpolation=None)
1608+
with pytest.raises(TypeError):
1609+
calc_weighted_quantile([1, 2], [1, 1], 0.5)
1610+
with pytest.raises(TypeError):
1611+
calc_weighted_quantile([1, 2], [1, 1], [0.5], interpolation=None)
15901612

15911613

15921614
def test_calc_quantile_rank():
@@ -1635,5 +1657,7 @@ def test_calc_quantile_rank():
16351657
assert actual1 != actual2
16361658

16371659
# Check error if invalid parameters.
1638-
assert_raises(TypeError, calc_quantile_rank, [1, 2], [1])
1639-
assert_raises(TypeError, calc_quantile_rank, [1, 2], float(1), kind=None)
1660+
with pytest.raises(TypeError):
1661+
calc_quantile_rank([1, 2], [1])
1662+
with pytest.raises(TypeError):
1663+
calc_quantile_rank([1, 2], float(1), kind=None)

0 commit comments

Comments
 (0)