From 798bb21644d4a3e80e13947268620018fd913f65 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 15 Oct 2024 08:43:22 +0100 Subject: [PATCH] updates, Gibbs still need updating --- Project.toml | 1 + ext/JuliaBUGSAdvancedHMCExt.jl | 18 ++-- src/JuliaBUGS.jl | 1 + src/gibbs.jl | 153 +++++++++++++++++++++++---------- src/model.jl | 30 ++----- src/utils.jl | 10 ++- test/utils.jl | 16 +++- 7 files changed, 146 insertions(+), 83 deletions(-) diff --git a/Project.toml b/Project.toml index bd7c6ec7b..ec2c92b4d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" version = "0.6.2" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/ext/JuliaBUGSAdvancedHMCExt.jl b/ext/JuliaBUGSAdvancedHMCExt.jl index 0d3503949..61074120d 100644 --- a/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/ext/JuliaBUGSAdvancedHMCExt.jl @@ -2,17 +2,13 @@ module JuliaBUGSAdvancedHMCExt using AbstractMCMC using AdvancedHMC -using AdvancedHMC: Transition, stat using JuliaBUGS -using JuliaBUGS: - AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!! -using JuliaBUGS.BUGSPrimitives -using JuliaBUGS.BangBang -using JuliaBUGS.LogDensityProblems -using JuliaBUGS.LogDensityProblemsAD -using JuliaBUGS.Bijectors -using JuliaBUGS.Random using MCMCChains: Chains + +using AdvancedHMC: Transition, stat +using JuliaBUGS: AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!! +using JuliaBUGS: BUGSPrimitives, BangBang, LogDensityProblems, LogDensityProblemsAD, Bijectors, Random + import JuliaBUGS: gibbs_internal function AbstractMCMC.bundle_samples( @@ -43,10 +39,10 @@ function AbstractMCMC.bundle_samples( end function JuliaBUGS.gibbs_internal( - rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC + rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC, adtype::ADTypes.AbstractADType ) logdensitymodel = AbstractMCMC.LogDensityModel( - LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model) + LogDensityProblemsAD.ADgradient(adtype, cond_model) ) t, s = AbstractMCMC.step( rng, diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 8071b5362..c860afaa5 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -3,6 +3,7 @@ module JuliaBUGS using AbstractMCMC using AbstractPPL using Accessors +using ADTypes using BangBang using Bijectors: Bijectors using Distributions diff --git a/src/gibbs.jl b/src/gibbs.jl index 32a5be305..3f97fb5ec 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -1,80 +1,142 @@ -struct Gibbs{N,S} <: AbstractMCMC.AbstractSampler +struct Gibbs{N,S,ADT<:ADTypes.AbstractADType} <: AbstractMCMC.AbstractSampler sampler_map::OrderedDict{N,S} + adtype::ADT end -function Gibbs(model::BUGSModel, s::AbstractMCMC.AbstractSampler) - return Gibbs(OrderedDict([v => s for v in model.parameters])) +function verify_sampler_map(model::BUGSModel, sampler_map::OrderedDict) + all_variables_in_keys = Set(Iterators.flatten(keys(sampler_map))) + model_parameters = Set(model.parameters) + + # Check for extra variables in sampler_map that are not in model parameters + extra_variables = setdiff(all_variables_in_keys, model_parameters) + if !isempty(extra_variables) + throw( + ArgumentError( + "Sampler map contains variables not in the model: $extra_variables" + ), + ) + end + + # Check for model parameters not covered by sampler_map + left_over_variables = setdiff(model_parameters, all_variables_in_keys) + if !isempty(left_over_variables) + throw( + ArgumentError( + "Some model parameters are not covered by the sampler map: $left_over_variables", + ), + ) + end + + return true end -abstract type AbstractGibbsState end +""" + _create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::Vector{<:VarName}) -# do the most basic thinkings right now -# - one `evaluation_env` throughout, +Internal function to create a conditioned model for Gibbs sampling. This is different from conditioning, because conditioning +only marks a model parameter as observation, while the function effectively creates a sub-model with only the variables in the +Markov blanket of the variables that are being updated. +""" +_create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::VarName) = + _create_submodel_for_gibbs_sampling(model, [variables_to_update]) +function _create_submodel_for_gibbs_sampling( + model::BUGSModel, variables_to_update::NTuple{N,<:VarName} +) where {N} + return _create_submodel_for_gibbs_sampling(model, collect(variables_to_update)) +end +function _create_submodel_for_gibbs_sampling( + model::BUGSModel, variables_to_update::Vector{<:VarName} +) + markov_blanket = markov_blanket(model.g, variables_to_update) + mb_without_variables_to_update = setdiff(markov_blanket, variables_to_update) + random_variables_in_mb = filter( + Base.Fix1(is_stochastic, model.g), mb_without_variables_to_update + ) + observed_random_variables_in_mb = filter( + Base.Fix1(is_observation, model.g), random_variables_in_mb + ) + model_parameters_in_mb = setdiff( + random_variables_in_mb, observed_random_variables_in_mb + ) + sub_model = BUGSModel( + model; parameters=variables_to_update, sorted_nodes=markov_blanket + ) + return condition(sub_model, model_parameters_in_mb) +end -struct GibbsState{T,S,C} <: AbstractGibbsState +struct GibbsState{T,S,C} evaluation_env::T - conditioning_schedule::S - sorted_nodes_cache::C + sub_model_cache::C + sub_states::S end -ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x +function gibbs_internal end function AbstractMCMC.step( rng::Random.AbstractRNG, - l_model::AbstractMCMC.LogDensityModel{<:BUGSModel}, - sampler::Gibbs{N,S}; - model=l_model.logdensity, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModel}, + sampler::Gibbs; + model=logdensitymodel.logdensity, kwargs..., -) where {N,S} - sorted_nodes_cache, conditioning_schedule = OrderedDict(), OrderedDict() - for variable_group in keys(sampler.sampler_map) - variable_to_condition_on = setdiff(model.parameters, ensure_vector(variable_group)) - conditioning_schedule[variable_to_condition_on] = sampler.sampler_map[variable_group] - conditioned_model = AbstractPPL.condition( - model, variable_to_condition_on, model.evaluation_env +) + verify_sampler_map(model, sampler.sampler_map) + + submodel_cache = Vector{BUGSModel}(undef, length(sampler.sampler_map)) + sub_states = Any[] + for (i, variable_group) in enumerate(keys(sampler.sampler_map)) + submodel = _create_submodel_for_gibbs_sampling(model, variable_group) + sublogdensitymodel = AbstractMCMC.LogDensityModel( + LogDensityProblemsAD.ADgradient(sampler.adtype, submodel) ) - sorted_nodes_cache[variable_to_condition_on] = conditioned_model.sorted_nodes + _, s = AbstractMCMC.step( + rng, sublogdensitymodel, sampler.sampler_map[variable_group] + ) + submodel_cache[i] = s + push!(sub_states, s) end - param_values = JuliaBUGS.getparams(model) - return param_values, GibbsState(param_values, conditioning_schedule, sorted_nodes_cache) + + return getparams(model), + GibbsState(model.evaluation_env, submodel_cache, map(identity, sub_states)) end function AbstractMCMC.step( rng::Random.AbstractRNG, - l_model::AbstractMCMC.LogDensityModel{<:BUGSModel}, + logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModel}, sampler::Gibbs, - state::AbstractGibbsState; - model=l_model.logdensity, + state::GibbsState; + model=logdensitymodel.logdensity, kwargs..., ) - param_values = state.values - for vs in keys(state.conditioning_schedule) - model = initialize!(model, param_values) - cond_model = AbstractPPL.condition( - model, vs, model.evaluation_env, state.sorted_nodes_cache[vs] - ) - param_values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs]) + evaluation_env = state.evaluation_env + for (i, vs) in enumerate(keys(sampler.sampler_map)) + sub_model = BangBang.setproperty!!(model, :evaluation_env, evaluation_env) + evaluation_env = gibbs_internal(rng, sub_model, sampler.sampler_map[vs]) end - return param_values, - GibbsState(param_values, state.conditioning_schedule, state.sorted_nodes_cache) + return getparams(model), GibbsState(evaluation_env, state.sub_model_cache) end -function gibbs_internal end - struct MHFromPrior <: AbstractMCMC.AbstractSampler end -function gibbs_internal(rng::Random.AbstractRNG, cond_model::BUGSModel, ::MHFromPrior) - transformed_original = JuliaBUGS.getparams(cond_model) - values, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original) - values_proposed, logp_proposed = evaluate!!(cond_model, SamplingContext()) +struct MHState{T} + evaluation_env::T + logp::Float64 +end + +function gibbs_internal( + rng::Random.AbstractRNG, + sub_model::BUGSModel, + ::MHFromPrior, + state::MHState, + adtype::ADTypes.AbstractADType, +) + evaluation_env, logp = evaluate!!(sub_model, DefaultContext()) + proposed_evaluation_env, logp_proposed = evaluate!!(sub_model, SamplingContext()) if logp_proposed - logp > log(rand(rng)) - values = values_proposed + evaluation_env = proposed_evaluation_env end - return JuliaBUGS.getparams( - BangBang.setproperty!!(cond_model.base_model, :evaluation_env, values) - ) + return MHState(evaluation_env, logp) end function AbstractMCMC.bundle_samples( @@ -90,4 +152,3 @@ function AbstractMCMC.bundle_samples( logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs... ) end - diff --git a/src/model.jl b/src/model.jl index 516d112d5..b1f0df08b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -155,7 +155,7 @@ function BUGSModel( end function BUGSModel( - model::BUGSModel, + model::BUGSModel; parameters::Vector{<:VarName}, sorted_nodes::Vector{<:VarName}, evaluation_env::NamedTuple=model.evaluation_env, @@ -277,25 +277,11 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed)) return BangBang.setproperty!!(model, :transformed, bool) end -""" - condition_on_complement(model::BUGSModel, variables_to_update::Vector{<:VarName}) - -Internal function to create a conditioned model for Gibbs sampling. This is different from conditioning, because conditioning -only marks a model parameter as observation, while the function effectively creates a sub-model with only the variables in the -Markov blanket of the variables that are being updated. - -Precondition: `variables_to_update` must be a subset of the model parameters. -""" -function condition_on_complement( - model::BUGSModel, variables_to_update::Vector{<:VarName} +function create_sub_model( + model::BUGSModel, + model_parameters_in_submodel::Vector{<:VarName}, + all_variables_in_submodel::Vector{<:VarName}, ) - markov_blanket = markov_blanket(model.g, variables_to_update) - sub_model = create_sub_model(model, variables_to_update, markov_blanket) - random_variables_to_condition_on = - return condition(sub_model, setdiff(variables_to_update)) -end - -function create_sub_model(model::BUGSModel, model_parameters_in_submodel::Vector{<:VarName}, all_variables_in_submodel::Vector{<:VarName}) return BUGSModel(model, model_parameters_in_submodel, all_variables_in_submodel) end @@ -387,12 +373,12 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext) args = prepare_arg_values(node_args, evaluation_env, loop_vars) if !is_stochastic value = node_function(; args...) - evaluation_env = setindex!!(evaluation_env, value, vn) + evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false) else dist = node_function(; args...) - value = rand(ctx.rng, dist) # just sample from the prior + value = rand(ctx.rng, dist) logp += logpdf(dist, value) - evaluation_env = setindex!!(evaluation_env, value, vn) + evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false) end end return evaluation_env, logp diff --git a/src/utils.jl b/src/utils.jl index 9e35ddc26..a98bbfbf2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -553,10 +553,14 @@ function BangBang.NoBang._setindex(xs::AbstractArray, v::AbstractArray, I...) return ys end -function BangBang.setindex!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} - optic = BangBang.prefermutation( +function BangBang.setindex!!( + nt::NamedTuple, val, vn::VarName{sym}; prefer_mutation::Bool=true +) where {sym} + optic = if prefer_mutation + BangBang.prefermutation(AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}()) + else AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() - ) + end return Accessors.set(nt, optic, val) end diff --git a/test/utils.jl b/test/utils.jl index 2f9b55681..eb554570b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -9,7 +9,7 @@ using JuliaBUGS: CompilerUtils end loop_var, lb, ub, body = JuliaBUGS.decompose_for_expr(ex) - + @test loop_var == :i @test lb == 1 @test ub == 3 @@ -20,3 +20,17 @@ using JuliaBUGS: CompilerUtils end end end + +@testset "BangBang.setindex!!" begin + nt = (a=1, b=[1, 2, 3], c=[1, 2, 3]) + nt1 = BangBang.setindex!!(nt, 2, @varname(a)) + @test nt1.a == 2 + + nt2 = BangBang.setindex!!(nt, 5, @varname(b[1])) + @test nt2.b == [5, 2, 3] + @test nt2.b === nt.b # mutation + + nt3 = BangBang.setindex!!(nt, 2, @varname(c[1]); prefer_mutation=false) + @test nt3.c == [2, 2, 3] + @test nt3.c !== nt.c # no mutation +end