-
Notifications
You must be signed in to change notification settings - Fork 30
Pre-training, global learning, and fine-tuning API #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,244 @@ | ||
| # API for pre-training and fine-tuning | ||
|
|
||
| Contributors: @fkiraly | ||
|
|
||
| ## High-level summary | ||
|
|
||
| ### The Aim | ||
|
|
||
| `sktime` now has a number of estimators that can do pre-training, fine-tuning, global learning, and cross-learning. | ||
|
|
||
| However, the current API design for these use cases has a few problems, and repeatedly issues have been opened for a rework. | ||
|
|
||
| This STEP is about finalizing a good interface for: | ||
|
|
||
| * global forecasting | ||
| * pre-training | ||
| * fine-tuning of foundation models | ||
| * zero-shot use of foundation models | ||
|
|
||
| References: | ||
|
|
||
| * conceptual design issue: https://github.com/sktime/sktime/issues/6580 | ||
| * umbrella issue foundation models: https://github.com/sktime/sktime/issues/6177 | ||
| * newer issue: https://github.com/sktime/sktime/issues/7838 | ||
|
|
||
| ### requirements | ||
|
|
||
| * design covers the above use cases with a simple interface | ||
| * composability - use of sensible pipelines, tuning, etc should be simple and not require major surgery in current compositors | ||
| * downwards and upwards compatibility - design should not impact current extension contracts | ||
| * maintainability: maintaining the framework and estimators with the above capabilities should be simple | ||
|
|
||
| ### The proposed solution | ||
|
|
||
| Our proposed solution adds a new state, and a simple switch. | ||
|
|
||
| No new public methods are added beyond this, and signatures of methods are not modified. | ||
|
|
||
| Estimators get a third state, "pretraining phase", besides unfitted and fitted. | ||
|
|
||
| The solution is best illustrated in the basic vignette below. | ||
|
|
||
| ### Discussion of current solutions | ||
|
|
||
| There are multiple current solutions, all have problems: | ||
|
|
||
| #### Global forecasting | ||
|
|
||
| Forecasters inheriting from `_BaseGlobalForecaster`. | ||
|
|
||
| A `y` is added in the `predict` methods. If this is passed, the `fit` is interpreted | ||
| as a pretraining pass. | ||
|
|
||
| Problems: | ||
|
|
||
| * some models need to know at `fit` time whether the data is for pretraining. | ||
benHeid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Examples: global reduction approaches. Broadcasting. | ||
| * as a general design principle, all `predict` methods would need to be changed to | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not understand this. Why do
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is subjecte to the current approach, of adding the local batch to In forecasting, for instance, we add the Also, consider the case of classification, where training happens on pairs of instances, one instanece in |
||
| allow for addition of the `fit` arguments. This clashes in cases where arguments | ||
| of the same name are present both in `fit` and `predict`, e.g., the `X` in forecasting, | ||
| or all arguments for transformations. | ||
|
|
||
| #### Pre-training foundation models | ||
|
|
||
| Foundation models currently come in two different, contradictory forms: | ||
|
|
||
| * those that carry out fine-tuning in `fit` | ||
| * those that pass the context in `fit`, e.g., zero-shot models | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which models do this? As far as I known, in most zero-shot capable scenarios the data passed to fit are completely ignored. Thus, I created the following PR: sktime/sktime#7838 (comment)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not correct! For instance, see If no Further, if we would change that to your suggestion - Because there would be no way to forecast the "normal" way with ttm. In particular, this is, for me, a prime example that shows that the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I was not exact in my statement above. It has to be: "data passed in But I now understand better your concerns regarding The difficulty for me is that we now have at least four different kinds of
(Not sure if there a more kinds of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed on that we need to collect and analyze the different kinds of "fitting". I would start with qualitative, and then try to match them in a "quantitative" way, I mean typing or API with mathematical formalism. In my current internal model, from a typing perspective there are two pairs in the above that can be identified:
fine-tuning is a curious case imo: it can be done with different series, but also with same series that are later predicted. I think this is a kind of reduction - "in fit, pretrain the inner model with the partial series passed, then fit inner model". |
||
|
|
||
| Problems: This is inconsistent, and it does not seem to be possible - without an `__init__` arg | ||
| that acts as a switch, or in different classes, to have the same weights in the same class | ||
| be part of a zero-shot or fine-tuning algorithm. | ||
|
|
||
|
|
||
| ## Design: pretraining vignette | ||
|
|
||
| Presenting user facing API. For delineation against current designs: | ||
|
|
||
| * no new arguments are added to `predict`-like methods | ||
benHeid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| * a flag is added before or at `fit` to determine whether usage is normal fitting, or pre-training. | ||
| * two vignettes are presented that pass this information on, for discussion. | ||
|
|
||
| ### basic usage vignette | ||
|
|
||
| Illustrated for global forecasting. | ||
|
|
||
| ```python | ||
|
|
||
| y_pretrain, X_pretrain = load_pretrain_data() | ||
|
|
||
| f = MyForecasterCls(params) | ||
|
|
||
| f.pretrain() | ||
|
||
|
|
||
| f.fit(y=y_pretrain, X=X_pretrain) | ||
|
|
||
| # fh is optional, but some models require this | ||
|
|
||
| f.pretrain("off") | ||
|
|
||
| # usual vignette starts here | ||
| y, X = load_data() | ||
|
|
||
| f.fit(y, X, fh=[1, 2, 3]) | ||
benHeid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| f.predict() | ||
| f.predict_intervals() | ||
| ``` | ||
|
|
||
| With optional serialization after pre-training: | ||
|
|
||
| ```python | ||
|
|
||
| # optional: serialize | ||
|
|
||
| f.save(checkpoint_name) | ||
|
|
||
| # restart | ||
|
|
||
| f = load(checkpoint_name) | ||
| ``` | ||
|
|
||
|
|
||
| ### Alternative vignette 1 | ||
|
|
||
| An alternative idea would be adding an arg to `fit`: | ||
|
|
||
| ```python | ||
|
|
||
| y_pretrain, X_pretrain = load_pretrain_data() | ||
|
|
||
| f = MyForecasterCls(params) | ||
|
|
||
| f.fit(y=y_pretrain, X=X_pretrain, pretrain=True) | ||
|
|
||
| # usual vignette starts here | ||
|
|
||
| y, X = load_data() | ||
|
|
||
| f.fit(y, X, fh=[1, 2, 3]) | ||
|
|
||
| f.predict() | ||
| f.predict_intervals() | ||
| ``` | ||
|
|
||
| ### Alternative vignette 2 | ||
|
|
||
| An alternative idea would be adding an new method | ||
|
|
||
| ```python | ||
|
|
||
| y_pretrain, X_pretrain = load_pretrain_data() | ||
|
|
||
| f = MyForecasterCls(params) | ||
|
|
||
| f.pretrain(y=y_pretrain, X=X_pretrain) | ||
|
|
||
| # usual vignette starts here | ||
|
|
||
| y, X = load_data() | ||
|
|
||
| f.fit(y, X, fh=[1, 2, 3]) | ||
|
|
||
| f.predict() | ||
| f.predict_intervals() | ||
| ``` | ||
|
|
||
|
|
||
| ### Mapping use cases on the vignette | ||
|
|
||
| The following map on the "pre-train" phase: | ||
|
|
||
| * Training for global forecasting | ||
| * fine-tuning | ||
| * pre-training of any other kind | ||
|
|
||
| Zero-shot models do not have pre-training, but `fit` needs to be called, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is confusing, since the call of fit would not be required if we allow to pass values to predict. For me fit is associated with actual fitting a model and not with just passing context.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo we should not merge two issues here:
The representation of "need no fit" is out of scope for this PR, imo that is not sth we should reopen (but we could). The current API requires:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In my opinion and my current conceptual model, fine-tuning is pre-training and not fitting. Fitting is done on the same data we later Therefore, in the current API - before we add any
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @benHeid, I would recommend you think of other examples than foundation models before continuing the discussion, I think having a larger example space is important. Two important examples:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does update mean here? Calling the Let me add the following example, which I think is realistic. Imagine you have a large dataset. You use this for pretraining. Afterwards, you have some specific time series. Now there are two options:
How would this look like from a user perspective and with the interplay between update/pretrain/fit/predict?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No, I meant the concept and not the API method "update". Simply a model update of the inner regressor, fo rinstance, a low rank update to a regularized linear regression model.
That is precisely the answer that we need to settle on. One question: is fine-tuning not a form of pre-training a composite model? Namely, applied to the, for instance, lora adapted model architecture? In vanilla, we could map this to a "pretrain", or to
I think this one, at least, maps clearly, to: "fit passes context, predict produces forecast". |
||
| it is used only to read in the context (there is no `y` in `predict`). | ||
|
|
||
|
|
||
| ## Design: concepts and internals | ||
|
|
||
| ### Conceptual model, state diagram | ||
|
|
||
| Estimators get a third state, from two: | ||
|
|
||
| * blueprint/pristine | ||
| * definition: directly after `__init__` | ||
| * even if a pretrained neural network is constructed with a checkpoint reference, we consider the `sktime` model a blueprint. | ||
| * pretrained (new) | ||
| * definition: pretrained attributes are present, at least one call of `fit` in pretrain mode | ||
| * fitted | ||
| * at least one call of `fit` in normal mode. | ||
|
|
||
| The definition of pretrained is: pretrained attributes are present, definition as below. | ||
|
|
||
| Blueprint transitions to pretrained or directly to fitted. | ||
|
|
||
| Fitted cannot transition back to pretrained. | ||
|
|
||
| ### Pretrained attributes and state attributes | ||
|
|
||
| Pretrained attributes, by convention, start with an underscore and end in an underscore. | ||
|
|
||
| They should not be present after `__init__`. | ||
|
|
||
| A `fit` (or `pretrain`) call may write only to pretrained attributes. | ||
|
|
||
| An attribute `_is_pretrained` is added, this tracks whether the model is pretrained. | ||
|
|
||
| ### Tags | ||
|
|
||
| A tag `capability:pretrain` is introduced, and signifies models with non-trivial pretrained state. | ||
|
|
||
| The default behaviour is not an error raised, but the empty operation (a `pass`). | ||
|
|
||
| ### Extension contract | ||
|
|
||
| An optional extender method `_pretrain` is added. This method returns `self`. | ||
|
|
||
| ### Optional: checkpoint serialization for deep learning models | ||
|
|
||
| Some neural network models may have a `save_checkpoint` method. | ||
|
|
||
| This allows to serialize checkpoints directly for use in `__init__`. | ||
|
|
||
| Not all models will have this method. | ||
|
|
||
| usage: | ||
|
|
||
| ```python | ||
|
|
||
| f = MyDLmodel(checkpoint=my_ckpt_path) | ||
|
|
||
| f.pretrain() | ||
|
|
||
| f.fit(y) | ||
|
|
||
| f.save_checkpoint(my_new_ckpt_path) | ||
|
|
||
| # later, it can be loaded in new kernel: | ||
|
|
||
| f = MyDLmodel(checkpoint=my_new_ckpt_path) | ||
| ``` | ||
Uh oh!
There was an error while loading. Please reload this page.