From 243a699fa65d993cf4d32beca4de40d29e82eb56 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 14 Nov 2024 07:08:50 +0000 Subject: [PATCH] revert some changes --- src/JuliaBUGS.jl | 1 - src/gibbs.jl | 4 +++- src/model.jl | 21 ++++++++++++++------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 6987619af..d2add3de3 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -16,7 +16,6 @@ using StaticArrays import Base: ==, hash, Symbol, size import Distributions: truncated -import AbstractPPL: condition, decondition, evaluate!! export @bugs export compile, initialize! diff --git a/src/gibbs.jl b/src/gibbs.jl index 6c191635b..f61e8b319 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -50,7 +50,9 @@ function AbstractMCMC.step( param_values = state.values for vs in keys(state.conditioning_schedule) model = initialize!(model, param_values) - cond_model = condition(model, vs, model.evaluation_env) + cond_model = AbstractPPL.condition( + model, vs, model.evaluation_env, state.cached_eval_caches[vs] + ) param_values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs]) end return param_values, diff --git a/src/model.jl b/src/model.jl index 625882555..06f39eded 100644 --- a/src/model.jl +++ b/src/model.jl @@ -347,18 +347,25 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed)) return BangBang.setproperty!!(model, :transformed, bool) end -function condition(model::BUGSModel, d::Dict{<:VarName,<:Any}) +function AbstractPPL.condition( + model::BUGSModel, + d::Dict{<:VarName,<:Any}, + sorted_nodes=Nothing, # support cached sorted Markov blanket nodes +) new_evaluation_env = deepcopy(model.evaluation_env) for (p, value) in d new_evaluation_env = setindex!!(new_evaluation_env, value, p) end - return condition(model, collect(keys(d)), new_evaluation_env; sorted_nodes) + return AbstractPPL.condition( + model, collect(keys(d)), new_evaluation_env; sorted_nodes=sorted_nodes + ) end -function condition( +function AbstractPPL.condition( model::BUGSModel, var_group::Vector{<:VarName}, evaluation_env::NamedTuple=model.evaluation_env, + sorted_nodes=Nothing, ) check_var_group(var_group, model) new_parameters = setdiff(model.parameters, var_group) @@ -390,7 +397,7 @@ function condition( return BangBang.setproperty!!(new_model, :g, g) end -function decondition(model::BUGSModel, var_group::Vector{<:VarName}) +function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) check_var_group(var_group, model) base_model = model.base_model isa Nothing ? model : model.base_model @@ -423,7 +430,7 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) ) end -function evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) +function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) (; evaluation_env, g) = model vi = deepcopy(evaluation_env) logp = 0.0 @@ -444,7 +451,7 @@ function evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) return evaluation_env, logp end -function evaluate!!(model::BUGSModel) +function AbstractPPL.evaluate!!(model::BUGSModel) logp = 0.0 evaluation_env = deepcopy(model.evaluation_env) for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) @@ -473,7 +480,7 @@ function evaluate!!(model::BUGSModel) return evaluation_env, logp end -function evaluate!!(model::BUGSModel, flattened_values::AbstractVector) +function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVector) var_lengths = if model.transformed model.transformed_var_lengths else