Skip to content

Commit

Permalink
revert some changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 14, 2024
1 parent c6db48a commit 243a699
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
1 change: 0 additions & 1 deletion src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ using StaticArrays

import Base: ==, hash, Symbol, size
import Distributions: truncated
import AbstractPPL: condition, decondition, evaluate!!

export @bugs
export compile, initialize!
Expand Down
4 changes: 3 additions & 1 deletion src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ function AbstractMCMC.step(
param_values = state.values
for vs in keys(state.conditioning_schedule)
model = initialize!(model, param_values)
cond_model = condition(model, vs, model.evaluation_env)
cond_model = AbstractPPL.condition(
model, vs, model.evaluation_env, state.cached_eval_caches[vs]
)
param_values = gibbs_internal(rng, cond_model, state.conditioning_schedule[vs])
end
return param_values,
Expand Down
21 changes: 14 additions & 7 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,18 +347,25 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed))
return BangBang.setproperty!!(model, :transformed, bool)
end

function condition(model::BUGSModel, d::Dict{<:VarName,<:Any})
function AbstractPPL.condition(
model::BUGSModel,
d::Dict{<:VarName,<:Any},
sorted_nodes=Nothing, # support cached sorted Markov blanket nodes
)
new_evaluation_env = deepcopy(model.evaluation_env)
for (p, value) in d
new_evaluation_env = setindex!!(new_evaluation_env, value, p)
end
return condition(model, collect(keys(d)), new_evaluation_env; sorted_nodes)
return AbstractPPL.condition(
model, collect(keys(d)), new_evaluation_env; sorted_nodes=sorted_nodes
)
end

function condition(
function AbstractPPL.condition(
model::BUGSModel,
var_group::Vector{<:VarName},
evaluation_env::NamedTuple=model.evaluation_env,
sorted_nodes=Nothing,
)
check_var_group(var_group, model)
new_parameters = setdiff(model.parameters, var_group)
Expand Down Expand Up @@ -390,7 +397,7 @@ function condition(
return BangBang.setproperty!!(new_model, :g, g)
end

function decondition(model::BUGSModel, var_group::Vector{<:VarName})
function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
check_var_group(var_group, model)
base_model = model.base_model isa Nothing ? model : model.base_model

Expand Down Expand Up @@ -423,7 +430,7 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel)
)
end

function evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
(; evaluation_env, g) = model
vi = deepcopy(evaluation_env)
logp = 0.0
Expand All @@ -444,7 +451,7 @@ function evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
return evaluation_env, logp
end

function evaluate!!(model::BUGSModel)
function AbstractPPL.evaluate!!(model::BUGSModel)
logp = 0.0
evaluation_env = deepcopy(model.evaluation_env)
for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes)
Expand Down Expand Up @@ -473,7 +480,7 @@ function evaluate!!(model::BUGSModel)
return evaluation_env, logp
end

function evaluate!!(model::BUGSModel, flattened_values::AbstractVector)
function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVector)
var_lengths = if model.transformed
model.transformed_var_lengths
else
Expand Down

0 comments on commit 243a699

Please sign in to comment.