From 3effb12fecbf3453a188779112e1132cd418ab1c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 14 Nov 2024 04:40:43 +0000 Subject: [PATCH] improve the show function of BUGSModel --- src/JuliaBUGS.jl | 2 +- src/model.jl | 74 +++++++++++++++++++++++++++++++----------------- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index db888c352..6987619af 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -16,7 +16,7 @@ using StaticArrays import Base: ==, hash, Symbol, size import Distributions: truncated -import AbstractPPL: AbstractContext, evaluate!! +import AbstractPPL: condition, decondition, evaluate!! export @bugs export compile, initialize! diff --git a/src/model.jl b/src/model.jl index 4c178d785..4a3f5e98b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -77,23 +77,47 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple, end function Base.show(io::IO, model::BUGSModel) - if model.transformed - println( - io, - "BUGSModel (transformed, with dimension $(model.transformed_param_length)):", - "\n", - ) + # Print model type and dimension + space_type = + model.transformed ? "transformed (unconstrained)" : "original (constrained)" + dim = if model.transformed + model.transformed_param_length else - println( - io, - "BUGSModel (untransformed, with dimension $(model.untransformed_param_length)):", - "\n", - ) + model.untransformed_param_length + end + printstyled(io, "BUGSModel"; bold=true, color=:blue) + println(io, " (parameters are in ", space_type, " space, with dimension ", dim, "):\n") + + # Group and print parameters + printstyled(io, " Model parameters:\n"; bold=true, color=:yellow) + grouped_params = Dict{Symbol,Vector{VarName}}() + for param in model.parameters + sym = AbstractPPL.getsym(param) + push!(get!(grouped_params, sym, VarName[]), param) + end + for (sym, params) in grouped_params + param_str = length(params) == 1 ? string(params[1]) : "$(join(params, ", "))" + print(io, " ") + printstyled(io, param_str; color=:cyan) + println(io) end - println(io, " Model parameters:") - println(io, " ", join(model.parameters, ", "), "\n") - println(io, " Variable values:") - return println(io, "$(model.evaluation_env)") + println(io) + + # Print variable info + printstyled(io, " Variable sizes and types:\n"; bold=true, color=:yellow) + for (name, value) in pairs(model.evaluation_env) + type_str = if isa(value, Number) + "type = $(typeof(value))" + else + "size = $(size(value)), type = $(typeof(value))" + end + print(io, " ") + printstyled(io, name; color=:cyan) + print(io, ": ") + printstyled(io, type_str; color=:green) + println(io) + end + return nothing end """ @@ -323,24 +347,22 @@ function settrans(model::BUGSModel, bool::Bool=!(model.transformed)) return BangBang.setproperty!!(model, :transformed, bool) end -function AbstractPPL.condition( +function condition( model::BUGSModel, - d::Dict{<:VarName,<:Any}, + d::Dict{<:VarName,<:Any}; sorted_nodes=Nothing, # support cached sorted Markov blanket nodes ) new_evaluation_env = deepcopy(model.evaluation_env) for (p, value) in d new_evaluation_env = setindex!!(new_evaluation_env, value, p) end - return AbstractPPL.condition( - model, collect(keys(d)), new_evaluation_env; sorted_nodes=sorted_nodes - ) + return condition(model, collect(keys(d)), new_evaluation_env; sorted_nodes) end -function AbstractPPL.condition( +function condition( model::BUGSModel, var_group::Vector{<:VarName}, - evaluation_env::NamedTuple=model.evaluation_env, + evaluation_env::NamedTuple=model.evaluation_env; sorted_nodes=Nothing, ) check_var_group(var_group, model) @@ -373,7 +395,7 @@ function AbstractPPL.condition( return BangBang.setproperty!!(new_model, :g, g) end -function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) +function decondition(model::BUGSModel, var_group::Vector{<:VarName}) check_var_group(var_group, model) base_model = model.base_model isa Nothing ? model : model.base_model @@ -406,7 +428,7 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) ) end -function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) +function evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) (; evaluation_env, g) = model vi = deepcopy(evaluation_env) logp = 0.0 @@ -427,7 +449,7 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) return evaluation_env, logp end -function AbstractPPL.evaluate!!(model::BUGSModel) +function evaluate!!(model::BUGSModel) logp = 0.0 evaluation_env = deepcopy(model.evaluation_env) for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) @@ -456,7 +478,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel) return evaluation_env, logp end -function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVector) +function evaluate!!(model::BUGSModel, flattened_values::AbstractVector) var_lengths = if model.transformed model.transformed_var_lengths else