Skip to content

Commit

Permalink
updates, Gibbs still need updating
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 15, 2024
1 parent 0b9b160 commit 798bb21
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 83 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.6.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
18 changes: 7 additions & 11 deletions ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@ module JuliaBUGSAdvancedHMCExt

using AbstractMCMC
using AdvancedHMC
using AdvancedHMC: Transition, stat
using JuliaBUGS
using JuliaBUGS:
AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.BangBang
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Bijectors
using JuliaBUGS.Random
using MCMCChains: Chains

using AdvancedHMC: Transition, stat
using JuliaBUGS: AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: BUGSPrimitives, BangBang, LogDensityProblems, LogDensityProblemsAD, Bijectors, Random

import JuliaBUGS: gibbs_internal

function AbstractMCMC.bundle_samples(
Expand Down Expand Up @@ -43,10 +39,10 @@ function AbstractMCMC.bundle_samples(
end

function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC, adtype::ADTypes.AbstractADType
)
logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
LogDensityProblemsAD.ADgradient(adtype, cond_model)
)
t, s = AbstractMCMC.step(
rng,
Expand Down
1 change: 1 addition & 0 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module JuliaBUGS
using AbstractMCMC
using AbstractPPL
using Accessors
using ADTypes
using BangBang
using Bijectors: Bijectors
using Distributions
Expand Down
153 changes: 107 additions & 46 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
@@ -1,80 +1,142 @@
struct Gibbs{N,S} <: AbstractMCMC.AbstractSampler
struct Gibbs{N,S,ADT<:ADTypes.AbstractADType} <: AbstractMCMC.AbstractSampler
sampler_map::OrderedDict{N,S}
adtype::ADT
end

function Gibbs(model::BUGSModel, s::AbstractMCMC.AbstractSampler)
return Gibbs(OrderedDict([v => s for v in model.parameters]))
function verify_sampler_map(model::BUGSModel, sampler_map::OrderedDict)
all_variables_in_keys = Set(Iterators.flatten(keys(sampler_map)))
model_parameters = Set(model.parameters)

# Check for extra variables in sampler_map that are not in model parameters
extra_variables = setdiff(all_variables_in_keys, model_parameters)
if !isempty(extra_variables)
throw(
ArgumentError(
"Sampler map contains variables not in the model: $extra_variables"
),
)
end

# Check for model parameters not covered by sampler_map
left_over_variables = setdiff(model_parameters, all_variables_in_keys)
if !isempty(left_over_variables)
throw(
ArgumentError(
"Some model parameters are not covered by the sampler map: $left_over_variables",
),
)
end

return true
end

abstract type AbstractGibbsState end
"""
_create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::Vector{<:VarName})
# do the most basic thinkings right now
# - one `evaluation_env` throughout,
Internal function to create a conditioned model for Gibbs sampling. This is different from conditioning, because conditioning
only marks a model parameter as observation, while the function effectively creates a sub-model with only the variables in the
Markov blanket of the variables that are being updated.
"""
_create_submodel_for_gibbs_sampling(model::BUGSModel, variables_to_update::VarName) =
_create_submodel_for_gibbs_sampling(model, [variables_to_update])
function _create_submodel_for_gibbs_sampling(
model::BUGSModel, variables_to_update::NTuple{N,<:VarName}
) where {N}
return _create_submodel_for_gibbs_sampling(model, collect(variables_to_update))
end
function _create_submodel_for_gibbs_sampling(
model::BUGSModel, variables_to_update::Vector{<:VarName}
)
markov_blanket = markov_blanket(model.g, variables_to_update)
mb_without_variables_to_update = setdiff(markov_blanket, variables_to_update)
random_variables_in_mb = filter(
Base.Fix1(is_stochastic, model.g), mb_without_variables_to_update
)
observed_random_variables_in_mb = filter(
Base.Fix1(is_observation, model.g), random_variables_in_mb
)
model_parameters_in_mb = setdiff(
random_variables_in_mb, observed_random_variables_in_mb
)
sub_model = BUGSModel(
model; parameters=variables_to_update, sorted_nodes=markov_blanket
)
return condition(sub_model, model_parameters_in_mb)
end

struct GibbsState{T,S,C} <: AbstractGibbsState
struct GibbsState{T,S,C}
evaluation_env::T
conditioning_schedule::S
sorted_nodes_cache::C
sub_model_cache::C
sub_states::S
end

ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x
function gibbs_internal end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{<:BUGSModel},
sampler::Gibbs{N,S};
model=l_model.logdensity,
logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModel},
sampler::Gibbs;
model=logdensitymodel.logdensity,
kwargs...,
) where {N,S}
sorted_nodes_cache, conditioning_schedule = OrderedDict(), OrderedDict()
for variable_group in keys(sampler.sampler_map)
variable_to_condition_on = setdiff(model.parameters, ensure_vector(variable_group))
conditioning_schedule[variable_to_condition_on] = sampler.sampler_map[variable_group]
conditioned_model = AbstractPPL.condition(
model, variable_to_condition_on, model.evaluation_env
)
verify_sampler_map(model, sampler.sampler_map)

