-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
119 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""Meta-algorithms to build online regression models.""" | ||
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) | ||
|
||
from skpro.regression.online._dont_refit import OnlineDontRefit | ||
from skpro.regression.online._refit import OnlineRefit | ||
|
||
__all__ = ["OnlineRefit"] | ||
__all__ = ["OnlineDontRefit", "OnlineRefit"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
"""Meta-strategy for online learning: turn off online update.""" | ||
|
||
__author__ = ["fkiraly"] | ||
__all__ = ["OnlineDontRefit"] | ||
|
||
from skpro.regression.base import _DelegatedProbaRegressor | ||
|
||
|
||
class OnlineDontRefit(_DelegatedProbaRegressor): | ||
"""Simple online regression strategy, turns off any refitting. | ||
In ``fit``, behaves like the wrapped regressor. | ||
In ``update``, does nothing, overriding any other logic. | ||
This strategy is useful when the wrapped regressor is already an online regressor, | ||
to create a "no-op" online regressor for comparison. | ||
Parameters | ||
---------- | ||
estimator : skpro regressor, descendant of BaseProbaRegressor | ||
regressor to be update-refitted on all data, blueprint | ||
Attributes | ||
---------- | ||
estimator_ : skpro regressor, descendant of BaseProbaRegressor | ||
clone of the regressor passed in the constructor, fitted on all data | ||
""" | ||
|
||
_tags = {"capability:online": False} | ||
|
||
def __init__(self, estimator): | ||
self.estimator = estimator | ||
|
||
super().__init__() | ||
|
||
tags_to_clone = [ | ||
"capability:missing", | ||
"capability:survival", | ||
] | ||
self.clone_tags(estimator, tags_to_clone) | ||
|
||
def _update(self, X, y, C=None): | ||
"""Update regressor with new batch of training data. | ||
State required: | ||
Requires state to be "fitted". | ||
Writes to self: | ||
Updates fitted model attributes ending in "_". | ||
Parameters | ||
---------- | ||
X : pandas DataFrame | ||
feature instances to fit regressor to | ||
y : pandas DataFrame, must be same length as X | ||
labels to fit regressor to | ||
C : pd.DataFrame, optional (default=None) | ||
censoring information for survival analysis, | ||
should have same column name as y, same length as X and y | ||
should have entries 0 and 1 (float or int) | ||
0 = uncensored, 1 = (right) censored | ||
if None, all observations are assumed to be uncensored | ||
Can be passed to any probabilistic regressor, | ||
but is ignored if capability:survival tag is False. | ||
Returns | ||
------- | ||
self : reference to self | ||
""" | ||
return self | ||
|
||
@classmethod | ||
def get_test_params(cls, parameter_set="default"): | ||
"""Return testing parameter settings for the estimator. | ||
Parameters | ||
---------- | ||
parameter_set : str, default="default" | ||
Name of the set of test parameters to return, for use in tests. If no | ||
special parameters are defined for a value, will return `"default"` set. | ||
Returns | ||
------- | ||
params : dict or list of dict, default = {} | ||
Parameters to create testing instances of the class | ||
Each dict are parameters to construct an "interesting" test instance, i.e., | ||
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. | ||
`create_test_instance` uses the first (or only) dictionary in `params` | ||
""" | ||
from sklearn.linear_model import LinearRegression | ||
|
||
from skpro.regression.residual import ResidualDouble | ||
from skpro.survival.coxph import CoxPH | ||
from skpro.utils.validation._dependencies import _check_estimator_deps | ||
|
||
regressor = ResidualDouble(LinearRegression()) | ||
|
||
params = [{"estimator": regressor}] | ||
|
||
if _check_estimator_deps(CoxPH, severity="none"): | ||
coxph = CoxPH() | ||
params.append({"estimator": coxph}) | ||
|
||
return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters