diff --git a/skpro/regression/tests/test_all_regressors.py b/skpro/regression/tests/test_all_regressors.py index 9b971866..e39eb33e 100644 --- a/skpro/regression/tests/test_all_regressors.py +++ b/skpro/regression/tests/test_all_regressors.py @@ -186,17 +186,20 @@ def test_online_update(self, object_instance): regressor.fit(X_fit, y_fit) regressor.update(X_upd1, y_upd1) - y_pred1 = regressor.predict(X_upd2) + y_pred1 = regressor.predict(X_upd1) + y_pred2 = regressor.predict(X_upd2) # check predict output contract - assert isinstance(y_pred1, pd.DataFrame) + assert isinstance(y_pred2, pd.DataFrame) assert (y_pred1.index == X_upd1.index).all() assert (y_pred1.columns == y_fit.columns).all() + assert (y_pred2.index == X_upd2.index).all() + assert (y_pred2.columns == y_fit.columns).all() regressor.update(X_upd2, y_upd2) - y_pred2 = regressor.predict(X_test) + y_pred_test = regressor.predict(X_test) # check predict output contract - assert isinstance(y_pred2, pd.DataFrame) - assert (y_pred2.index == X_test.index).all() - assert (y_pred2.columns == y_fit.columns).all() + assert isinstance(y_pred_test, pd.DataFrame) + assert (y_pred_test.index == X_test.index).all() + assert (y_pred_test.columns == y_fit.columns).all()