Skip to content

Replace internal AD backend types with ADTypes#2047

Merged
yebai merged 15 commits intomasterfrom
dw/adtypes
Nov 16, 2023
Merged

Replace internal AD backend types with ADTypes#2047
yebai merged 15 commits intomasterfrom
dw/adtypes

Conversation

@devmotion
Copy link
Copy Markdown
Member

This PR is a draft proposal for replacing our internal AD backend types with ADTypes.

Needs support of ADTypes in LogDensityProblemsAD: tpapp/LogDensityProblemsAD.jl#17

@torfjelde
Copy link
Copy Markdown
Member

Love it!

@torfjelde
Copy link
Copy Markdown
Member

We'll also need to make equivalent changes in AdvancedVI, I believe.

@yebai
Copy link
Copy Markdown
Member

yebai commented Jul 19, 2023

We'll also need to make equivalent changes in AdvancedVI, I believe.

I think @Red-Portal already did it in the AdvancedVI rewrite PR.

@Red-Portal
Copy link
Copy Markdown
Member

Hi, yes that's already up to date!

Comment thread docs/src/using-turing/autodiff.md Outdated
Comment thread src/essential/ad.jl Outdated
Comment thread src/essential/ad.jl Outdated
Comment thread Project.toml Outdated
Comment thread src/essential/Essential.jl Outdated
Comment thread Project.toml Outdated
@devmotion
Copy link
Copy Markdown
Member Author

I opened tpapp/LogDensityProblemsAD.jl#21.

Comment on lines +49 to +52
AutoForwardDiff,
AutoTracker,
AutoZygote,
AutoReverseDiff,
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually want to export these types? Or should we tell users to use ADTypes.AutoForwardDiff etc. (in particular when these types would be used in other packages such as e.g. AdvancedVI as well).

Comment thread src/essential/ad.jl Outdated
Comment thread test/essential/ad.jl Outdated
f_rd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{false}(), f)
f_rd_compiled = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{true}(), f)
f_rd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD(false), f)
f_rd_compiled = LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), f; compile=Val(true), x=θ) # need to compile with non-zero inputs
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion I had to pass in θ for the result to be correct, started a PR tpapp/LogDensityProblemsAD.jl#22

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not add a keyword argument but just continue overloading ADgradient for TuringLogDensityFunction. I wouldn't want users to have to deal with such a keyword argument.

Comment thread Project.toml Outdated
Comment thread src/mcmc/hmc.jl
adtype::ADTypes.AbstractADType = ADBackend(),
) where AD
return HMC{AD}(ϵ, n_leapfrog, metricT, space)
return HMC(ϵ, n_leapfrog, metricT, space; adtype = adtype)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion, we can probably consider removing the global AD flag ADBACKEND, and always specify the autodiff backend in inference algorithms. What are your thoughts?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And set the default AD type for each algorithm/sampler to a specific default such as AutoForwardDiff, you mean? Or would you like users to always specify the AD type explicitly?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can set the default AD type to AutoForwardDiff for all algorithms and allow users to override them via keyword arguments. That way, we don't need to maintain a global AD flag and can remove the messy code around it. But that should probably be done in a separate PR.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to get rid of the global flag 👍 If we make it in a separate PR we should maybe wait with tagging a breaking release until this follow-up PR is merged as well, to avoid two breaking releases in a row.

@github-actions
Copy link
Copy Markdown
Contributor

Pull Request Test Coverage Report for Build 6887921533

  • 0 of 41 (0.0%) changed or added relevant lines in 5 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage remained the same at 0.0%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/mcmc/Inference.jl 0 1 0.0%
ext/TuringDynamicHMCExt.jl 0 2 0.0%
src/mcmc/sghmc.jl 0 4 0.0%
src/essential/ad.jl 0 12 0.0%
src/mcmc/hmc.jl 0 22 0.0%
Files with Coverage Reduction New Missed Lines %
ext/TuringDynamicHMCExt.jl 1 0.0%
Totals Coverage Status
Change from base Build 6829598411: 0.0%
Covered Lines: 0
Relevant Lines: 1421

