diff --git a/Project.toml b/Project.toml index a4ec7fcbd..fd8c62a92 100644 --- a/Project.toml +++ b/Project.toml @@ -30,8 +30,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -39,8 +37,6 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLMooncakeExt = ["Mooncake"] -DynamicPPLReverseDiffExt = ["ReverseDiff"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -62,10 +58,8 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" -Mooncake = "0.4.54" OrderedCollections = "1" Random = "1.6" -ReverseDiff = "1" Requires = "1" Test = "1.6" ZygoteRules = "0.2" diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl deleted file mode 100644 index 400da3b20..000000000 --- a/ext/DynamicPPLMooncakeExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module DynamicPPLMooncakeExt - -import LogDensityProblemsAD: ADgradient -using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction - -ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) = _make_ad_gradient(ad, f) - -end # module diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl deleted file mode 100644 index 3728068ce..000000000 --- a/ext/DynamicPPLReverseDiffExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module DynamicPPLReverseDiffExt - -import LogDensityProblemsAD: ADgradient -using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction - -ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) = _make_ad_gradient(ad, f) - -end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index d47c6ccd4..214369ab0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -153,3 +153,10 @@ function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params return LogDensityProblemsAD.ADgradient(ad, ℓ; x) end + +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end