diff --git a/ext/JuliaBUGSMCMCChainsExt.jl b/ext/JuliaBUGSMCMCChainsExt.jl index fcb0e029b..2b9628226 100644 --- a/ext/JuliaBUGSMCMCChainsExt.jl +++ b/ext/JuliaBUGSMCMCChainsExt.jl @@ -3,6 +3,7 @@ module JuliaBUGSMCMCChainsExt using JuliaBUGS using JuliaBUGS: AbstractBUGSModel, find_generated_quantities_variables, LogDensityContext, evaluate!! +using JuliaBUGS: AbstractMCMC using JuliaBUGS.AbstractPPL using JuliaBUGS.BUGSPrimitives using JuliaBUGS.LogDensityProblems @@ -11,6 +12,20 @@ using DynamicPPL using AbstractMCMC using MCMCChains: Chains +function AbstractMCMC.bundle_samples( + ts, + logdensitymodel::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}, + sampler::JuliaBUGS.Gibbs, + state, + ::Type{Chains}; + discard_initial=0, + kwargs..., +) + return JuliaBUGS.gen_chains( + logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs... + ) +end + function JuliaBUGS.gen_chains( model::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}, samples, diff --git a/src/gibbs.jl b/src/gibbs.jl index bcc0a0abd..70f85934a 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -41,8 +41,9 @@ Internal function to create a conditioned model for Gibbs sampling. This is diff only marks a model parameter as observation, while the function effectively creates a sub-model with only the variables in the Markov blanket of the variables that are being updated. """ -_create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::VarName) = - _create_submodel_for_gibbs_sampling(model, [variables_to_update]) +function _create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::VarName) + return _create_submodel_for_gibbs_sampling(model, [variables_to_update]) +end function _create_submodel_for_gibbs_sampling( model::BUGSModel, variables_to_update::NTuple{N,<:VarName} ) where {N} @@ -105,7 +106,7 @@ function AbstractMCMC.step( push!(sub_states, state) end - return getparams(model), + return getparams(settrans(model, false)), GibbsState(model.evaluation_env, submodel_cache, map(identity, sub_states)) end @@ -119,15 +120,16 @@ function AbstractMCMC.step( ) evaluation_env = state.evaluation_env for (i, vs) in enumerate(keys(sampler.sampler_map)) - sub_model = state.sub_model_cache[i] - sub_model = BangBang.setproperty!!(sub_model, :evaluation_env, evaluation_env) + sub_model = BangBang.setproperty!!( + state.sub_model_cache[i], :evaluation_env, evaluation_env + ) evaluation_env, new_sub_state = gibbs_internal( rng, sub_model, sampler.sampler_map[vs], state.sub_states[i], sampler.adtype ) state.sub_states[i] = new_sub_state end model = BangBang.setproperty!!(model, :evaluation_env, evaluation_env) - return getparams(model), + return getparams(settrans(model, false)), GibbsState(evaluation_env, state.sub_model_cache, state.sub_states) end @@ -156,16 +158,3 @@ function gibbs_internal( return evaluation_env, MHState(evaluation_env, logp) end -function AbstractMCMC.bundle_samples( - ts, - logdensitymodel::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel}, - sampler::Gibbs, - state, - ::Type{T}; - discard_initial=0, - kwargs..., -) where {T} - return JuliaBUGS.gen_chains( - logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs... - ) -end diff --git a/src/model.jl b/src/model.jl index f03b1217b..f3380e1f5 100644 --- a/src/model.jl +++ b/src/model.jl @@ -181,34 +181,29 @@ Initialize the model with a NamedTuple of initial values, the values are expecte """ function initialize!(model::BUGSModel, initial_params::NamedTuple) check_input(initial_params) + evaluation_env = model.evaluation_env for vn in model.sorted_nodes - (; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn] + (; node_function, node_args, loop_vars) = model.g[vn] args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) - if !is_stochastic + if is_deterministic(model.g, vn) value = Base.invokelatest(node_function; args...) - BangBang.@set!! model.evaluation_env = setindex!!( - model.evaluation_env, value, vn - ) - elseif !is_observed + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) + elseif is_model_parameter(model.g, vn) initialization = try AbstractPPL.get(initial_params, vn) catch _ missing end if !ismissing(initialization) - BangBang.@set!! model.evaluation_env = setindex!!( - model.evaluation_env, initialization, vn - ) + evaluation_env = BangBang.setindex!!(evaluation_env, initialization, vn) else - BangBang.@set!! model.evaluation_env = setindex!!( - model.evaluation_env, - rand(Base.invokelatest(node_function; args...)), - vn, + evaluation_env = BangBang.setindex!!( + evaluation_env, rand(Base.invokelatest(node_function; args...)), vn ) end end end - return model + return BangBang.setproperty!!(model, :evaluation_env, evaluation_env) end """ @@ -312,9 +307,9 @@ function AbstractPPL.condition( elseif model.g[vn].is_observed @warn "$vn is already an observed variable, conditioning on it won't have any effect" else - old_node_info = model.g[vn] - new_node_info = BangBang.setproperty!!(old_node_info, :is_observed, true) - model.g[vn] = new_node_info + new_g = copy(model.g) + new_g[vn] = BangBang.setproperty!!(model.g[vn], :is_observed, true) + model = BangBang.setproperty!!(model, :g, new_g) end end return model @@ -331,9 +326,9 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) elseif !model.g[vn].is_observed @warn "$vn is already treated as model parameter, deconditioning it won't have any effect" else - BangBang.@set!! model.g[vn] = BangBang.setproperty!!( - model.g[vn], :is_observed, false - ) + new_g = copy(model.g) + new_g[vn] = BangBang.setproperty!!(model.g[vn], :is_observed, false) + model = BangBang.setproperty!!(model, :g, new_g) end end return model @@ -366,19 +361,31 @@ function AbstractPPL.evaluate!!(model::BUGSModel, rng::Random.AbstractRNG) return evaluate!!(model, SamplingContext(rng)) end function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext) - (; evaluation_env, g, sorted_nodes) = model + evaluation_env = deepcopy(model.evaluation_env) # TODO: a lot of the arrays are not modified logp = 0.0 - for vn in sorted_nodes - (; is_stochastic, node_function, node_args, loop_vars) = g[vn] + for vn in model.sorted_nodes + (; node_function, node_args, loop_vars) = model.g[vn] args = prepare_arg_values(node_args, evaluation_env, loop_vars) - if !is_stochastic + if !is_stochastic(model.g, vn) value = node_function(; args...) - evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) else dist = node_function(; args...) - value = rand(ctx.rng, dist) - logp += logpdf(dist, value) - evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false) + if is_observation(model.g, vn) + value = AbstractPPL.get(evaluation_env, vn) + else + value = rand(ctx.rng, dist) + evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) + end + if model.transformed + value_transformed = Bijectors.transform(Bijectors.bijector(dist), value) + logp += + Distributions.logpdf(dist, value) + Bijectors.logabsdetjac( + Bijectors.inverse(Bijectors.bijector(dist)), value_transformed + ) + else + logp += Distributions.logpdf(dist, value) + end end end return evaluation_env, logp diff --git a/src/utils.jl b/src/utils.jl index a98bbfbf2..9e35ddc26 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -553,14 +553,10 @@ function BangBang.NoBang._setindex(xs::AbstractArray, v::AbstractArray, I...) return ys end -function BangBang.setindex!!( - nt::NamedTuple, val, vn::VarName{sym}; prefer_mutation::Bool=true -) where {sym} - optic = if prefer_mutation - BangBang.prefermutation(AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}()) - else +function BangBang.setindex!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} + optic = BangBang.prefermutation( AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() - end + ) return Accessors.set(nt, optic, val) end diff --git a/test/gibbs.jl b/test/gibbs.jl index 9a6adff58..c8cb5da53 100644 --- a/test/gibbs.jl +++ b/test/gibbs.jl @@ -1,6 +1,51 @@ -using JuliaBUGS: MHFromPrior, Gibbs +using JuliaBUGS: MHFromPrior, Gibbs, OrderedDict -@testset "Simple gibbs" begin +model_def = @bugs begin + μ ~ Normal(0, 4) + σ ~ Gamma(1, 1) + for i in 1:100 + y[i] ~ Normal(μ, σ) + end +end + +μ_true = 2 +σ_true = 4 + +y = rand(Normal(μ_true, σ_true), 100) + +model = compile(model_def, (;y=y)) +model = initialize!(model, (μ=4.0, σ=6.0)) + +splr_map = OrderedDict(@varname(μ) => MHFromPrior(), @varname(σ) => MHFromPrior()) +splr = Gibbs(splr_map) + +p_s, st_init = AbstractMCMC.step( + Random.default_rng(), + AbstractMCMC.LogDensityModel(model), + splr, +) + +p_s, st = AbstractMCMC.step( + Random.default_rng(), + AbstractMCMC.LogDensityModel(model), + splr, + st_init, +) + +chn = AbstractMCMC.sample( + Random.default_rng(), + model, + splr, + 1000 +) + +σ_samples = [v[1] for v in chn[300:end]] +μ_samples = [v[2] for v in chn[300:end]] + +@test mean(μ_samples) ≈ μ_true rtol = 0.2 +@test mean(σ_samples) ≈ σ_true rtol = 0.2 + +# @testset "Simple gibbs" begin model_def = @bugs begin # Likelihood for i in 1:N @@ -29,11 +74,19 @@ using JuliaBUGS: MHFromPrior, Gibbs model = compile(model_def, data, (;)) + sampler = Gibbs( + OrderedDict( + @varname(alpha) => MHFromPrior(), + @varname(beta) => MHFromPrior(), + @varname(sigma) => MHFromPrior(), + ), + ) + # single step p_s, st_init = AbstractMCMC.step( Random.default_rng(), AbstractMCMC.LogDensityModel(model), - Gibbs(model, MHFromPrior()), + sampler, ) # following step