Skip to content

Commit

Permalink
fix errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 20, 2024
1 parent 57576b0 commit f3707e3
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 57 deletions.
15 changes: 15 additions & 0 deletions ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module JuliaBUGSMCMCChainsExt
using JuliaBUGS
using JuliaBUGS:
AbstractBUGSModel, find_generated_quantities_variables, LogDensityContext, evaluate!!
using JuliaBUGS: AbstractMCMC
using JuliaBUGS.AbstractPPL
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
Expand All @@ -11,6 +12,20 @@ using DynamicPPL
using AbstractMCMC
using MCMCChains: Chains

function AbstractMCMC.bundle_samples(
ts,
logdensitymodel::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel},
sampler::JuliaBUGS.Gibbs,
state,
::Type{Chains};
discard_initial=0,
kwargs...,
)
return JuliaBUGS.gen_chains(
logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs...
)
end

function JuliaBUGS.gen_chains(
model::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel},
samples,
Expand Down
27 changes: 8 additions & 19 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ Internal function to create a conditioned model for Gibbs sampling. This is diff
only marks a model parameter as observation, while the function effectively creates a sub-model with only the variables in the
Markov blanket of the variables that are being updated.
"""
_create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::VarName) =
_create_submodel_for_gibbs_sampling(model, [variables_to_update])
function _create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::VarName)
return _create_submodel_for_gibbs_sampling(model, [variables_to_update])
end
function _create_submodel_for_gibbs_sampling(
model::BUGSModel, variables_to_update::NTuple{N,<:VarName}
) where {N}
Expand Down Expand Up @@ -105,7 +106,7 @@ function AbstractMCMC.step(
push!(sub_states, state)
end

return getparams(model),
return getparams(settrans(model, false)),
GibbsState(model.evaluation_env, submodel_cache, map(identity, sub_states))
end

Expand All @@ -119,15 +120,16 @@ function AbstractMCMC.step(
)
evaluation_env = state.evaluation_env
for (i, vs) in enumerate(keys(sampler.sampler_map))
sub_model = state.sub_model_cache[i]
sub_model = BangBang.setproperty!!(sub_model, :evaluation_env, evaluation_env)
sub_model = BangBang.setproperty!!(
state.sub_model_cache[i], :evaluation_env, evaluation_env
)
evaluation_env, new_sub_state = gibbs_internal(
rng, sub_model, sampler.sampler_map[vs], state.sub_states[i], sampler.adtype
)
state.sub_states[i] = new_sub_state
end
model = BangBang.setproperty!!(model, :evaluation_env, evaluation_env)
return getparams(model),
return getparams(settrans(model, false)),
GibbsState(evaluation_env, state.sub_model_cache, state.sub_states)
end

Expand Down Expand Up @@ -156,16 +158,3 @@ function gibbs_internal(
return evaluation_env, MHState(evaluation_env, logp)
end

function AbstractMCMC.bundle_samples(
ts,
logdensitymodel::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel},
sampler::Gibbs,
state,
::Type{T};
discard_initial=0,
kwargs...,
) where {T}
return JuliaBUGS.gen_chains(
logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs...
)
end
63 changes: 35 additions & 28 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,34 +181,29 @@ Initialize the model with a NamedTuple of initial values, the values are expecte
"""
function initialize!(model::BUGSModel, initial_params::NamedTuple)
check_input(initial_params)
evaluation_env = model.evaluation_env
for vn in model.sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn]
(; node_function, node_args, loop_vars) = model.g[vn]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
if !is_stochastic
if is_deterministic(model.g, vn)
value = Base.invokelatest(node_function; args...)
BangBang.@set!! model.evaluation_env = setindex!!(
model.evaluation_env, value, vn
)
elseif !is_observed
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
elseif is_model_parameter(model.g, vn)
initialization = try
AbstractPPL.get(initial_params, vn)
catch _
missing
end
if !ismissing(initialization)
BangBang.@set!! model.evaluation_env = setindex!!(
model.evaluation_env, initialization, vn
)
evaluation_env = BangBang.setindex!!(evaluation_env, initialization, vn)
else
BangBang.@set!! model.evaluation_env = setindex!!(
model.evaluation_env,
rand(Base.invokelatest(node_function; args...)),
vn,
evaluation_env = BangBang.setindex!!(
evaluation_env, rand(Base.invokelatest(node_function; args...)), vn
)
end
end
end
return model
return BangBang.setproperty!!(model, :evaluation_env, evaluation_env)
end

"""
Expand Down Expand Up @@ -312,9 +307,9 @@ function AbstractPPL.condition(
elseif model.g[vn].is_observed
@warn "$vn is already an observed variable, conditioning on it won't have any effect"
else
old_node_info = model.g[vn]
new_node_info = BangBang.setproperty!!(old_node_info, :is_observed, true)
model.g[vn] = new_node_info
new_g = copy(model.g)
new_g[vn] = BangBang.setproperty!!(model.g[vn], :is_observed, true)
model = BangBang.setproperty!!(model, :g, new_g)
end
end
return model
Expand All @@ -331,9 +326,9 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
elseif !model.g[vn].is_observed
@warn "$vn is already treated as model parameter, deconditioning it won't have any effect"
else
BangBang.@set!! model.g[vn] = BangBang.setproperty!!(
model.g[vn], :is_observed, false
)
new_g = copy(model.g)
new_g[vn] = BangBang.setproperty!!(model.g[vn], :is_observed, false)
model = BangBang.setproperty!!(model, :g, new_g)
end
end
return model
Expand Down Expand Up @@ -366,19 +361,31 @@ function AbstractPPL.evaluate!!(model::BUGSModel, rng::Random.AbstractRNG)
return evaluate!!(model, SamplingContext(rng))
end
function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
(; evaluation_env, g, sorted_nodes) = model
evaluation_env = deepcopy(model.evaluation_env) # TODO: a lot of the arrays are not modified
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
for vn in model.sorted_nodes
(; node_function, node_args, loop_vars) = model.g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
if !is_stochastic(model.g, vn)
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
value = rand(ctx.rng, dist)
logp += logpdf(dist, value)
evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false)
if is_observation(model.g, vn)
value = AbstractPPL.get(evaluation_env, vn)
else
value = rand(ctx.rng, dist)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
end
if model.transformed
value_transformed = Bijectors.transform(Bijectors.bijector(dist), value)
logp +=
Distributions.logpdf(dist, value) + Bijectors.logabsdetjac(
Bijectors.inverse(Bijectors.bijector(dist)), value_transformed
)
else
logp += Distributions.logpdf(dist, value)
end
end
end
return evaluation_env, logp
Expand Down
10 changes: 3 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,14 +553,10 @@ function BangBang.NoBang._setindex(xs::AbstractArray, v::AbstractArray, I...)
return ys
end

function BangBang.setindex!!(
nt::NamedTuple, val, vn::VarName{sym}; prefer_mutation::Bool=true
) where {sym}
optic = if prefer_mutation
BangBang.prefermutation(AbstractPPL.getoptic(vn) Accessors.PropertyLens{sym}())
else
function BangBang.setindex!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym}
optic = BangBang.prefermutation(
AbstractPPL.getoptic(vn) Accessors.PropertyLens{sym}()
end
)
return Accessors.set(nt, optic, val)
end

Expand Down
59 changes: 56 additions & 3 deletions test/gibbs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,51 @@
using JuliaBUGS: MHFromPrior, Gibbs
using JuliaBUGS: MHFromPrior, Gibbs, OrderedDict

@testset "Simple gibbs" begin
model_def = @bugs begin
μ ~ Normal(0, 4)
σ ~ Gamma(1, 1)
for i in 1:100
y[i] ~ Normal(μ, σ)
end
end

μ_true = 2
σ_true = 4

y = rand(Normal(μ_true, σ_true), 100)

model = compile(model_def, (;y=y))
model = initialize!(model, (μ=4.0, σ=6.0))

splr_map = OrderedDict(@varname(μ) => MHFromPrior(), @varname(σ) => MHFromPrior())
splr = Gibbs(splr_map)

p_s, st_init = AbstractMCMC.step(
Random.default_rng(),
AbstractMCMC.LogDensityModel(model),
splr,
)

p_s, st = AbstractMCMC.step(
Random.default_rng(),
AbstractMCMC.LogDensityModel(model),
splr,
st_init,
)

chn = AbstractMCMC.sample(
Random.default_rng(),
model,
splr,
1000
)

σ_samples = [v[1] for v in chn[300:end]]
μ_samples = [v[2] for v in chn[300:end]]

@test mean(μ_samples) μ_true rtol = 0.2
@test mean(σ_samples) σ_true rtol = 0.2

# @testset "Simple gibbs" begin
model_def = @bugs begin
# Likelihood
for i in 1:N
Expand Down Expand Up @@ -29,11 +74,19 @@ using JuliaBUGS: MHFromPrior, Gibbs

model = compile(model_def, data, (;))

sampler = Gibbs(
OrderedDict(
@varname(alpha) => MHFromPrior(),
@varname(beta) => MHFromPrior(),
@varname(sigma) => MHFromPrior(),
),
)

# single step
p_s, st_init = AbstractMCMC.step(
Random.default_rng(),
AbstractMCMC.LogDensityModel(model),
Gibbs(model, MHFromPrior()),
sampler,
)

# following step
Expand Down

0 comments on commit f3707e3

Please sign in to comment.