Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 4, 2024
1 parent 21c2a0a commit 73fbf34
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 22 deletions.
6 changes: 0 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,13 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
Expand All @@ -62,10 +58,8 @@ LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.54"
OrderedCollections = "1"
Random = "1.6"
ReverseDiff = "1"
Requires = "1"
Test = "1.6"
ZygoteRules = "0.2"
Expand Down
8 changes: 0 additions & 8 deletions ext/DynamicPPLMooncakeExt.jl

This file was deleted.

8 changes: 0 additions & 8 deletions ext/DynamicPPLReverseDiffExt.jl

This file was deleted.

7 changes: 7 additions & 0 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,10 @@ function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction)
x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params
return LogDensityProblemsAD.ADgradient(ad, ℓ; x)
end

function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction)
return _make_ad_gradient(ad, f)
end
function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction)
return _make_ad_gradient(ad, f)
end

0 comments on commit 73fbf34

Please sign in to comment.