💛 - Coveralls

@codecov
Copy link
Copy Markdown

codecov Bot commented Nov 16, 2023

Codecov Report

Attention: 41 lines in your changes are missing coverage. Please review.

Comparison is base (d4a7975) 0.00% compared to head (da44611) 0.00%.

Files Patch % Lines
src/mcmc/hmc.jl 0.00% 22 Missing ⚠️
src/essential/ad.jl 0.00% 12 Missing ⚠️
src/mcmc/sghmc.jl 0.00% 4 Missing ⚠️
ext/TuringDynamicHMCExt.jl 0.00% 2 Missing ⚠️
src/mcmc/Inference.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@          Coverage Diff           @@
##           master   #2047   +/-   ##
======================================
  Coverage    0.00%   0.00%           
======================================
  Files          21      21           
  Lines        1435    1421   -14     
======================================
+ Misses       1435    1421   -14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment thread src/mcmc/hmc.jl
kwargs...
) where AD
return HMC{AD}(ϵ, n_leapfrog; kwargs...)
function HMC(ϵ::Float64, n_leapfrog::Int, ::Type{metricT}, space::Tuple; adtype::ADTypes.AbstractADType = ADBackend()) where {metricT <: AHMC.AbstractMetric}
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove this constructor and only support metricT as keyword argument? Or make all arguments keyword arguments?

Generally, these HMC constructors are quite messy...

Copy link
Copy Markdown
Member

@yebai yebai Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We plan to depreciate and then remove this old interface once the AbstractMCMC-based externalsampler interface works for Gibbs.

@yebai
Copy link
Copy Markdown
Member

yebai commented Nov 16, 2023

CI errors about chain resume/save are unrelated to this PR.

@yebai yebai marked this pull request as ready for review November 16, 2023 11:39
Copy link
Copy Markdown
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sunxd3 @devmotion -- good work!

@yebai yebai merged commit 6649f10 into master Nov 16, 2023
@yebai yebai deleted the dw/adtypes branch November 16, 2023 12:23
yebai added a commit to TuringLang/docs that referenced this pull request Dec 19, 2023
* Update autodiff.jmd following adaptation of `ADTypes`

TuringLang/Turing.jl#2047

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Using `ADTypes` for ad doc

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

* Update tutorials/docs-10-using-turing-autodiff/autodiff.jmd

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
@storopoli
Copy link
Copy Markdown
Member

So the tape gets compiled by default now when using ReverseDiff?
No need more for Turing.setrdcache(true)?

@yebai
Copy link
Copy Markdown
Member

yebai commented Dec 20, 2023

You need to pass AutoReverseDiff(true) (true enables compiled tape, false disables it) to individual sampling algorithms, see more at https://github.com/TuringLang/Turing.jl/blob/master/HISTORY.md

@storopoli
Copy link
Copy Markdown
Member

Yes, but still is not clear. The example has only adtype=AutoForwardDiff(; chunksize)) in the sampler constructor.
I had no idea that AutoReverseDiff took positional arguments, and why AutoForwardDiff doesn't take any positional arguments?

Also, https://github.com/SciML/ADTypes.jl doesn't have docs, so it is even harder to figure it out.

@yebai
Copy link
Copy Markdown
Member

yebai commented Dec 20, 2023

@sunxd3, maybe add a few concrete examples for each popular autodiff backend to HISTORY.md and docstrings?

@sunxd3
Copy link
Copy Markdown
Member

sunxd3 commented Dec 20, 2023

Yeah, it is confusing, I'll have a PR

@storopoli
Copy link
Copy Markdown
Member

storopoli commented Dec 20, 2023

I can do that if you want to. I am hours away from my Holiday break.
At least the docs part.

EDIT: here (https://turinglang.org/v0.30/docs/using-turing/autodiff)
EDIT 2: already done here (TuringLang/docs#430) huh? Oh that's great then.

@sunxd3
Copy link
Copy Markdown
Member

sunxd3 commented Dec 20, 2023

@storopoli yep, the tutorial is updated. I just updated the release information too

@ya0 ya0 mentioned this pull request Jan 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants