Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADTypes + ADgradient Performance #727

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Nov 28, 2024

The way to use Mooncake with DPPL is to make use of the generic DifferentiationInterface.jl interface that was added to LogDensityProblemsAD.jl. i.e. write something like

ADgradient(ADTypes.AutoMooncake(; config=nothing), log_density_function)

where log_density_function is a DPPL.LogDensityFunction.

By default, this will hit this method in LogDensityProblemsAD.

This leads to DifferentiationInterface not having sufficient information to construct its prep object, in which various things are pre-allocated and, in the case of Mooncake, the rule is constructed. This means that this method of logdensity_and_gradient gets hit, in which the prep object is reconstructed each and every time the rule is hit. This is moderately bad for Mooncake's performance, because this includes fetching the rule each and every time this function is called.

This PR adds a method to ADGradient which is specialised to LogDensityFunction and AbstractADType which ensures that the optional x kwarg is always passed in. This is enough to ensure good performance with Mooncake.

Questions:

  1. is this the optimal way to implement this? Another option might be to modify setmodel to always do this every time that ADgradient is called.
  2. Where / how should I test this? Should I just add Mooncake to the test suite and verify that ADgradient runs correctly?

Misc:

  1. I've removed the [extras] block in the primary Project.toml because we use the test/Project.toml for our test deps.
  2. I've bumped the patch so that we can tag a release asap after this is merged.

@coveralls
Copy link

coveralls commented Nov 28, 2024

Pull Request Test Coverage Report for Build 12068201796

Details

  • 0 of 3 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.06%) to 84.294%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 0 3 0.0%
Totals Coverage Status
Change from base Build 12056044639: -0.06%
Covered Lines: 3553
Relevant Lines: 4215

💛 - Coveralls

Copy link

codecov bot commented Nov 28, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.39%. Comparing base (48921d3) to head (6fb7f9b).
Report is 4 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #727      +/-   ##
==========================================
+ Coverage   84.35%   86.39%   +2.04%     
==========================================
  Files          35       36       +1     
  Lines        4212     4183      -29     
==========================================
+ Hits         3553     3614      +61     
+ Misses        659      569      -90     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Nov 28, 2024

Pull Request Test Coverage Report for Build 12083712337

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 5 of 5 (100.0%) changed or added relevant lines in 3 files are covered.
  • 71 unchanged lines in 6 files lost coverage.
  • Overall coverage increased (+1.5%) to 85.847%

Files with Coverage Reduction New Missed Lines %
src/model.jl 2 91.58%
src/distribution_wrappers.jl 4 63.89%
src/debug_utils.jl 5 94.74%
src/test_utils/contexts.jl 6 86.36%
src/context_implementations.jl 17 90.72%
src/contexts.jl 37 78.06%
Totals Coverage Status
Change from base Build 12056044639: 1.5%
Covered Lines: 3591
Relevant Lines: 4183

💛 - Coveralls

@sunxd3
Copy link
Member

sunxd3 commented Nov 28, 2024

maybe you have seen this, in ReverseDiffExt there is similar code for different reason

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction
) where {Tcompile}
return LogDensityProblemsAD.ADgradient(
Val(:ReverseDiff),
ℓ;
compile=Val(Tcompile),
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
# `zero(D)` will return 0 when D is Real.
# here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`.
x=map(identity, DynamicPPL.getparams(ℓ)),
)
end

can the above code be removed then?

@willtebbutt
Copy link
Member Author

Ooooo I think we might be able to. This is because there is this code in the DI.jl extension of LogDensityProblemsAD.

Probably good to do a quick performance check though.

I wonder whether the ForwardDiff code found here could also be removed...

@torfjelde
Copy link
Member

Where / how should I test this? Should I just add Mooncake to the test suite and verify that ADgradient runs correctly?

Yes please:) We have one for ForwardDiff.jl (see test/ext/... and the AD tests in runtests.jl). Maybe just add a similar one?

@willtebbutt
Copy link
Member Author

Actually, does anyone know whether here is the only place that we interact with ADgradient? It's the only place I can find it in DPPL. I wonder whether we should just change the call here to pass in x, rather than adding random extra methods of ADgradient?

@penelopeysm
Copy link
Member

AFAIK that's the only place where LogDensityProblems is used, so yes, it seems we could just pass in an x there.

@torfjelde
Copy link
Member

Not sure I fully follow