submodel_cache = Vector{BUGSModel}(undef, length(sampler.sampler_map))
sub_states = Any[]
for (i, variable_group) in enumerate(keys(sampler.sampler_map))
submodel = _create_submodel_for_gibbs_sampling(model, variable_group)
sublogdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(sampler.adtype, submodel)
)
sorted_nodes_cache[variable_to_condition_on] = conditioned_model.sorted_nodes
_, s = AbstractMCMC.step(
rng, sublogdensitymodel, sampler.sampler_map[variable_group]
)
submodel_cache[i] = s
push!(sub_states, s)
end
param_values = JuliaBUGS.getparams(model)
return param_values, GibbsState(param_values, conditioning_schedule, sorted_nodes_cache)

return getparams(model),
GibbsState(model.evaluation_env, submodel_cache, map(identity, sub_states))
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
l_model::AbstractMCMC.LogDensityModel{<:BUGSModel},
logdensitymodel::AbstractMCMC.LogDensityModel{<:BUGSModel},
sampler::Gibbs,
state::AbstractGibbsState;
model=l_model.logdensity,
state::GibbsState;
model=logdensitymodel.logdensity,
kwargs...,
)
param_values = state.values
for vs in keys(state.conditioning_schedule)
model = initialize!(model, param_values)
cond_model = AbstractPPL.condition(
model, vs, model.evaluation_env, state.sorted_nodes_cache[vs]
)
param_values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs])
evaluation_env = state.evaluation_env
for (i, vs) in enumerate(keys(sampler.sampler_map))
sub_model = BangBang.setproperty!!(model, :evaluation_env, evaluation_env)
evaluation_env = gibbs_internal(rng, sub_model, sampler.sampler_map[vs])
end
return param_values,
GibbsState(param_values, state.conditioning_schedule, state.sorted_nodes_cache)
return getparams(model), GibbsState(evaluation_env, state.sub_model_cache)
end

function gibbs_internal end

struct MHFromPrior <: AbstractMCMC.AbstractSampler end

function gibbs_internal(rng::Random.AbstractRNG, cond_model::BUGSModel, ::MHFromPrior)
transformed_original = JuliaBUGS.getparams(cond_model)
values, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original)
values_proposed, logp_proposed = evaluate!!(cond_model, SamplingContext())
struct MHState{T}
evaluation_env::T
logp::Float64
end

function gibbs_internal(
rng::Random.AbstractRNG,
sub_model::BUGSModel,
::MHFromPrior,
state::MHState,
adtype::ADTypes.AbstractADType,
)
evaluation_env, logp = evaluate!!(sub_model, DefaultContext())
proposed_evaluation_env, logp_proposed = evaluate!!(sub_model, SamplingContext())

if logp_proposed - logp > log(rand(rng))
values = values_proposed
evaluation_env = proposed_evaluation_env
end

