Skip to content

Commit

Permalink
improve the show function of BUGSModel
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 14, 2024
1 parent f3230cc commit 3effb12
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using StaticArrays

import Base: ==, hash, Symbol, size
import Distributions: truncated
import AbstractPPL: AbstractContext, evaluate!!
import AbstractPPL: condition, decondition, evaluate!!

export @bugs
export compile, initialize!
Expand Down
74 changes: 48 additions & 26 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,47 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,
end

function Base.show(io::IO, model::BUGSModel)
if model.transformed
println(
io,
"BUGSModel (transformed, with dimension $(model.transformed_param_length)):",
"\n",
)
# Print model type and dimension
space_type =
model.transformed ? "transformed (unconstrained)" : "original (constrained)"
dim = if model.transformed
model.transformed_param_length
else
println(
io,
"BUGSModel (untransformed, with dimension $(model.untransformed_param_length)):",
"\n",
)
model.untransformed_param_length
end
printstyled(io, "BUGSModel"; bold=true, color=:blue)
println(io, " (parameters are in ", space_type, " space, with dimension ", dim, "):\n")

# Group and print parameters
printstyled(io, " Model parameters:\n"; bold=true, color=:yellow)
grouped_params = Dict{Symbol,Vector{VarName}}()
for param in model.parameters
sym = AbstractPPL.getsym(param)
push!(get!(grouped_params, sym, VarName[]), param)
end
for (sym, params) in grouped_params
param_str = length(params) == 1 ? string(params[1]) : "$(join(params, ", "))"
print(io, " ")
printstyled(io, param_str; color=:cyan)
println(io)
end
println(io, " Model parameters:")
println(io, " ", join(model.parameters, ", "), "\n")
println(io, " Variable values:")
return println(io, "$(model.evaluation_env)")
println(io)

# Print variable info
printstyled(io, " Variable sizes and types:\n"; bold=true, color=:yellow)
for (name, value) in pairs(model.evaluation_env)
type_str = if isa(value, Number)
"type = $(typeof(value))"
else
"size = $(size(value)), type = $(typeof(value))"
end
print(io, " ")
printstyled(io, name; color=:cyan)
print(io, ": ")
printstyled(io, type_str; color=:green)
println(io)
end
return nothing
end

"""
Expand Down Expand Up @@ -323,24 +347,22 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed))
return BangBang.setproperty!!(model, :transformed, bool)
end

function AbstractPPL.condition(
function condition(
model::BUGSModel,
d::Dict{<:VarName,<:Any},
d::Dict{<:VarName,<:Any};
sorted_nodes=Nothing, # support cached sorted Markov blanket nodes
)
new_evaluation_env = deepcopy(model.evaluation_env)
for (p, value) in d
new_evaluation_env = setindex!!(new_evaluation_env, value, p)
end
return AbstractPPL.condition(
model, collect(keys(d)), new_evaluation_env; sorted_nodes=sorted_nodes
)
return condition(model, collect(keys(d)), new_evaluation_env; sorted_nodes)
end

function AbstractPPL.condition(
function condition(
model::BUGSModel,
var_group::Vector{<:VarName},
evaluation_env::NamedTuple=model.evaluation_env,
evaluation_env::NamedTuple=model.evaluation_env;
sorted_nodes=Nothing,
)
check_var_group(var_group, model)
Expand Down Expand Up @@ -373,7 +395,7 @@ function AbstractPPL.condition(
return BangBang.setproperty!!(new_model, :g, g)
end

function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
function decondition(model::BUGSModel, var_group::Vector{<:VarName})
check_var_group(var_group, model)
base_model = model.base_model isa Nothing ? model : model.base_model

Expand Down Expand Up @@ -406,7 +428,7 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel)
)
end

function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
function evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
(; evaluation_env, g) = model
vi = deepcopy(evaluation_env)
logp = 0.0
Expand All @@ -427,7 +449,7 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
return evaluation_env, logp
end

function AbstractPPL.evaluate!!(model::BUGSModel)
function evaluate!!(model::BUGSModel)
logp = 0.0
evaluation_env = deepcopy(model.evaluation_env)
for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes)
Expand Down Expand Up @@ -456,7 +478,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel)
return evaluation_env, logp
end

function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVector)
function evaluate!!(model::BUGSModel, flattened_values::AbstractVector)
var_lengths = if model.transformed
model.transformed_var_lengths
else
Expand Down

0 comments on commit 3effb12

Please sign in to comment.