Skip to content

Commit

Permalink
Fix multi-target serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Sep 10, 2024
1 parent dfb9a6f commit 19caeb7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion quantile_forest/_quantile_forest_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ cdef class QuantileForest:
d = {}
if self.sparse_pickle:
matrix1 = kwargs["y_train_leaves"]
reshape1 = (matrix1.shape[2], matrix1.shape[0] * matrix1.shape[1] * matrix1.shape[2])
reshape1 = (matrix1.shape[3], matrix1.shape[0] * matrix1.shape[1] * matrix1.shape[2])
d["shape1"] = matrix1.shape
d["matrix1"] = sparse.csc_matrix(matrix1.reshape(reshape1))

Expand Down
21 changes: 16 additions & 5 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,15 +1334,24 @@ def test_monotonic_constraints(name, max_samples_leaf):
check_monotonic_constraints(name, max_samples_leaf)


def check_serialization(name, sparse_pickle):
def check_serialization(name, sparse_pickle, monotonic_cst, multi_target):
# Check model serialization/deserialization.

X = X_california
y = y_california

if multi_target:
y = np.vstack([y_california, y_california]).T
else:
y = y_california

if monotonic_cst and not multi_target:
monotonic_cst = [1] * X.shape[1]
else:
monotonic_cst = None

ForestRegressor = FOREST_REGRESSORS[name]

est = ForestRegressor(n_estimators=10, random_state=0)
est = ForestRegressor(n_estimators=10, monotonic_cst=monotonic_cst, random_state=0)
est.fit(X, y, sparse_pickle=sparse_pickle)

dumped = pickle.dumps(est)
Expand All @@ -1354,8 +1363,10 @@ def check_serialization(name, sparse_pickle):

@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@pytest.mark.parametrize("sparse_pickle", [False, True])
def test_serialization(name, sparse_pickle):
check_serialization(name, sparse_pickle)
@pytest.mark.parametrize("monotonic_cst", [False, True])
@pytest.mark.parametrize("multi_target", [False, True])
def test_serialization(name, sparse_pickle, monotonic_cst, multi_target):
check_serialization(name, sparse_pickle, monotonic_cst, multi_target)


def test_calc_quantile():
Expand Down

0 comments on commit 19caeb7

Please sign in to comment.