Skip to content

Commit

Permalink
Improve show function of BUGSModel (#236)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 authored Nov 14, 2024
1 parent f3230cc commit 41620db
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.6.4"
version = "0.6.5"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
1 change: 0 additions & 1 deletion src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ using StaticArrays

import Base: ==, hash, Symbol, size
import Distributions: truncated
import AbstractPPL: AbstractContext, evaluate!!

export @bugs
export compile, initialize!
Expand Down
54 changes: 39 additions & 15 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,47 @@ struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,
end

function Base.show(io::IO, model::BUGSModel)
if model.transformed
println(
io,
"BUGSModel (transformed, with dimension $(model.transformed_param_length)):",
"\n",
)
# Print model type and dimension
space_type =
model.transformed ? "transformed (unconstrained)" : "original (constrained)"
dim = if model.transformed
model.transformed_param_length
else
println(
io,
"BUGSModel (untransformed, with dimension $(model.untransformed_param_length)):",
"\n",
)
model.untransformed_param_length
end
printstyled(io, "BUGSModel"; bold=true, color=:blue)
println(io, " (parameters are in ", space_type, " space, with dimension ", dim, "):\n")

# Group and print parameters
printstyled(io, " Model parameters:\n"; bold=true, color=:yellow)
grouped_params = Dict{Symbol,Vector{VarName}}()
for param in model.parameters
sym = AbstractPPL.getsym(param)
push!(get!(grouped_params, sym, VarName[]), param)
end
for (sym, params) in grouped_params
param_str = length(params) == 1 ? string(params[1]) : "$(join(params, ", "))"
print(io, " ")
printstyled(io, param_str; color=:cyan)
println(io)
end
println(io)

# Print variable info
printstyled(io, " Variable sizes and types:\n"; bold=true, color=:yellow)
for (name, value) in pairs(model.evaluation_env)
type_str = if isa(value, Number)
"type = $(typeof(value))"
else
"size = $(size(value)), type = $(typeof(value))"
end
print(io, " ")
printstyled(io, name; color=:cyan)
print(io, ": ")
printstyled(io, type_str; color=:green)
println(io)
end
println(io, " Model parameters:")
println(io, " ", join(model.parameters, ", "), "\n")
println(io, " Variable values:")
return println(io, "$(model.evaluation_env)")
return nothing
end

"""
Expand Down

0 comments on commit 41620db

Please sign in to comment.