Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linesearch (and lbfgs) support #4351

Merged
merged 4 commits into from
Nov 1, 2024
Merged

Conversation

jlperla
Copy link
Contributor

@jlperla jlperla commented Oct 31, 2024

What does this PR do?

This PR modifies the .update() to support additional arguments required for the Optax lbfgs and related algorithms which rely on linesearch methods.

Fixes #4144

  • Additional tests were added to check the lbfgs() (where I believe the additional arguments are required for the linesearch more generally)
  • The key feature needed is a callback to evaluate the objective function at different "state" values. This needs to use a split and merge in the optimization.
    • In cases where the model_static, model_state = nnx.split(state.model, self.wrt) has already been done, the model_static can be passed into the .update to avoid doing the nnx.split step with each iteration. See the test_jit_linesearch for more there.

A feature missing here is support for value_and_grad_from_state as described in https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html#linesearches-in-practice. Implementing this would provide some performance advantages as it would allow the optimizer to reuse the gradients/value.

Copy link

google-cla bot commented Oct 31, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 31, 2024

Hey @jlperla! I left a longer comment in #4144 but TLDR is that I think we should either only add **kwargs and have the user implement the definition for value_fn, or implement an entirely new optimizer class for this family of algorithms (or just specific to lbfgs).

@jlperla
Copy link
Contributor Author

jlperla commented Oct 31, 2024

@cgarciae See if this is what you had in mind. If so, it seems like it solves the general GradientTransformationExtraArgs challenges for future features. There are other arguments there as well.

@cgarciae
Copy link
Collaborator

cgarciae commented Nov 1, 2024

@jlperla sound reasonable. Approved.

@copybara-service copybara-service bot merged commit d8b1a92 into google:main Nov 1, 2024
17 checks passed
@jlperla
Copy link
Contributor Author

jlperla commented Nov 1, 2024

Amazing, thanks @cgarciae When is the next expected release? Would love to publicize this in some sample code for a paper.

@jlperla jlperla deleted the lbfgs_support branch November 1, 2024 19:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for optax lbfgs and related optimizers with NNX
2 participants