Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADTypes + ADgradient Performance #727

Merged
merged 21 commits into from
Dec 7, 2024
Merged

ADTypes + ADgradient Performance #727

merged 21 commits into from
Dec 7, 2024

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Nov 28, 2024

The way to use Mooncake with DPPL is to make use of the generic DifferentiationInterface.jl interface that was added to LogDensityProblemsAD.jl. i.e. write something like

ADgradient(ADTypes.AutoMooncake(; config=nothing), log_density_function)

where log_density_function is a DPPL.LogDensityFunction.

By default, this will hit this method in LogDensityProblemsAD.

This leads to DifferentiationInterface not having sufficient information to construct its prep object, in which various things are pre-allocated and, in the case of Mooncake, the rule is constructed. This means that this method of logdensity_and_gradient gets hit, in which the prep object is reconstructed each and every time the rule is hit. This is moderately bad for Mooncake's performance, because this includes fetching the rule each and every time this function is called.

This PR adds a method to ADGradient which is specialised to LogDensityFunction and AbstractADType which ensures that the optional x kwarg is always passed in. This is enough to ensure good performance with Mooncake.

Questions:

  1. is this the optimal way to implement this? Another option might be to modify setmodel to always do this every time that ADgradient is called.
  2. Where / how should I test this? Should I just add Mooncake to the test suite and verify that ADgradient runs correctly?

Misc:

  1. I've removed the [extras] block in the primary Project.toml because we use the test/Project.toml for our test deps.
  2. I've bumped the patch so that we can tag a release asap after this is merged.

@coveralls
Copy link

coveralls commented Nov 28, 2024

Pull Request Test Coverage Report for Build 12068201796

Details

  • 0 of 3 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.06%) to 84.294%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/logdensityfunction.jl 0 3 0.0%
Totals Coverage Status
Change from base Build 12056044639: -0.06%
Covered Lines: 3553
Relevant Lines: 4215

💛 - Coveralls

Copy link

codecov bot commented Nov 28, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 86.34%. Comparing base (5a58571) to head (5b7ab97).
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #727      +/-   ##
==========================================
+ Coverage   86.32%   86.34%   +0.01%     
==========================================
  Files          35       34       -1     
  Lines        4249     4254       +5     
==========================================
+ Hits         3668     3673       +5     
  Misses        581      581              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Nov 28, 2024

Pull Request Test Coverage Report for Build 12206812251

Details

  • 7 of 7 (100.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.02%) to 86.342%

Totals Coverage Status
Change from base Build 12200518816: 0.02%
Covered Lines: 3673
Relevant Lines: 4254

💛 - Coveralls

@sunxd3
Copy link
Member

sunxd3 commented Nov 28, 2024

maybe you have seen this, in ReverseDiffExt there is similar code for different reason

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction
) where {Tcompile}
return LogDensityProblemsAD.ADgradient(
Val(:ReverseDiff),
ℓ;
compile=Val(Tcompile),
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
# `zero(D)` will return 0 when D is Real.
# here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`.
x=map(identity, DynamicPPL.getparams(ℓ)),
)
end

can the above code be removed then?

@willtebbutt
Copy link
Member Author

Ooooo I think we might be able to. This is because there is this code in the DI.jl extension of LogDensityProblemsAD.

Probably good to do a quick performance check though.

I wonder whether the ForwardDiff code found here could also be removed...

@torfjelde
Copy link
Member

Where / how should I test this? Should I just add Mooncake to the test suite and verify that ADgradient runs correctly?

Yes please:) We have one for ForwardDiff.jl (see test/ext/... and the AD tests in runtests.jl). Maybe just add a similar one?

@willtebbutt
Copy link
Member Author

Actually, does anyone know whether here is the only place that we interact with ADgradient? It's the only place I can find it in DPPL. I wonder whether we should just change the call here to pass in x, rather than adding random extra methods of ADgradient?

@penelopeysm
Copy link
Member

AFAIK that's the only place where LogDensityProblems is used, so yes, it seems we could just pass in an x there.

@torfjelde
Copy link
Member

Not sure I fully follow

@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 29, 2024

I could have been clearer in my explanation. Here's a better one.

The Problem

My issue with the current implementation is method ambiguities. I've defined a method with signature

Tuple{typeof(ADgradient), AbstractADType, LogDensityFunction}

but there exist other methods in LogDensityProblemsAD.jl, located around here, with signatures such as

Tuple{typeof(ADgradient), AutoEnzyme, Any}
Tuple{typeof(ADgradient), AutoForwardDiff, Any}
Tuple{typeof(ADgradient), AutoReverseDiff, Any}

etc. Now, we currently have methods in DynamicPPL.jl (defined in extensions) which have signatures

Tuple{typeof(ADgradient), AutoForwardDiff, LogDensityFunction}
Tuple{typeof(ADgradient), AutoReverseDiff, LogDensityFunction}

which resolve the ambiguity discussed above for AutoForwardDiff and AutoReverseDiff, but I imagine we'll encounter problems for AutoEnzyme and AutoZygote. Also, we would quite like to remove these methods, so they don't constitute a solution to the problem.