return JuliaBUGS.getparams(
BangBang.setproperty!!(cond_model.base_model, :evaluation_env, values)
)
return MHState(evaluation_env, logp)
end

function AbstractMCMC.bundle_samples(
Expand All @@ -90,4 +152,3 @@ function AbstractMCMC.bundle_samples(
logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs...
)
end

30 changes: 8 additions & 22 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ function BUGSModel(
end

function BUGSModel(
model::BUGSModel,
model::BUGSModel;
parameters::Vector{<:VarName},
sorted_nodes::Vector{<:VarName},
evaluation_env::NamedTuple=model.evaluation_env,
Expand Down Expand Up @@ -277,25 +277,11 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed))
return BangBang.setproperty!!(model, :transformed, bool)
end

"""
condition_on_complement(model::BUGSModel, variables_to_update::Vector{<:VarName})
Internal function to create a conditioned model for Gibbs sampling. This is different from conditioning, because conditioning
only marks a model parameter as observation, while the function effectively creates a sub-model with only the variables in the
Markov blanket of the variables that are being updated.
Precondition: `variables_to_update` must be a subset of the model parameters.
"""
function condition_on_complement(
model::BUGSModel, variables_to_update::Vector{<:VarName}
function create_sub_model(
model::BUGSModel,
model_parameters_in_submodel::Vector{<:VarName},
all_variables_in_submodel::Vector{<:VarName},
)
markov_blanket = markov_blanket(model.g, variables_to_update)
sub_model = create_sub_model(model, variables_to_update, markov_blanket)
random_variables_to_condition_on =
return condition(sub_model, setdiff(variables_to_update))
end

function create_sub_model(model::BUGSModel, model_parameters_in_submodel::Vector{<:VarName}, all_variables_in_submodel::Vector{<:VarName})
return BUGSModel(model, model_parameters_in_submodel, all_variables_in_submodel)
end

Expand Down Expand Up @@ -387,12 +373,12 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn)
evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false)
else
dist = node_function(; args...)
value = rand(ctx.rng, dist) # just sample from the prior
value = rand(ctx.rng, dist)
logp += logpdf(dist, value)
evaluation_env = setindex!!(evaluation_env, value, vn)
evaluation_env = setindex!!(evaluation_env, value, vn; prefer_mutation=false)
end
end
return evaluation_env, logp
Expand Down
10 changes: 7 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,14 @@ function BangBang.NoBang._setindex(xs::AbstractArray, v::AbstractArray, I...)
return ys
end

function BangBang.setindex!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym}
optic = BangBang.prefermutation(
function BangBang.setindex!!(
nt::NamedTuple, val, vn::VarName{sym}; prefer_mutation::Bool=true
) where {sym}
optic = if prefer_mutation
BangBang.prefermutation(AbstractPPL.getoptic(vn) Accessors.PropertyLens{sym}())
else
AbstractPPL.getoptic(vn) Accessors.PropertyLens{sym}()
)
end
return Accessors.set(nt, optic, val)
end

Expand Down
16 changes: 15 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using JuliaBUGS: CompilerUtils
end

loop_var, lb, ub, body = JuliaBUGS.decompose_for_expr(ex)

@test loop_var == :i
@test lb == 1
@test ub == 3
Expand All @@ -20,3 +20,17 @@ using JuliaBUGS: CompilerUtils
end
end
end

@testset "BangBang.setindex!!" begin
nt = (a=1, b=[1, 2, 3], c=[1, 2, 3])
nt1 = BangBang.setindex!!(nt, 2, @varname(a))
@test nt1.a == 2

nt2 = BangBang.setindex!!(nt, 5, @varname(b[1]))
@test nt2.b == [5, 2, 3]
@test nt2.b === nt.b # mutation

nt3 = BangBang.setindex!!(nt, 2, @varname(c[1]); prefer_mutation=false)
@test nt3.c == [2, 2, 3]
@test nt3.c !== nt.c # no mutation
end

0 comments on commit 798bb21

Please sign in to comment.