Skip to content

Commit

Permalink
use arrays to store the data in nodeinfo instead of getindex from met…
Browse files Browse the repository at this point in the history
…agraph
  • Loading branch information
sunxd3 committed Oct 27, 2024
1 parent f9c253b commit a1443e8
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 35 deletions.
2 changes: 1 addition & 1 deletion ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function JuliaBUGS.gen_chains(
g = model.g

generated_vars = find_generated_vars(g)
generated_vars = [v for v in model.sorted_nodes if v in generated_vars] # keep the order
generated_vars = [v for v in model.eval_cache.sorted_nodes if v in generated_vars] # keep the order

param_vals = []
generated_quantities = []
Expand Down
12 changes: 6 additions & 6 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ abstract type AbstractGibbsState end
struct GibbsState{T,S,C} <: AbstractGibbsState
values::T
conditioning_schedule::S
sorted_nodes_cache::C
cached_eval_caches::C
end

ensure_vector(x) = x isa Union{Number,VarName} ? [x] : x
Expand All @@ -25,17 +25,17 @@ function AbstractMCMC.step(
model=l_model.logdensity,
kwargs...,
) where {N,S}
sorted_nodes_cache, conditioning_schedule = OrderedDict(), OrderedDict()
cached_eval_caches, conditioning_schedule = OrderedDict(), OrderedDict()
for variable_group in keys(sampler.sampler_map)
variable_to_condition_on = setdiff(model.parameters, ensure_vector(variable_group))
conditioning_schedule[variable_to_condition_on] = sampler.sampler_map[variable_group]
conditioned_model = AbstractPPL.condition(
model, variable_to_condition_on, model.evaluation_env
)
sorted_nodes_cache[variable_to_condition_on] = conditioned_model.sorted_nodes
cached_eval_caches[variable_to_condition_on] = conditioned_model.eval_cache
end
param_values = JuliaBUGS.getparams(model)
return param_values, GibbsState(param_values, conditioning_schedule, sorted_nodes_cache)
return param_values, GibbsState(param_values, conditioning_schedule, cached_eval_caches)
end

function AbstractMCMC.step(
Expand All @@ -50,12 +50,12 @@ function AbstractMCMC.step(
for vs in keys(state.conditioning_schedule)
model = initialize!(model, param_values)
cond_model = AbstractPPL.condition(
model, vs, model.evaluation_env, state.sorted_nodes_cache[vs]
model, vs, model.evaluation_env, state.cached_eval_caches[vs]
)
param_values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs])
end
return param_values,
GibbsState(param_values, state.conditioning_schedule, state.sorted_nodes_cache)
GibbsState(param_values, state.conditioning_schedule, state.cached_eval_caches)
end

function gibbs_internal end
Expand Down
105 changes: 79 additions & 26 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,44 @@
# instead of https://github.com/TuringLang/AbstractMCMC.jl/blob/d7c549fe41a80c1f164423c7ac458425535f624b/src/logdensityproblems.jl#L90
abstract type AbstractBUGSModel end

"""
EvalCache{TNF,TNA,TV}
Pre-compute the values of the nodes in the model to avoid lookups from MetaGraph.
"""
struct EvalCache{TNF,TNA,TV}
sorted_nodes::Vector{<:VarName}
is_stochastic_vals::Vector{Bool}
is_observed_vals::Vector{Bool}
node_function_vals::TNF
node_args_vals::TNA
loop_vars_vals::TV
end

function EvalCache(sorted_nodes::Vector{<:VarName}, g::BUGSGraph)
is_stochastic_vals = Array{Bool}(undef, length(sorted_nodes))
is_observed_vals = Array{Bool}(undef, length(sorted_nodes))
node_function_vals = []
node_args_vals = []
loop_vars_vals = []
for (i, vn) in enumerate(sorted_nodes)
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
is_stochastic_vals[i] = is_stochastic
is_observed_vals[i] = is_observed
push!(node_function_vals, node_function)
push!(node_args_vals, Val(node_args))
push!(loop_vars_vals, loop_vars)
end
return EvalCache(
sorted_nodes,
is_stochastic_vals,
is_observed_vals,
node_function_vals,
node_args_vals,
loop_vars_vals,
)
end

"""
BUGSModel
Expand All @@ -27,8 +65,8 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple}
evaluation_env::T
"A vector containing the names of the model parameters (unobserved stochastic variables)."
parameters::Vector{<:VarName}
"A vector containing the names of all the variables in the model, sorted in topological order."
sorted_nodes::Vector{<:VarName}
"An `EvalCache` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed."
eval_cache::EvalCache

"An instance of `BUGSGraph`, representing the dependency graph of the model."
g::BUGSGraph
Expand Down Expand Up @@ -144,14 +182,15 @@ function BUGSModel(
transformed_var_lengths,
evaluation_env,
parameters,
sorted_nodes,
EvalCache(sorted_nodes, g),
g,
nothing,
)
end

function BUGSModel(
model::BUGSModel,
g::BUGSGraph,
parameters::Vector{<:VarName},
sorted_nodes::Vector{<:VarName},
evaluation_env::NamedTuple=model.evaluation_env,
Expand All @@ -164,8 +203,8 @@ function BUGSModel(
model.transformed_var_lengths,
evaluation_env,
parameters,
sorted_nodes,
model.g,
EvalCache(sorted_nodes, g),
g,
isnothing(model.base_model) ? model : model.base_model,
)
end
Expand All @@ -177,9 +216,13 @@ Initialize the model with a NamedTuple of initial values, the values are expecte
"""
function initialize!(model::BUGSModel, initial_params::NamedTuple)
check_input(initial_params)
for vn in model.sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn]
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
is_observed = model.eval_cache.is_observed_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
if !is_stochastic
value = Base.invokelatest(node_function; args...)
BangBang.@set!! model.evaluation_env = setindex!!(
Expand Down Expand Up @@ -318,11 +361,11 @@ function AbstractPPL.condition(
new_parameters = setdiff(model.parameters, var_group)

sorted_blanket_with_vars = if sorted_nodes isa Nothing
sorted_nodes
model.eval_cache.sorted_nodes
else
filter(
vn -> vn in union(markov_blanket(model.g, new_parameters), new_parameters),
model.sorted_nodes,
model.eval_cache.sorted_nodes,
)
end

Expand All @@ -338,7 +381,7 @@ function AbstractPPL.condition(
end
end

new_model = BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env)
new_model = BUGSModel(model, g, new_parameters, sorted_blanket_with_vars, evaluation_env)
return BangBang.setproperty!!(new_model, :g, g)
end

Expand All @@ -347,18 +390,19 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
base_model = model.base_model isa Nothing ? model : model.base_model

new_parameters = [
v for v in base_model.sorted_nodes if v in union(model.parameters, var_group)
v for
v in base_model.eval_cache.sorted_nodes if v in union(model.parameters, var_group)
] # keep the order

markov_blanket_with_vars = union(
markov_blanket(base_model.g, new_parameters), new_parameters
)
sorted_blanket_with_vars = filter(
vn -> vn in markov_blanket_with_vars, base_model.sorted_nodes
vn -> vn in markov_blanket_with_vars, base_model.eval_cache.sorted_nodes
)

new_model = BUGSModel(
model, new_parameters, sorted_blanket_with_vars, base_model.evaluation_env
model, model.g, new_parameters, sorted_blanket_with_vars, base_model.evaluation_env
)
evaluate_env, _ = evaluate!!(new_model)
return BangBang.setproperty!!(new_model, :evaluation_env, evaluate_env)
Expand All @@ -374,12 +418,15 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel)
end

function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
(; evaluation_env, g, sorted_nodes) = model
(; evaluation_env, g) = model
vi = deepcopy(evaluation_env)
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand All @@ -394,11 +441,14 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
end

function AbstractPPL.evaluate!!(model::BUGSModel)
(; sorted_nodes, g, evaluation_env) = model
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
evaluation_env = deepcopy(model.evaluation_env)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand Down Expand Up @@ -428,13 +478,16 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect
model.untransformed_var_lengths
end

g = model.g
evaluation_env = deepcopy(model.evaluation_env)
current_idx = 1
logp = 0.0
for vn in model.sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
is_observed = model.eval_cache.is_observed_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
Expand Down
4 changes: 2 additions & 2 deletions test/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ c = @varname c
cond_model = AbstractPPL.condition(model, setdiff(model.parameters, [c]))
# tests for MarkovBlanketBUGSModel constructor
@test cond_model.parameters == [c]
@test Set(Symbol.(cond_model.sorted_nodes)) == Set([:l, :a, :b, :f, :c])
@test Set(Symbol.(cond_model.eval_cache.sorted_nodes)) == Set([:l, :a, :b, :f, :c])

decond_model = AbstractPPL.decondition(cond_model, [a, l])
@test Set(Symbol.(decond_model.parameters)) == Set([:a, :c, :l])
@test Set(Symbol.(decond_model.sorted_nodes)) ==
@test Set(Symbol.(decond_model.eval_cache.sorted_nodes)) ==
Set([:l, :b, :f, :a, :d, :e, :c, :h, :g, :i])

c_value = 4.0
Expand Down

0 comments on commit a1443e8

Please sign in to comment.