Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Variational Bayesian last layer models as surrogate models #2754

Open
wants to merge 24 commits into
base: main
Choose a base branch
from

Conversation

brunzema
Copy link
Contributor

Motivation

This PR adds variational Bayesian last layers (VBLLs) [1], which demonstrated very promising results in the context of BO in our last paper [2], to BoTorch. The goal is to provide a BoTorch-compatible implementation of VBLL surrogates for standard use cases (single-output models), making them accessible to the community as quickly as possible. This PR does not yet contain all the features discussed in [2] such as the continual learning. If there is the interest to also add the continual learning, I am happy to add them down the line!

The VBLLs can be used in standard acquisition functions such as (log)EI but are especially nice for Thompson sampling as the Thompson sample of a Bayesian last layer model is a differentiable standard feed forward neural network which is useful for (almost) global optimization of the sample for the next query location.

Implementation details

This PR adds the implementation to the community folders--also here, if there is a large interest in the model, I am happy to help merge them into the main part of the repo. The added files of this PR are the following

botorch_community
|-- acquisition
|   |-- bll_thompson_sampling.py # TS for Bayesian last layer models
|-- models
|   |-- vblls.py # BoTorch wrapper for VBLLs
|-- posteriors
|   |-- bll_posterior.py # Posterior class for Bayesian last layer models
notebooks_community
|-- vbll_thompson_sampling.ipynb # Tutorial on how use the VBLL model
test_community
|-- models
|   |-- test_vblls.py # test for the VBLL models functionality (backbone freezing for feature reuse, etc)

The current implementation build directly on the VBLL repo, which is actively maintained and depends only on PyTorch. Using this repo allows improvements—e.g., better variational posterior initialization—to be directly beneficial for BO.

Have you read the Contributing Guidelines on pull requests?

Yes.

Test Plan

The PR does not change any functionality of the current code base. The core functionality of the VBLLs should be covered by test_vblls.py. Let me know if further tests are required.

Related PRs

This PR does not change functionality and I did not see any PRs regarding last layer models in BoTorch. Maybe this implementation can useful also for other BLLs.

References

[1] P. Brunzema, M. Jordahn, J. Willes, S. Trimpe, J. Snoek, J. Harrison. Bayesian Optimization via Continual Variational Last Layer Training. International Conference on Learning Representations (ICLR), 2025.

[2] J. Harrison, J. Willes, J. Snoek. Variational Bayesian Last Layers. International Conference on Learning Representations (ICLR), 2024.

@facebook-github-bot facebook-github-bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Feb 21, 2025
@brunzema brunzema marked this pull request as ready for review February 25, 2025 08:27
@eytan
Copy link
Contributor

eytan commented Feb 25, 2025 via email

@Balandat
Copy link
Contributor

Thanks for putting this up, @brunzema. The notebook looks great, and I plan to review this PR in more detail over the next day or two.

Regarding the dependency on vbll: Right now it looks like the code only uses ~120 lines of pure torch code from the vbll repo (namely this
https://github.com/VectorInstitute/vbll/blob/main/vbll/layers/regression.py#L34-L149 plus the minimal function https://github.com/VectorInstitute/vbll/blob/main/vbll/utils/distributions.py#L94-L98). It seems questionable to me to take that dependency, especially since the vllb repo doesn't include unit tests and/or CI. We spend a nontrivial amount of time fixing issues with downstream dependencies, so we're really careful about adding them only if really necessary.

My preference would be to move the relevant pieces of the vbll code mentioned above into a helper module (and clearly attribute the source there, of course) so we can avoid the dependency for now. If we do end up expanding the functionality and use additional features from vbll, then I'd be happy to reconsider (provided the vbll repo adds proper unit tests and a CI setup).

Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests are failing b/c vbll is not installed in the CI. This will not be necessary if we move the minimal code required in a helper module that we include here.

Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, submitted prematurely. Here is the rest of the review.

"\n",
" ax.plot(x_test, mean, label=\"Posterior predictive\", color=\"tab:blue\")\n",
"\n",
" # Posterior samples\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TS samples shown are not the samples that were optimized to obtain the next candidate point. This initially caused some confusion; can you make the plots in a way so that this is consistent (i.e. that we show the set of TSs that were optimized to obtain the next candidate)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mhm, yeah thats true.. I now removed the samples from the plot to avoid this. I could also include in BLLMaxPosteriorSampling to return the sampled functions but this use case seemed too specific. Do you think this is also relevant beyond visualization? Then, this might be nice imo

@brunzema
Copy link
Contributor Author

@eytan Thank you for the nice comment and @Balandat thank you for the detailed review!

Just wanted to quickly give an update on the vblls: I talked to my collaborators and we are happy to move the relevant part (regression layer + some utils) to botorch. I will update this PR to incorporate all suggested changes within the next few days + include the vbll code 👍

@brunzema
Copy link
Contributor Author

@Balandat Thank you again for the detailed review--I really appreciate it! I’ve updated the PR to address all your points, but please let me know if any further changes are needed. I’m happy to make any additional updates!

Biggest change is of course the added code from the vbll package, let me know if the way I know included it is ok (additional file + credit at the top).

@Balandat
Copy link
Contributor

@brunzema thanks a lot for the updates - will review this within the next couple of days!

Copy link

codecov bot commented Mar 10, 2025

Codecov Report

Attention: Patch coverage is 92.52874% with 39 lines in your changes missing coverage. Please review.

Project coverage is 99.79%. Comparing base (39dd171) to head (ae1dca7).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
botorch_community/models/vbll_helper.py 85.82% 37 Missing ⚠️
...rch_community/acquisition/bll_thompson_sampling.py 96.66% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##              main    #2754      +/-   ##
===========================================
- Coverage   100.00%   99.79%   -0.21%     
===========================================
  Files          206      211       +5     
  Lines        18599    19197     +598     
===========================================
+ Hits         18599    19158     +559     
- Misses           0       39      +39     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates and in particular for avoiding another external dependency! I left a number of inline comments, here are some higher level ones:

  1. Please rebase this on a recent version of master (the base of commit of this PR is pretty old)
  2. Please address flake8 (incl. line length) and import sorting errors - you can install the pre-commit hooks to make sure that your code confirms to the standard (see https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pre-commit-hooks)
  3. I spotted some some rather problematic numerical code in the vbll helpers, let's update that (highlighted inline), shall we?

@Balandat
Copy link
Contributor

@brunzema checking in here, anything needed to get this over the finish line? Seems like we're very close.

@brunzema
Copy link
Contributor Author

@Balandat no, the delay is fully on me, sorry! Everything is pretty much done, flake8 is adressed and also with the import sorting + from __future__ import annotations I no longer require the not so nice type_checking condition 👍 I will make a fresh rebase tomorrow and a final pass and hopefully in the morning your time, I have submitted the updated PR :)

