Skip to content

Commit

Permalink
finish renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 7, 2024
1 parent e4e7583 commit 9517682
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function JuliaBUGS.gen_chains(
g = model.g

generated_vars = find_generated_vars(g)
generated_vars = [v for v in model.eval_cache.sorted_nodes if v in generated_vars] # keep the order
generated_vars = [v for v in model.flattened_graph_node_data.sorted_nodes if v in generated_vars] # keep the order

param_vals = []
generated_quantities = []
Expand Down
3 changes: 2 additions & 1 deletion src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ function AbstractMCMC.step(
conditioned_model = AbstractPPL.condition(
model, variable_to_condition_on, model.evaluation_env
)
cached_eval_caches[variable_to_condition_on] = conditioned_model.eval_cache
cached_eval_caches[variable_to_condition_on] =
conditioned_model.flattened_graph_node_data
end
param_values = JuliaBUGS.getparams(model)
return param_values, GibbsState(param_values, conditioning_schedule, cached_eval_caches)
Expand Down
5 changes: 3 additions & 2 deletions test/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@ c = @varname c
cond_model = AbstractPPL.condition(model, setdiff(model.parameters, [c]))
# tests for MarkovBlanketBUGSModel constructor
@test cond_model.parameters == [c]
@test Set(Symbol.(cond_model.eval_cache.sorted_nodes)) == Set([:l, :a, :b, :f, :c])
@test Set(Symbol.(cond_model.flattened_graph_node_data.sorted_nodes)) ==
Set([:l, :a, :b, :f, :c])

decond_model = AbstractPPL.decondition(cond_model, [a, l])
@test Set(Symbol.(decond_model.parameters)) == Set([:a, :c, :l])
@test Set(Symbol.(decond_model.eval_cache.sorted_nodes)) ==
@test Set(Symbol.(decond_model.flattened_graph_node_data.sorted_nodes)) ==
Set([:l, :b, :f, :a, :d, :e, :c, :h, :g, :i])

c_value = 4.0
Expand Down

0 comments on commit 9517682

Please sign in to comment.