Potential Solutions

My initial proposal above was to avoid this method ambiguity entirely by just not defining any new methods of ADgradient, and simply ensuring that we always make sure to pass in the x kwarg when calling ADgradient with an AbstractADType.

This seems like a fine solution if we only ever call it in a single place (i.e. in setmodel), but if we call ADgradient in many places, it's a pain to ensure that we do the (somewhat arcane) thing required to get x in all of the places.

Another option would be to introduce another function to the DPPL interface, which has two methods, with signatures

Tuple{typeof(make_ad_gradient), ADType, LogDensityFunction} # ADType interface
Tuple{typeof(make_ad_gradient), ::Val, LogDensityFunction} # old LogDensityProblemsAD with `Val` interface

Both of which would construct an ADgradient in whatever the correct manner is.

This function would need to be documented as part of the public DynamicPPL interface, and linked to from the docstring for LogDensityFunction.

Thoughts @penelopeysm @torfjelde @sunxd3 ?

@sunxd3
Copy link
Member

sunxd3 commented Nov 29, 2024

I think all of you guys have better opinions on interface than I do. So this is more like a discussion point rather than strong suggestion.

I think

Tuple{typeof(ADgradient), AutoForwardDiff, LogDensityFunction}

can causes confusion for (potential) maintainers (us), but straightforward for users that are familiar with LogDensityProblemsAD.

I like the idea of make_ad_gradient to avoid ambiguity. But it might be somewhat unavoidable that someone would try to call ADgradient with LogDensityFunction just because they think: "okay, LogDensityFunction conforms to LogDenistyProblems interface, so it should just work withLogDensityProblemsAD." Then we would need to make ADgradient work regardless.

@torfjelde
Copy link
Member

My issue with the current implementation is method ambiguities. I've defined a method with signature

Ah, damn 😕 Yeah this ain't great.

But @willtebbutt why do we need to define this extraction for the AbstractADBackend? Why dont' we just do ths on a case-by-case basis? Sure, that is a bit annoying, but there aren't that many AD backends we need to do it for.

Another option would be to introduce another function to the DPPL interface, which has two methods, with signatures

We had this before, but a big part of the motivation for moving to LogDensityProblemsAD.jl was to not diverge from the ecosystem by defining our own make_ad functions, so this goes quite counter to that. IF we make a new method, then the selling point that "you can also just treat a model as a LogDensityProblems.jl problem!" sort of isnt' true anymore, no?

@willtebbutt
Copy link
Member Author

willtebbutt commented Nov 29, 2024

Hmmm yes, I agree that it would be a great shame to do something that users aren't expecting here.

Okay, I propose the following:

  1. we define an internal function called _make_ad_gradient,
  2. for each ADType we care about we add a method of ADgradient to an extension, which just defers the call to _make_ad_gradient. i.e. it should just be a 1-liner.

I'm going to implement this now to see what it looks like.

@torfjelde
Copy link
Member

Happy with the internal _make_ad_gradient:)

@willtebbutt
Copy link
Member Author

Is it often the case that CI times out, or should I look into why this might be happening?

@willtebbutt willtebbutt reopened this Nov 30, 2024
@willtebbutt willtebbutt closed this Dec 2, 2024
@willtebbutt willtebbutt reopened this Dec 2, 2024
@willtebbutt willtebbutt closed this Dec 2, 2024
@willtebbutt willtebbutt reopened this Dec 2, 2024
Project.toml Outdated Show resolved Hide resolved
@willtebbutt
Copy link
Member Author

Update: we're not going to be able to get this merged until JuliaRegistries/General#120562 is resolved.

@willtebbutt willtebbutt closed this Dec 5, 2024
@willtebbutt willtebbutt reopened this Dec 5, 2024
@willtebbutt
Copy link
Member Author

willtebbutt commented Dec 5, 2024

@mhauru @torfjelde any idea what this OOM error is about? Have we seen it anywhere else? It looks x86 specific, and like it's happening in a part of the pipeline which isn't AD related, but it does scare me a little bit.

@mhauru
Copy link
Member

mhauru commented Dec 5, 2024

Seems like a case of this #725

@penelopeysm
Copy link
Member

penelopeysm commented Dec 5, 2024

#725

I thought I'd look into it, but those tests will be removed in #733 anyway so couldn't be bothered to track it down. Personally I'd be ok with ignoring the error.

@willtebbutt
Copy link
Member Author

willtebbutt commented Dec 5, 2024

Cool. As far as I'm concerned, this is reading to go then (I've bumped the patch version). @penelopeysm could you approve if you're happy, and we'll get it merged.

test/ad.jl Show resolved Hide resolved
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dopey stuff @willtebbutt :)

@penelopeysm penelopeysm merged commit f0c31f0 into master Dec 7, 2024
11 of 13 checks passed
@penelopeysm penelopeysm deleted the wct/mooncake-perf branch December 7, 2024 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants