Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions steps/25_pretrain/step.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# API for pre-training and fine-tuning

Contributors: @fkiraly

## High-level summary

### The Aim

`sktime` now has a number of estimators that can do pre-training, fine-tuning, global learning, and cross-learning.

However, the current API design for these use cases has a few problems, and repeatedly issues have been opened for a rework.

This STEP is about finalizing a good interface for:

* global forecasting
* pre-training
* fine-tuning of foundation models
* zero-shot use of foundation models

References:

* conceptual design issue: https://github.com/sktime/sktime/issues/6580
* umbrella issue foundation models: https://github.com/sktime/sktime/issues/6177
* newer issue: https://github.com/sktime/sktime/issues/7838

### requirements

* design covers the above use cases with a simple interface
* composability - use of sensible pipelines, tuning, etc should be simple and not require major surgery in current compositors
* downwards and upwards compatibility - design should not impact current extension contracts
* maintainability: maintaining the framework and estimators with the above capabilities should be simple

### The proposed solution

Our proposed solution adds a new state, and a simple switch.

No new public methods are added beyond this, and signatures of methods are not modified.

Estimators get a third state, "pretraining phase", besides unfitted and fitted.

The solution is best illustrated in the basic vignette below.

### Discussion of current solutions

There are multiple current solutions, all have problems:

#### Global forecasting

Forecasters inheriting from `_BaseGlobalForecaster`.

A `y` is added in the `predict` methods. If this is passed, the `fit` is interpreted
as a pretraining pass.

Problems:

* some models need to know at `fit` time whether the data is for pretraining.
Examples: global reduction approaches. Broadcasting.
* as a general design principle, all `predict` methods would need to be changed to
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand this. Why do predict methods need to get the fit arguments?

Copy link
Contributor Author

@fkiraly fkiraly Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is subjecte to the current approach, of adding the local batch to predict-like methods. A training batch is described by the full set of data-like arguments in fit.

In forecasting, for instance, we add the y to each and every predict-like method.

Also, consider the case of classification, where training happens on pairs of instances, one instanece in X pairing with a label in y. To turn time series classifiers into "global models" - or, equivalently, fine-tunable ones - we would need to be able to pass the entire "local" training batch, X and y, to predict and predict_proba. But, there already is an X in predict...

allow for addition of the `fit` arguments. This clashes in cases where arguments
of the same name are present both in `fit` and `predict`, e.g., the `X` in forecasting,
or all arguments for transformations.

#### Pre-training foundation models

Foundation models currently come in two different, contradictory forms:

* those that carry out fine-tuning in `fit`
* those that pass the context in `fit`, e.g., zero-shot models
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which models do this? As far as I known, in most zero-shot capable scenarios the data passed to fit are completely ignored. Thus, I created the following PR: sktime/sktime#7838 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not correct! For instance, see TinyTimeMixerForecaster.

If no y is passed in predict, then the y passed in fit is the context.

Further, if we would change that to your suggestion - y being ignored in fit - we would break API uniformity!

Because there would be no way to forecast the "normal" way with ttm.

In particular, this is, for me, a prime example that shows that the y in predict is not a good idea:

  • you need to remember _y from fit, because you do not know whether the user will pass one in predict
  • the "pure" design where we do not do this remembering is no longer in a uniform API with all other forecasters!

Copy link

@benHeid benHeid Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I was not exact in my statement above. It has to be: "data passed in y might completely be ignored by the forecast in case predict receives a new y."

But I now understand better your concerns regarding y in predict.

The difficulty for me is that we now have at least four different kinds of fit, which has to be distinguishable:

  • fit for setting the context.
  • fitfor training a model in a global setting (I think this is referred to as pretraining)
  • fit for fine-tuning a model on (a single/few time series) (either with full fine tuning or PEFT approaches)
  • fit for training a model in the broadcasting setting.

(Not sure if there a more kinds of fit when using other forecasters in sktime). So helpful for me would be to collect the different kinds of fit and how we distinguish between them and if we have to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on that we need to collect and analyze the different kinds of "fitting". I would start with qualitative, and then try to match them in a "quantitative" way, I mean typing or API with mathematical formalism.

