Skip to content

Commit

Permalink
dont refit
Browse files Browse the repository at this point in the history
  • Loading branch information
fkiraly committed Sep 27, 2024
1 parent 871b711 commit c3d9fa1
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 10 deletions.
20 changes: 12 additions & 8 deletions docs/source/api_reference/regression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ Composition

Pipeline

.. currentmodule:: skpro.regression.online

.. autosummary::
:toctree: auto_generated/
:template: class.rst

OnlineRefit

Model selection and tuning
--------------------------

Expand All @@ -52,6 +44,18 @@ Model selection and tuning

evaluate

Online learning
---------------

.. currentmodule:: skpro.regression.online

.. autosummary::
:toctree: auto_generated/
:template: class.rst

OnlineRefit
OnlineDontRefit

Reduction - adding ``predict_proba``
------------------------------------

Expand Down
3 changes: 2 additions & 1 deletion skpro/regression/online/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Meta-algorithms to build online regression models."""
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)

from skpro.regression.online._dont_refit import OnlineDontRefit
from skpro.regression.online._refit import OnlineRefit

__all__ = ["OnlineRefit"]
__all__ = ["OnlineDontRefit", "OnlineRefit"]
104 changes: 104 additions & 0 deletions skpro/regression/online/_dont_refit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Meta-strategy for online learning: turn off online update."""

__author__ = ["fkiraly"]
__all__ = ["OnlineDontRefit"]

from skpro.regression.base import _DelegatedProbaRegressor


class OnlineDontRefit(_DelegatedProbaRegressor):
"""Simple online regression strategy, turns off any refitting.
In ``fit``, behaves like the wrapped regressor.
In ``update``, does nothing, overriding any other logic.
This strategy is useful when the wrapped regressor is already an online regressor,
to create a "no-op" online regressor for comparison.
Parameters
----------
estimator : skpro regressor, descendant of BaseProbaRegressor
regressor to be update-refitted on all data, blueprint
Attributes
----------
estimator_ : skpro regressor, descendant of BaseProbaRegressor
clone of the regressor passed in the constructor, fitted on all data
"""

_tags = {"capability:online": False}

def __init__(self, estimator):
self.estimator = estimator

super().__init__()

tags_to_clone = [
"capability:missing",
"capability:survival",
]
self.clone_tags(estimator, tags_to_clone)

def _update(self, X, y, C=None):
"""Update regressor with new batch of training data.
State required:
Requires state to be "fitted".
Writes to self:
Updates fitted model attributes ending in "_".
Parameters
----------
X : pandas DataFrame
feature instances to fit regressor to
y : pandas DataFrame, must be same length as X
labels to fit regressor to
C : pd.DataFrame, optional (default=None)
censoring information for survival analysis,
should have same column name as y, same length as X and y
should have entries 0 and 1 (float or int)
0 = uncensored, 1 = (right) censored
if None, all observations are assumed to be uncensored
Can be passed to any probabilistic regressor,
but is ignored if capability:survival tag is False.
Returns
-------
self : reference to self
"""
return self

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator.
Parameters
----------
parameter_set : str, default="default"
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.
Returns
-------
params : dict or list of dict, default = {}
Parameters to create testing instances of the class
Each dict are parameters to construct an "interesting" test instance, i.e.,
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
`create_test_instance` uses the first (or only) dictionary in `params`
"""
from sklearn.linear_model import LinearRegression

from skpro.regression.residual import ResidualDouble
from skpro.survival.coxph import CoxPH
from skpro.utils.validation._dependencies import _check_estimator_deps

regressor = ResidualDouble(LinearRegression())

params = [{"estimator": regressor}]

if _check_estimator_deps(CoxPH, severity="none"):
coxph = CoxPH()
params.append({"estimator": coxph})

return params
2 changes: 1 addition & 1 deletion skpro/regression/online/_refit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Bagging probabilistic regressors."""
"""Meta-strategy for online learning: refit on full data."""

__author__ = ["fkiraly"]
__all__ = ["OnlineRefit"]
Expand Down

0 comments on commit c3d9fa1

Please sign in to comment.