diff --git a/docs/source/api_reference/regression.rst b/docs/source/api_reference/regression.rst index 927f148d1..2855e6b9d 100644 --- a/docs/source/api_reference/regression.rst +++ b/docs/source/api_reference/regression.rst @@ -44,6 +44,18 @@ Model selection and tuning evaluate +Online learning +--------------- + +.. currentmodule:: skpro.regression.online + +.. autosummary:: + :toctree: auto_generated/ + :template: class.rst + + OnlineRefit + OnlineDontRefit + Reduction - adding ``predict_proba`` ------------------------------------ diff --git a/skpro/registry/_tags.py b/skpro/registry/_tags.py index 53fd2f4e0..d35470ca0 100644 --- a/skpro/registry/_tags.py +++ b/skpro/registry/_tags.py @@ -122,6 +122,12 @@ "bool", "whether estimator supports missing values", ), + ( + "capability:online", + "regressor_proba", + "bool", + "whether estimator supports online updates via update", + ), ( "X_inner_mtype", "regressor_proba", diff --git a/skpro/regression/base/_base.py b/skpro/regression/base/_base.py index bbd205ad8..91def4db3 100644 --- a/skpro/regression/base/_base.py +++ b/skpro/regression/base/_base.py @@ -33,6 +33,7 @@ class BaseProbaRegressor(BaseEstimator): "capability:survival": False, "capability:multioutput": False, "capability:missing": True, + "capability:online": False, "X_inner_mtype": "pd_DataFrame_Table", "y_inner_mtype": "pd_DataFrame_Table", "C_inner_mtype": "pd_DataFrame_Table", @@ -136,6 +137,74 @@ def _fit(self, X, y, C=None): """ raise NotImplementedError + def update(self, X, y, C=None): + """Update regressor with a new batch of training data. + + Only estimators with the ``capability:online`` tag (value ``True``) + provide this method, otherwise the method ignores the call and + discards the data passed. + + State required: + Requires state to be "fitted". + + Writes to self: + Updates fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pd.DataFrame, must be same length as X + labels to fit regressor to + C : ignored, optional (default=None) + censoring information for survival analysis + All probabilistic regressors assume data to be uncensored + + Returns + ------- + self : reference to self + """ + capa_online = self.get_tag("capability:online") + capa_surv = self.get_tag("capability:survival") + + if not capa_online: + return self + + check_ret = self._check_X_y(X, y, C, return_metadata=True) + + # get inner X, y, C + X_inner = check_ret["X_inner"] + y_inner = check_ret["y_inner"] + if capa_surv: + C_inner = check_ret["C_inner"] + + if not capa_surv: + return self._update(X_inner, y_inner) + else: + return self._update(X_inner, y_inner, C=C_inner) + + def _update(self, X, y, C=None): + """Update regressor with a new batch of training data. + + State required: + Requires state to be "fitted". + + Writes to self: + Updates fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pandas DataFrame, must be same length as X + labels to fit regressor to + + Returns + ------- + self : reference to self + """ + raise NotImplementedError + def predict(self, X): """Predict labels for data from features. diff --git a/skpro/regression/base/_delegate.py b/skpro/regression/base/_delegate.py index 52d31ab1d..2cd581fc2 100644 --- a/skpro/regression/base/_delegate.py +++ b/skpro/regression/base/_delegate.py @@ -67,6 +67,38 @@ def _fit(self, X, y, C=None): estimator.fit(X=X, y=y, C=C) return self + def _update(self, X, y, C=None): + """Update regressor with a new batch of training data. + + State required: + Requires state to be "fitted" = self.is_fitted=True + + Writes to self: + Updates fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pd.DataFrame, must be same length as X + labels to fit regressor to + C : pd.DataFrame, optional (default=None) + censoring information for survival analysis, + should have same column name as y, same length as X and y + should have entries 0 and 1 (float or int) + 0 = uncensored, 1 = (right) censored + if None, all observations are assumed to be uncensored + Can be passed to any probabilistic regressor, + but is ignored if capability:survival tag is False. + + Returns + ------- + self : reference to self + """ + estimator = self._get_delegate() + estimator.update(X=X, y=y, C=C) + return self + def _predict(self, X): """Predict labels for data from features. diff --git a/skpro/regression/compose/_pipeline.py b/skpro/regression/compose/_pipeline.py index 85466d876..c06be454d 100644 --- a/skpro/regression/compose/_pipeline.py +++ b/skpro/regression/compose/_pipeline.py @@ -336,7 +336,11 @@ def __init__(self, steps): super().__init__() - tags_to_clone = ["capability:multioutput", "capability:survival"] + tags_to_clone = [ + "capability:multioutput", + "capability:survival", + "capability:online", + ] self.clone_tags(self.regressor_, tags_to_clone) @property @@ -427,6 +431,38 @@ def _fit(self, X, y, C=None): return self + def _update(self, X, y, C=None): + """Update regressor with a new batch of training data. + + State required: + Requires state to be "fitted" = self.is_fitted=True + + Writes to self: + Updates fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pd.DataFrame, must be same length as X + labels to fit regressor to + C : pd.DataFrame, optional (default=None) + censoring information for survival analysis, + should have same column name as y, same length as X and y + should have entries 0 and 1 (float or int) + 0 = uncensored, 1 = (right) censored + if None, all observations are assumed to be uncensored + Can be passed to any probabilistic regressor, + but is ignored if capability:survival tag is False. + + Returns + ------- + self : reference to self + """ + X = self._transform(X) + self.regressor_.update(X=X, y=y, C=C) + return self + def _predict(self, X): """Predict labels for data from features. diff --git a/skpro/regression/online/__init__.py b/skpro/regression/online/__init__.py new file mode 100644 index 000000000..90d20a346 --- /dev/null +++ b/skpro/regression/online/__init__.py @@ -0,0 +1,7 @@ +"""Meta-algorithms to build online regression models.""" +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) + +from skpro.regression.online._dont_refit import OnlineDontRefit +from skpro.regression.online._refit import OnlineRefit + +__all__ = ["OnlineDontRefit", "OnlineRefit"] diff --git a/skpro/regression/online/_dont_refit.py b/skpro/regression/online/_dont_refit.py new file mode 100644 index 000000000..6ae70c4fd --- /dev/null +++ b/skpro/regression/online/_dont_refit.py @@ -0,0 +1,106 @@ +"""Meta-strategy for online learning: turn off online update.""" + +__author__ = ["fkiraly"] +__all__ = ["OnlineDontRefit"] + +from skpro.regression.base import _DelegatedProbaRegressor + + +class OnlineDontRefit(_DelegatedProbaRegressor): + """Simple online regression strategy, turns off any refitting. + + In ``fit``, behaves like the wrapped regressor. + In ``update``, does nothing, overriding any other logic. + + This strategy is useful when the wrapped regressor is already an online regressor, + to create a "no-op" online regressor for comparison. + + Parameters + ---------- + estimator : skpro regressor, descendant of BaseProbaRegressor + regressor to be update-refitted on all data, blueprint + + Attributes + ---------- + estimator_ : skpro regressor, descendant of BaseProbaRegressor + clone of the regressor passed in the constructor, fitted on all data + """ + + _tags = {"capability:online": False} + + def __init__(self, estimator): + self.estimator = estimator + + super().__init__() + + tags_to_clone = [ + "capability:missing", + "capability:survival", + ] + self.clone_tags(estimator, tags_to_clone) + + self.estimator_ = self.estimator.clone() + + def _update(self, X, y, C=None): + """Update regressor with new batch of training data. + + State required: + Requires state to be "fitted". + + Writes to self: + Updates fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pandas DataFrame, must be same length as X + labels to fit regressor to + C : pd.DataFrame, optional (default=None) + censoring information for survival analysis, + should have same column name as y, same length as X and y + should have entries 0 and 1 (float or int) + 0 = uncensored, 1 = (right) censored + if None, all observations are assumed to be uncensored + Can be passed to any probabilistic regressor, + but is ignored if capability:survival tag is False. + + Returns + ------- + self : reference to self + """ + return self + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from sklearn.linear_model import LinearRegression + + from skpro.regression.residual import ResidualDouble + from skpro.survival.coxph import CoxPH + from skpro.utils.validation._dependencies import _check_estimator_deps + + regressor = ResidualDouble(LinearRegression()) + + params = [{"estimator": regressor}] + + if _check_estimator_deps(CoxPH, severity="none"): + coxph = CoxPH() + params.append({"estimator": coxph}) + + return params diff --git a/skpro/regression/online/_refit.py b/skpro/regression/online/_refit.py new file mode 100644 index 000000000..8685ec17e --- /dev/null +++ b/skpro/regression/online/_refit.py @@ -0,0 +1,181 @@ +"""Meta-strategy for online learning: refit on full data.""" + +__author__ = ["fkiraly"] +__all__ = ["OnlineRefit"] + +import pandas as pd + +from skpro.regression.base import _DelegatedProbaRegressor + + +class OnlineRefit(_DelegatedProbaRegressor): + """Simple online regression strategy, by refitting the regressor on all data. + + In ``fit`` and ``update``, remembers all data. + In ``update``, refits the regressor on all data seen so far. + + Caveat: data indices are reset to RangeIndex internally, even if some indices + passed in ``fit`` and ``update`` overlap. + + Parameters + ---------- + estimator : skpro regressor, descendant of BaseProbaRegressor + regressor to be update-refitted on all data, blueprint + + Attributes + ---------- + estimator_ : skpro regressor, descendant of BaseProbaRegressor + clone of the regressor passed in the constructor, fitted on all data + """ + + _tags = {"capability:online": True} + + def __init__(self, estimator): + self.estimator = estimator + + super().__init__() + + tags_to_clone = [ + "capability:missing", + "capability:survival", + ] + self.clone_tags(estimator, tags_to_clone) + + def _fit(self, X, y, C=None): + """Fit regressor to training data. + + Writes to self: + Sets fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pandas DataFrame, must be same length as X + labels to fit regressor to + C : pd.DataFrame, optional (default=None) + censoring information for survival analysis, + should have same column name as y, same length as X and y + should have entries 0 and 1 (float or int) + 0 = uncensored, 1 = (right) censored + if None, all observations are assumed to be uncensored + Can be passed to any probabilistic regressor, + but is ignored if capability:survival tag is False. + + Returns + ------- + self : reference to self + """ + estimator = self.estimator.clone() + + estimator.fit(X=X, y=y, C=C) + self.estimator_ = estimator + + # remember data + self._X = X + self._y = y + self._C = C + + return self + + def _update(self, X, y, C=None): + """Update regressor with new batch of training data. + + State required: + Requires state to be "fitted". + + Writes to self: + Updates fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pandas DataFrame, must be same length as X + labels to fit regressor to + C : pd.DataFrame, optional (default=None) + censoring information for survival analysis, + should have same column name as y, same length as X and y + should have entries 0 and 1 (float or int) + 0 = uncensored, 1 = (right) censored + if None, all observations are assumed to be uncensored + Can be passed to any probabilistic regressor, + but is ignored if capability:survival tag is False. + + Returns + ------- + self : reference to self + """ + X_pool = self._update_data(self._X, X) + y_pool = self._update_data(self._y, y) + C_pool = self._update_data(self._C, C) + + estimator = self.estimator.clone() + estimator.fit(X=X_pool, y=y_pool, C=C_pool) + self.estimator_ = estimator + + # remember data + self._X = X_pool + self._y = y_pool + self._C = C_pool + + return self + + def _update_data(self, X, X_new): + """Update data with new batch of training data. + + Treats X_new as data with new indices, even if some indices overlap with X. + + Parameters + ---------- + X : pandas DataFrame + X_new : pandas DataFrame + + Returns + ------- + X_updated : pandas DataFrame + concatenated data, with reset index + """ + if X is None and X_new is None: + return None + if X is None and X_new is not None: + return X_new.reset_index(drop=True) + if X is not None and X_new is None: + return X.reset_index(drop=True) + # else, both X and X_new are not None + X_updated = pd.concat([X, X_new], ignore_index=True) + return X_updated + + @classmethod + def get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + `create_test_instance` uses the first (or only) dictionary in `params` + """ + from sklearn.linear_model import LinearRegression + + from skpro.regression.residual import ResidualDouble + from skpro.survival.coxph import CoxPH + from skpro.utils.validation._dependencies import _check_estimator_deps + + regressor = ResidualDouble(LinearRegression()) + + params = [{"estimator": regressor}] + + if _check_estimator_deps(CoxPH, severity="none"): + coxph = CoxPH() + params.append({"estimator": coxph}) + + return params diff --git a/skpro/regression/tests/test_all_regressors.py b/skpro/regression/tests/test_all_regressors.py index 06d991be3..e39eb33ef 100644 --- a/skpro/regression/tests/test_all_regressors.py +++ b/skpro/regression/tests/test_all_regressors.py @@ -166,3 +166,40 @@ def test_pred_quantiles_interval(self, object_instance, alpha): # check predict_quantiles output contract pred_q = regressor.predict_quantiles(X_test, alpha) self._check_predict_quantiles(pred_q, X_test, y_train, alpha) + + def test_online_update(self, object_instance): + """Test online update of regressor.""" + import pandas as pd + from sklearn.datasets import load_diabetes + from sklearn.model_selection import train_test_split + + X, y = load_diabetes(return_X_y=True, as_frame=True) + X = X.iloc[:70] + y = y.iloc[:70] + y = pd.DataFrame(y) + + X_train, X_test, y_train, _ = train_test_split(X, y) + X_fit, X_update, y_fit, y_update = train_test_split(X_train, y_train) + X_upd1, X_upd2, y_upd1, y_upd2 = train_test_split(X_update, y_update) + + regressor = object_instance + regressor.fit(X_fit, y_fit) + + regressor.update(X_upd1, y_upd1) + y_pred1 = regressor.predict(X_upd1) + y_pred2 = regressor.predict(X_upd2) + + # check predict output contract + assert isinstance(y_pred2, pd.DataFrame) + assert (y_pred1.index == X_upd1.index).all() + assert (y_pred1.columns == y_fit.columns).all() + assert (y_pred2.index == X_upd2.index).all() + assert (y_pred2.columns == y_fit.columns).all() + + regressor.update(X_upd2, y_upd2) + y_pred_test = regressor.predict(X_test) + + # check predict output contract + assert isinstance(y_pred_test, pd.DataFrame) + assert (y_pred_test.index == X_test.index).all() + assert (y_pred_test.columns == y_fit.columns).all() diff --git a/skpro/survival/base.py b/skpro/survival/base.py index 403d43808..a083d08e1 100644 --- a/skpro/survival/base.py +++ b/skpro/survival/base.py @@ -47,3 +47,36 @@ def fit(self, X, y, C=None): """ super().fit(X=X, y=y, C=C) return self + + def update(self, X, y, C=None): + """Update regressor with a new batch of training data. + + Only estimators with the ``capability:online`` tag (value ``True``) + provide this method, otherwise the method ignores the call and + discards the data passed. + + State required: + Requires state to be "fitted". + + Writes to self: + Updates fitted model attributes ending in "_". + + Parameters + ---------- + X : pandas DataFrame + feature instances to fit regressor to + y : pd.DataFrame, must be same length as X + labels to fit regressor to + C : pd.DataFrame, optional (default=None) + censoring information for survival analysis, + should have same column name as y, same length as X and y + should have entries 0 and 1 (float or int) + 0 = uncensored, 1 = (right) censored + if None, all observations are assumed to be uncensored + + Returns + ------- + self : reference to self + """ + super().update(X=X, y=y, C=C) + return self