Skip to content

Commit

Permalink
Update serialization unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Sep 10, 2024
1 parent a988bae commit 6c8a288
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ def test_monotonic_constraints(name, max_samples_leaf):
check_monotonic_constraints(name, max_samples_leaf)


def check_serialization(name, sparse_pickle, multi_target):
def check_serialization(name, sparse_pickle, monotonic_cst, multi_target):
# Check model serialization/deserialization.

X = X_california
Expand All @@ -1344,9 +1344,14 @@ def check_serialization(name, sparse_pickle, multi_target):
else:
y = y_california

if monotonic_cst and not multi_target:
monotonic_cst = [1] * X.shape[1]
else:
monotonic_cst = None

ForestRegressor = FOREST_REGRESSORS[name]

est = ForestRegressor(n_estimators=10, random_state=0)
est = ForestRegressor(n_estimators=10, monotonic_cst=monotonic_cst, random_state=0)
est.fit(X, y, sparse_pickle=sparse_pickle)

dumped = pickle.dumps(est)
Expand All @@ -1358,9 +1363,10 @@ def check_serialization(name, sparse_pickle, multi_target):

@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@pytest.mark.parametrize("sparse_pickle", [False, True])
@pytest.mark.parametrize("monotonic_cst", [False, True])
@pytest.mark.parametrize("multi_target", [False, True])
def test_serialization(name, sparse_pickle, multi_target):
check_serialization(name, sparse_pickle, multi_target)
def test_serialization(name, sparse_pickle, monotonic_cst, multi_target):
check_serialization(name, sparse_pickle, monotonic_cst, multi_target)


def test_calc_quantile():
Expand Down

0 comments on commit 6c8a288

Please sign in to comment.