Skip to content

Commit

Permalink
Update test_all_regressors.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Sep 14, 2024
1 parent 8275d72 commit 871b711
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions skpro/regression/tests/test_all_regressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,20 @@ def test_online_update(self, object_instance):
regressor.fit(X_fit, y_fit)

regressor.update(X_upd1, y_upd1)
y_pred1 = regressor.predict(X_upd2)
y_pred1 = regressor.predict(X_upd1)
y_pred2 = regressor.predict(X_upd2)

# check predict output contract
assert isinstance(y_pred1, pd.DataFrame)
assert isinstance(y_pred2, pd.DataFrame)
assert (y_pred1.index == X_upd1.index).all()
assert (y_pred1.columns == y_fit.columns).all()
assert (y_pred2.index == X_upd2.index).all()
assert (y_pred2.columns == y_fit.columns).all()

regressor.update(X_upd2, y_upd2)
y_pred2 = regressor.predict(X_test)
y_pred_test = regressor.predict(X_test)

# check predict output contract
assert isinstance(y_pred2, pd.DataFrame)
assert (y_pred2.index == X_test.index).all()
assert (y_pred2.columns == y_fit.columns).all()
assert isinstance(y_pred_test, pd.DataFrame)
assert (y_pred_test.index == X_test.index).all()
assert (y_pred_test.columns == y_fit.columns).all()

0 comments on commit 871b711

Please sign in to comment.