diff --git a/Project.toml b/Project.toml index df7357d3c..909be870f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.2" +version = "0.31.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -30,7 +30,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -38,7 +37,6 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLReverseDiffExt = ["ReverseDiff"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -63,15 +61,6 @@ MacroTools = "0.5.6" OrderedCollections = "1" Random = "1.6" Requires = "1" -ReverseDiff = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.10" - -[extras] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl deleted file mode 100644 index 3fd174ed1..000000000 --- a/ext/DynamicPPLReverseDiffExt.jl +++ /dev/null @@ -1,26 +0,0 @@ -module DynamicPPLReverseDiffExt - -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ReverseDiff -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ..ReverseDiff -end - -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction -) where {Tcompile} - return LogDensityProblemsAD.ADgradient( - Val(:ReverseDiff), - ℓ; - compile=Val(Tcompile), - # `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0 - # because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473 - # `zero(D)` will return 0 when D is Real. - # here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`. - x=map(identity, DynamicPPL.getparams(ℓ)), - ) -end - -end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9e86590fa..214369ab0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -144,3 +144,19 @@ function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) end # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) + +# This is important for performance -- one needs to provide `ADGradient` with a vector of +# parameters, or DifferentiationInterface will not have sufficient information to e.g. +# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate +# a tape when using ReverseDiff.jl. +function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) + x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params + return LogDensityProblemsAD.ADgradient(ad, ℓ; x) +end + +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end diff --git a/test/Project.toml b/test/Project.toml index 686475ebd..0d247c3ec 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -17,6 +18,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -43,6 +45,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" +Mooncake = "0.4.50" ReverseDiff = "1" StableRNGs = "1" Tracker = "0.2.23" diff --git a/test/ad.jl b/test/ad.jl index 6046cfda4..768a55ad3 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -@testset "AD: ForwardDiff and ReverseDiff" begin +@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS f = DynamicPPL.LogDensityFunction(m) rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -17,11 +17,20 @@ θ = convert(Vector{Float64}, varinfo[:]) logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) - @testset "ReverseDiff with compile=$compile" for compile in (false, true) - adtype = ADTypes.AutoReverseDiff(; compile=compile) - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) - @test grad ≈ ref_grad + @testset "$adtype" for adtype in [ + ADTypes.AutoReverseDiff(; compile=false), + ADTypes.AutoReverseDiff(; compile=true), + ADTypes.AutoMooncake(; config=nothing), + ] + # Mooncake can't currently handle something that is going on in + # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. + if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo + @test_broken 1 == 0 + else + ad_f = LogDensityProblemsAD.ADgradient(adtype, f) + _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) + @test grad ≈ ref_grad + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index a832a0f08..dbfa319b0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using DynamicPPL using AbstractMCMC using AbstractPPL using Bijectors +using DifferentiationInterface using Distributions using DistributionsAD using Documenter @@ -11,6 +12,7 @@ using ForwardDiff using LogDensityProblems, LogDensityProblemsAD using MacroTools using MCMCChains +using Mooncake: Mooncake using Tracker using ReverseDiff using Zygote