In my current internal model, from a typing perspective there are two pairs in the above that can be identified:

  • setting the context, training a model in the broadcasting setting (with VectorizedDf, and the current sktime vanilla case
  • training a model in a global setting, and fine-tuning

fine-tuning is a curious case imo: it can be done with different series, but also with same series that are later predicted. I think this is a kind of reduction - "in fit, pretrain the inner model with the partial series passed, then fit inner model".


Problems: This is inconsistent, and it does not seem to be possible - without an `__init__` arg
that acts as a switch, or in different classes, to have the same weights in the same class
be part of a zero-shot or fine-tuning algorithm.


## Design: pretraining vignette

Presenting user facing API. For delineation against current designs:

* no new arguments are added to `predict`-like methods
* a flag is added before or at `fit` to determine whether usage is normal fitting, or pre-training.
* two vignettes are presented that pass this information on, for discussion.

### basic usage vignette

Illustrated for global forecasting.

```python

y_pretrain, X_pretrain = load_pretrain_data()

f = MyForecasterCls(params)

f.pretrain()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what makes it confusing for me is the word pretrain. I think it should be just train. This would make it much clearer that the simple fit does not train.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not precious about the name - happy to call it train if this is more consistent with, say, torch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much clearer that the simple fit does not train.

Though, sometimes it does, and in a very general sense it always does, expect if the fit_is_empty tag is True.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Puh, this is very philosophical :D and probably depends on the perspective.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think one of the key secrets in API design is to discard philosophical considerations and purely consider the typing of operations.

At least, considering input/output and parametrization and the action carried out tends to lead to better APIs than considering hard to quantify "meaning" assigned to it.

Which, I suppose, is again a bit of a philosophical statement...


f.fit(y=y_pretrain, X=X_pretrain)

# fh is optional, but some models require this

f.pretrain("off")

# usual vignette starts here
y, X = load_data()

f.fit(y, X, fh=[1, 2, 3])

f.predict()
f.predict_intervals()
```

With optional serialization after pre-training:

```python

# optional: serialize

f.save(checkpoint_name)

# restart

f = load(checkpoint_name)
```


### Alternative vignette 1

An alternative idea would be adding an arg to `fit`:

```python

y_pretrain, X_pretrain = load_pretrain_data()

f = MyForecasterCls(params)

f.fit(y=y_pretrain, X=X_pretrain, pretrain=True)

# usual vignette starts here

y, X = load_data()

f.fit(y, X, fh=[1, 2, 3])

f.predict()
f.predict_intervals()
```

### Alternative vignette 2

An alternative idea would be adding an new method

```python

y_pretrain, X_pretrain = load_pretrain_data()

f = MyForecasterCls(params)

f.pretrain(y=y_pretrain, X=X_pretrain)

# usual vignette starts here

y, X = load_data()

f.fit(y, X, fh=[1, 2, 3])

f.predict()
f.predict_intervals()
```


### Mapping use cases on the vignette

The following map on the "pre-train" phase:

* Training for global forecasting
* fine-tuning
* pre-training of any other kind

Zero-shot models do not have pre-training, but `fit` needs to be called,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is confusing, since the call of fit would not be required if we allow to pass values to predict. For me fit is associated with actual fitting a model and not with just passing context.
Furthermore, it is also possible to fit zero-shot models. e.g. for fine-tuning them on own data. Would this mean to switch pertaining on before starting the training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo we should not merge two issues here:

  • how to map models that "need no fit" onto the current interface
  • how to map fine-tuning

The representation of "need no fit" is out of scope for this PR, imo that is not sth we should reopen (but we could). The current API requires:

  • all models to have a fit, and it to be executed before predict.
  • some index connection between data seen in fit and predict, for most estimators (in the consolidated interfaces)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Furthermore, it is also possible to fit zero-shot models. e.g. for fine-tuning them on own data. Would this mean to switch pretraining on before starting the training?

In my opinion and my current conceptual model, fine-tuning is pre-training and not fitting. Fitting is done on the same data we later predict on.

Therefore, in the current API - before we add any y to predict in forecasting, and also for all other learning tasks - zero-shot models pass the context in fit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benHeid, I would recommend you think of other examples than foundation models before continuing the discussion, I think having a larger example space is important.

Two important examples:

  • recursive reduction with a global regression model that has an update. We fit this on a number of pre-training batches (global per batch), and then for application we update it on the local data set.
    • a special case is the current recursive reducer, namely without pre-training.
  • the naive strategy of predicting the mean over pre-training time series.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does update mean here? Calling the update method? And what exactly would it do, training or just setting the context?

Let me add the following example, which I think is realistic. Imagine you have a large dataset. You use this for pretraining. Afterwards, you have some specific time series. Now there are two options:

  • Finetune the model further on one specific time series
  • Directly forecast on one specific time series.

How would this look like from a user perspective and with the interplay between update/pretrain/fit/predict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does update mean here? Calling the update method? And what exactly would it do, training or just setting the context?

No, I meant the concept and not the API method "update". Simply a model update of the inner regressor, fo rinstance, a low rank update to a regularized linear regression model.

How would this look like from a user perspective and with the interplay between update/pretrain/fit/predict?

That is precisely the answer that we need to settle on.

One question: is fine-tuning not a form of pre-training a composite model? Namely, applied to the, for instance, lora adapted model architecture?

In vanilla, we could map this to a "pretrain", or to fit. The key question becomes, if we want to fine-tune on a different instance vs the same instance: should this be in different modes? I.e., pretrain vs fit? Or should this be consistently in the same mode? Then it would have to be pretrain, if we do no twant to add new arguments to predict.

Directly forecast on one specific time series.

I think this one, at least, maps clearly, to: "fit passes context, predict produces forecast".

it is used only to read in the context (there is no `y` in `predict`).


## Design: concepts and internals

### Conceptual model, state diagram

Estimators get a third state, from two:

* blueprint/pristine
* definition: directly after `__init__`
* even if a pretrained neural network is constructed with a checkpoint reference, we consider the `sktime` model a blueprint.
* pretrained (new)
* definition: pretrained attributes are present, at least one call of `fit` in pretrain mode
* fitted
* at least one call of `fit` in normal mode.

The definition of pretrained is: pretrained attributes are present, definition as below.

Blueprint transitions to pretrained or directly to fitted.

Fitted cannot transition back to pretrained.

### Pretrained attributes and state attributes

Pretrained attributes, by convention, start with an underscore and end in an underscore.

They should not be present after `__init__`.

A `fit` (or `pretrain`) call may write only to pretrained attributes.

An attribute `_is_pretrained` is added, this tracks whether the model is pretrained.

### Tags

A tag `capability:pretrain` is introduced, and signifies models with non-trivial pretrained state.

The default behaviour is not an error raised, but the empty operation (a `pass`).

### Extension contract

An optional extender method `_pretrain` is added. This method returns `self`.

### Optional: checkpoint serialization for deep learning models

Some neural network models may have a `save_checkpoint` method.

This allows to serialize checkpoints directly for use in `__init__`.

Not all models will have this method.

usage:

```python

f = MyDLmodel(checkpoint=my_ckpt_path)

f.pretrain()

f.fit(y)

f.save_checkpoint(my_new_ckpt_path)

# later, it can be loaded in new kernel:

f = MyDLmodel(checkpoint=my_new_ckpt_path)
```