@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 29, 2024

I could have been clearer in my explanation. Here's a better one.

The Problem

My issue with the current implementation is method ambiguities. I've defined a method with signature

Tuple{typeof(ADgradient), AbstractADType, LogDensityFunction}

but there exist other methods in LogDensityProblemsAD.jl, located around here, with signatures such as

Tuple{typeof(ADgradient), AutoEnzyme, Any}
Tuple{typeof(ADgradient), AutoForwardDiff, Any}
Tuple{typeof(ADgradient), AutoReverseDiff, Any}

etc. Now, we currently have methods in DynamicPPL.jl (defined in extensions) which have signatures

Tuple{typeof(ADgradient), AutoForwardDiff, LogDensityFunction}
Tuple{typeof(ADgradient), AutoReverseDiff, LogDensityFunction}

which resolve the ambiguity discussed above for AutoForwardDiff and AutoReverseDiff, but I imagine we'll encounter problems for AutoEnzyme and AutoZygote. Also, we would quite like to remove these methods, so they don't constitute a solution to the problem.

Potential Solutions

My initial proposal above was to avoid this method ambiguity entirely by just not defining any new methods of ADgradient, and simply ensuring that we always make sure to pass in the x kwarg when calling ADgradient with an AbstractADType.

This seems like a fine solution if we only ever call it in a single place (i.e. in setmodel), but if we call ADgradient in many places, it's a pain to ensure that we do the (somewhat arcane) thing required to get x in all of the places.

Another option would be to introduce another function to the DPPL interface, which has two methods, with signatures

Tuple{typeof(make_ad_gradient), ADType, LogDensityFunction} # ADType interface
Tuple{typeof(make_ad_gradient), ::Val, LogDensityFunction} # old LogDensityProblemsAD with `Val` interface

Both of which would construct an ADgradient in whatever the correct manner is.

This function would need to be documented as part of the public DynamicPPL interface, and linked to from the docstring for LogDensityFunction.

Thoughts @penelopeysm @torfjelde @sunxd3 ?

@sunxd3
Copy link
Member

sunxd3 commented Nov 29, 2024

I think all of you guys have better opinions on interface than I do. So this is more like a discussion point rather than strong suggestion.

I think

Tuple{typeof(ADgradient), AutoForwardDiff, LogDensityFunction}

can causes confusion for (potential) maintainers (us), but straightforward for users that are familiar with LogDensityProblemsAD.

I like the idea of make_ad_gradient to avoid ambiguity. But it might be somewhat unavoidable that someone would try to call ADgradient with LogDensityFunction just because they think: "okay, LogDensityFunction conforms to LogDenistyProblems interface, so it should just work withLogDensityProblemsAD." Then we would need to make ADgradient work regardless.

@torfjelde
Copy link
Member

My issue with the current implementation is method ambiguities. I've defined a method with signature

Ah, damn 😕 Yeah this ain't great.

But @willtebbutt why do we need to define this extraction for the AbstractADBackend? Why dont' we just do ths on a case-by-case basis? Sure, that is a bit annoying, but there aren't that many AD backends we need to do it for.

Another option would be to introduce another function to the DPPL interface, which has two methods, with signatures

We had this before, but a big part of the motivation for moving to LogDensityProblemsAD.jl was to not diverge from the ecosystem by defining our own make_ad functions, so this goes quite counter to that. IF we make a new method, then the selling point that "you can also just treat a model as a LogDensityProblems.jl problem!" sort of isnt' true anymore, no?

@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 29, 2024

Hmmm yes, I agree that it would be a great shame to do something that users aren't expecting here.

Okay, I propose the following:

  1. we define an internal function called _make_ad_gradient,
  2. for each ADType we care about we add a method of ADgradient to an extension, which just defers the call to _make_ad_gradient. i.e. it should just be a 1-liner.

I'm going to implement this now to see what it looks like.

@torfjelde
Copy link
Member

Happy with the internal _make_ad_gradient:)

@willtebbutt
Copy link
Member Author

Is it often the case that CI times out, or should I look into why this might be happening?

@penelopeysm
Copy link
Member

The CI matrix is set to fail-fast, so in this case the coveralls app failed and that shut everything else down 😅

Rerunning it will probably fix it. Usually, DPPL CI should complete within half an hour on ubuntu.

@willtebbutt willtebbutt reopened this Nov 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants