Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AD testing and benchmarking (with DITest) #883

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Apr 4, 2025

Part 2 of two options. The other one at #882.

Closes #869

Why am I not in favour of this one?

I think some exposition is required here, and I didn't have time to explain this super clearly during the meeting.

The API of DITest is like this:

  1. You construct a scenario, which includes the function f, the value at which to evaluate it / the gradient x, and a bunch of other things. Crucially, the scenario does not include the adtype.

  2. You then run the scenario with an adtype (or an array thereof).

From the perspective of generic functions f, this is quite a nice interface. The tricky bit with DynamicPPL, as I briefly mentioned, is that when you pass LogDensityFunction a model, varinfo, etc. it does a bunch of things that not only changes the function f being differentiated, but also potentially modifies the adtype that is actually used. See, especially, this constructor:

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
.

(Note that LogDensityFunctionsAD.jl used to do this stuff for us; #806 effectively removed it and inlined its optimisations into that inner constructor.)

What this means is that, to be completely consistent with the way DynamicPPL behaves, one has to:

  1. Reproduce the code inside src/logdensityfunctions.jl that generates the function f, so that the scenario can use the correct f.
  2. Because the above depends on the adtype, we have to make sure that scenarios generated with one adtype are later run with the same adtype.
    • In fact, the preparation in the LogDensityFunction doesn't only depend on the adtype; it potentially also modifies the adtype.
    • That's why this PR doesn't just include make_scenario; it also includes a run_ad function below, which ensures that the scenario is run with the appropriately modified adtype.

If we adopt this PR, then we have to choose between either:

  1. Duplicating the code inside src/logdensityfunctions.jl, as I've done in this PR; or
  2. Cutting this duplicated code out, which means that the results obtained when using this test/benchmark function will differ from the results when actually sampling a Turing model;
  3. Removing the extra prep work inside src/logdensityfunctions.jl

(3) is a no-go as it would have noticeable impacts on performance, and even though I think it'd be very nice if we could just export a list of scenarios, I'm not really comfortable with either (1) or (2), and I don't think it's a good enough reason to do either.

The alternative to this, #882, already makes the API very straightforward (it's just one function with a very thorough docstring) and so I don't think it's unfair to define that as our interface - especially considering that it's most likely that we will actually be the ones writing the integration tests for other people.

@penelopeysm penelopeysm changed the title Implement AD testing (with DITest) Implement AD testing and benchmarking (with DITest) Apr 4, 2025
Copy link
Contributor

github-actions bot commented Apr 4, 2025

Benchmark Report for Commit e1a34e1

Computer Information

Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  9.9 |                 1.5 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                617.9 |                42.6 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                419.8 |                48.4 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1243.3 |                27.5 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               3937.0 |                20.4 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1459.6 |                29.8 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                944.3 |                 5.4 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5567.2 |                 4.1 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1123.4 |                 8.2 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              61969.3 |                 3.7 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8946.0 |                 9.6 |
|               Dynamic |        10 |    mooncake |             typed |   true |                136.5 |                11.9 |
|              Submodel |         1 |    mooncake |             typed |   true |                 25.7 |                 7.7 |
|                   LDA |        12 | reversediff |             typed |   true |                479.8 |                 5.2 |

Copy link

codecov bot commented Apr 4, 2025

Codecov Report

Attention: Patch coverage is 88.88889% with 2 lines in your changes missing coverage. Please review.

Project coverage is 84.89%. Comparing base (eed80e5) to head (e1a34e1).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
ext/DynamicPPLDifferentiationInterfaceTestExt.jl 88.88% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #883      +/-   ##
==========================================
+ Coverage   84.87%   84.89%   +0.01%     
==========================================
  Files          34       35       +1     
  Lines        3815     3833      +18     
==========================================
+ Hits         3238     3254      +16     
- Misses        577      579       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

coveralls commented Apr 4, 2025

Pull Request Test Coverage Report for Build 14256574630

Details

  • 0 of 14 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-3.5%) to 81.418%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLDifferentiationInterfaceTestExt.jl 0 14 0.0%
Totals Coverage Status
Change from base Build 14127923718: -3.5%
Covered Lines: 3111
Relevant Lines: 3821

💛 - Coveralls

@coveralls
Copy link

coveralls commented Apr 4, 2025

Pull Request Test Coverage Report for Build 14263072728

Details

  • 0 of 18 (0.0%) changed or added relevant lines in 1 file are covered.
  • 20 unchanged lines in 3 files lost coverage.
  • Overall coverage increased (+0.02%) to 84.983%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLDifferentiationInterfaceTestExt.jl 0 18 0.0%
Files with Coverage Reduction New Missed Lines %
src/model.jl 1 85.83%
src/varinfo.jl 3 84.51%
src/threadsafe.jl 16 55.05%
Totals Coverage Status
Change from base Build 14127923718: 0.02%
Covered Lines: 3254
Relevant Lines: 3829

💛 - Coveralls

@sunxd3
Copy link
Member

sunxd3 commented Apr 8, 2025

The reasons for preference are super valid. I also think that since the hand-rolled version is not too complicated, it's worth to maintain it ourselves. Otherwise for new contributors to be able to contribute to this, they need to know what a test scenario is for DIT.

@penelopeysm penelopeysm mentioned this pull request Apr 8, 2025
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AD testing
3 participants