Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor interface for projections/proximal operators #147

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Nov 17, 2024

This PR refactors how post-hoc modifications are applied to the iterates after performing a gradient descent step. For instance, before, updating the parameters of LocationScale always silently applied a projection step. Now, everything needs to be made into its own OptimisationRule to make it more modular and explicit.

More concretely, this PR changes the following:

  • The scale matrix of a LocationScale distribution is no longer projected by default.
  • A new rule object, ProjectScale, which wraps around an actual optimizer like Adam, will apply it instead.

Red-Portal and others added 4 commits November 17, 2024 01:06
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 074218a Previous: f754796 Ratio
normal/RepGradELBO + STL/meanfield/Zygote 15006704304 ns 14763506565 ns 1.02
normal/RepGradELBO + STL/meanfield/ForwardDiff 3348126895 ns 3447158973 ns 0.97
normal/RepGradELBO + STL/meanfield/ReverseDiff 3251766175 ns 3287816987 ns 0.99
normal/RepGradELBO + STL/fullrank/Zygote 14839392486 ns 14709910827 ns 1.01
normal/RepGradELBO + STL/fullrank/ForwardDiff 3681133932 ns 3731464808 ns 0.99
normal/RepGradELBO + STL/fullrank/ReverseDiff 5922342652 ns 5901083378 ns 1.00
normal/RepGradELBO/meanfield/Zygote 7220171802 ns 7152569391 ns 1.01
normal/RepGradELBO/meanfield/ForwardDiff 2475323654 ns 2590582460 ns 0.96
normal/RepGradELBO/meanfield/ReverseDiff 1518493197 ns 1537937469 ns 0.99
normal/RepGradELBO/fullrank/Zygote 7251201005 ns 7236156742 ns 1.00
normal/RepGradELBO/fullrank/ForwardDiff 2684751297 ns 2831282755 ns 0.95
normal/RepGradELBO/fullrank/ReverseDiff 2688543940 ns 2665713222 ns 1.01
normal + bijector/RepGradELBO + STL/meanfield/Zygote 22580741833 ns 22923473993 ns 0.99
normal + bijector/RepGradELBO + STL/meanfield/ForwardDiff 9239954469 ns 9104303562 ns 1.01
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 4700620404 ns 4734080494 ns 0.99
normal + bijector/RepGradELBO + STL/fullrank/Zygote 22561930054 ns 22834259529 ns 0.99
normal + bijector/RepGradELBO + STL/fullrank/ForwardDiff 9456450373 ns 9877151525 ns 0.96
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 7951126990 ns 7892624197 ns 1.01
normal + bijector/RepGradELBO/meanfield/Zygote 14279623416 ns 14509912160 ns 0.98
normal + bijector/RepGradELBO/meanfield/ForwardDiff 7938907881 ns 7940646347 ns 1.00
normal + bijector/RepGradELBO/meanfield/ReverseDiff 2757542395 ns 2763150626 ns 1.00
normal + bijector/RepGradELBO/fullrank/Zygote 14378045996 ns 14495005765 ns 0.99
normal + bijector/RepGradELBO/fullrank/ForwardDiff 8384985802 ns 9451953758 ns 0.89
normal + bijector/RepGradELBO/fullrank/ReverseDiff 4215836426 ns 4263995527 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant