Skip to content

Commit

Permalink
Merge branch 'master' into rylin/bayesnet_implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 authored Nov 20, 2024
2 parents 121d412 + 5dedc30 commit 0cff1db
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.7.0"
version = "0.7.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
1 change: 1 addition & 0 deletions ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using JuliaBUGS.AbstractPPL
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS: Accessors
using AbstractMCMC
using MCMCChains: Chains

Expand Down
32 changes: 32 additions & 0 deletions test/ext/mcmchains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,36 @@
@test means[:beta].nt.mean[1] 2.1 atol = 0.2
@test means[:sigma].nt.mean[1] 0.9 atol = 0.2
@test means[:gen_quant].nt.mean[1] 4.2 atol = 0.2

# test for more complicated varnames
model_def = @bugs begin
A[1, 1:3] ~ Dirichlet(ones(3))
A[2, 1:3] ~ Dirichlet(ones(3))
A[3, 1:3] ~ Dirichlet(ones(3))

mu[1:3] ~ MvNormal(zeros(3), 10 * Diagonal(ones(3)))
sigma[1] ~ InverseGamma(2, 3)
sigma[2] ~ InverseGamma(2, 3)
sigma[3] ~ InverseGamma(2, 3)
end
model = compile(model_def, (;))
ad_model = ADgradient(:ReverseDiff, model; compile=Val(true))
hmc_chain = AbstractMCMC.sample(ad_model, NUTS(0.8), 10; chain_type=Chains)
@test hmc_chain.name_map[:parameters] == [
Symbol("sigma[3]"),
Symbol("sigma[2]"),
Symbol("sigma[1]"),
Symbol("mu[1:3][1]"),
Symbol("mu[1:3][2]"),
Symbol("mu[1:3][3]"),
Symbol("A[3, 1:3][1]"),
Symbol("A[3, 1:3][2]"),
Symbol("A[3, 1:3][3]"),
Symbol("A[2, 1:3][1]"),
Symbol("A[2, 1:3][2]"),
Symbol("A[2, 1:3][3]"),
Symbol("A[1, 1:3][1]"),
Symbol("A[1, 1:3][2]"),
Symbol("A[1, 1:3][3]"),
]
end

0 comments on commit 0cff1db

Please sign in to comment.