@Balandat
Copy link
Contributor

Excellent!

@brunzema brunzema requested a review from Balandat March 27, 2025 14:12
@brunzema
Copy link
Contributor Author

@Balandat I updated the PR. Not sure why I thought the importing was fixed yesterday—I still had the type checking in place. I noticed that some files in the main repo also use it, so it should be fine?

That said, I think a cleaner approach (which I’ve implemented in the updated PR) is to extract the abstract Bayesian last-layer (BLL) class and use it as a parent class/interface. This should also make it more extensible for future BLL models.

botorch_community
|-- acquisition
|   |-- bll_thompson_sampling.py # AbstractBLL to specify "interface"
|-- models
|   |-- blls.py # Define AbstractBLL model
|   |-- vblls.py # VBLL inherit from AbstractBLL
|-- posteriors
|   |-- bll_posterior.py # AbstractBLL to specify "interface"
...

Copy link
Contributor

@Balandat Balandat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall this looks great.

Before I merge this in, could you please increase unit test coverage - looks like quite a lot of things are not covered by tests, including some important parts such as BLLPosterior.rsample().

The tutorial failure is unrelated, we'll fix this on our end.

@brunzema
Copy link
Contributor Author

brunzema commented Apr 9, 2025

hey @Balandat , again took a while--sorry about that! I tried to push the test coverage to 100%. Only places that I am unsure about are the following:

  1. After how small discussion, I now added an error to the numerical optimization of the posterior samples paths (See [Feature] Variational Bayesian last layer models as surrogate models #2754 (comment)):
def _optimize_sample_path():
    ...
    
    optimization_successful = False
    for j in range(num_restarts):
        # map to bounds
        x0 = lb + (ub - lb) * x0s[j]

        # optimize sample path
        res = scipy.optimize.minimize(
            func, x0, jac=grad_func, bounds=bounds, method="L-BFGS-B"
        )

        # check if optimization was successful
        if res.success:
            optimization_successful = True
        if not res.success:
            logger.warning(f"Optimization failed with message: {res.message}")

        # store the candidate
        X_cand[j, :] = torch.from_numpy(res.x).to(dtype=torch.float64)
        Y_cand[j] = torch.tensor([-res.fun], dtype=torch.float64)

    if not optimization_successful:
        raise RuntimeError("All optimization attempts on the sample path failed.")
        

Here I am unsure how to actually test this as this never happened to be before. I am inclined to put a # pragma: no cover here. What do you think?

  1. Should I write a separate test for the vbll helpers? The not covered lines are essentially all properties and the important parts (different parameterization of the VBLL head) are now all covered in the VBLL test. What is your opinion? Best place would be in utils?

@Balandat
Copy link
Contributor

Balandat commented Apr 9, 2025

Awesome, thanks for pushing this along. Once we cover the last few lines with tests this can go in.

Here I am unsure how to actually test this as this never happened to be before. I am inclined to put a # pragma: no cover here. What do you think?

I'd recommend having a lightweight mocked test that mocks out scipy.optimize.minimize to return a mocked OptimizeResult where success = False.

Should I write a separate test for the vbll helpers? The not covered lines are essentially all properties and the important parts (different parameterization of the VBLL head) are now all covered in the VBLL test. What is your opinion? Best place would be in utils?

Yes, having those tests separately would be great. Looks like the misses are mainly in botorch_community/models/vbll_helper.py, so I would just test these in a new test_community/models/test_vbll_helper.py module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed Do not delete this pull request or issue due to inactivity.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants