Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use "differentiable optimization trick" for backpropagation through tangent vector field calculation #74

Open
mfschubert opened this issue Jan 10, 2024 · 5 comments

Comments

@mfschubert
Copy link
Collaborator

Currently, we directly backpropagate through the tangent vector field calculation, which involves a Newton solve to find the minimum of a convex quadratic objective. It may be more efficient to define a custom gradient for this operation, in a manner similar to what is done for differentiable optimization.

@mfschubert
Copy link
Collaborator Author

I am seeing some issues with super-long compile times in the optimization context, which are eliminated when we use a stop_gradient before the vector field calculation. I am thinking we should just add this stop_gradient, and then restore the ability to backpropagate through vector field generation via the method mentioned above. This might be fairly involved, and would take time. fyi @smartalecH

@smartalecH
Copy link
Contributor

Yep this sounds like a good plan to me. How hard do we anticipate the manual adjoint will be?

@mfschubert
Copy link
Collaborator Author

I am looking at it a bit. It might actually be relatively straightforward. Here's a reference that seems nice, it even includes Jax code: https://implicit-layers-tutorial.org/implicit_functions/

@mfschubert
Copy link
Collaborator Author

@smartalecH @Luochenghuang I have things working here---all it needed was a bit of regularization.

https://github.com/mfschubert/mewtax

@mfschubert
Copy link
Collaborator Author

I think we may want to put this on hold for now: the potential accuracy improvement is small, and there is a speed penalty.

  • I added a test with Test AD gradient against finite difference gradient #94 which checks the FD gradient against AD gradient. They are very close as-is, i.e. even with the stop_gradient in the vector field calculation.
  • I tested using mewtax to solve for the vector fields, but this seems to make the tests much slower (2x time to complete all tests). I suspect there is a significant compile time penalty.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants