diff --git a/ext/JuliaBUGSMCMCChainsExt.jl b/ext/JuliaBUGSMCMCChainsExt.jl index 8b1a1a12c..b17d575ba 100644 --- a/ext/JuliaBUGSMCMCChainsExt.jl +++ b/ext/JuliaBUGSMCMCChainsExt.jl @@ -84,7 +84,9 @@ function JuliaBUGS.gen_chains( g = model.g generated_vars = find_generated_vars(g) - generated_vars = [v for v in model.eval_cache.sorted_nodes if v in generated_vars] # keep the order + generated_vars = [ + v for v in model.flattened_graph_node_data.sorted_nodes if v in generated_vars + ] # keep the order param_vals = [] generated_quantities = [] diff --git a/src/compiler_pass.jl b/src/compiler_pass.jl index 661d8f027..d46f54920 100644 --- a/src/compiler_pass.jl +++ b/src/compiler_pass.jl @@ -691,12 +691,12 @@ function build_node_functions( ) for statement in expr.args if is_deterministic(statement) || is_stochastic(statement) - rhs = if is_deterministic(statement) - statement.args[2] + lhs, rhs = if is_deterministic(statement) + statement.args[1], statement.args[2] else - statement.args[3] + statement.args[2], statement.args[3] end - args, node_func_expr = make_function_expr(rhs, eval_env) + args, node_func_expr = make_function_expr(lhs, rhs, eval_env) node_func = eval(node_func_expr) f_dict[statement] = (args, node_func_expr, node_func) elseif Meta.isexpr(statement, :for) @@ -709,25 +709,55 @@ function build_node_functions( return f_dict end -function make_function_expr(expr, env::NamedTuple{vars}) where {vars} - args = Tuple(keys(extract_variable_names_and_numdims(expr, ()))) - arg_exprs = Expr[] - for v in args - if v ∈ vars - value = env[v] - if value isa Int || value isa Float64 || value isa Missing - push!(arg_exprs, Expr(:(::), v, :Real)) - elseif value isa AbstractArray - push!(arg_exprs, Expr(:(::), v, :(Array{<:Real}))) - else - error("Unexpected argument type: $(typeof(value))") - end - else # loop variable - push!(arg_exprs, Expr(:(::), v, :Int)) - end - end +""" + make_function_expr(lhs, rhs, env::NamedTuple{vars}; use_lhs_as_func_name=false) - expr = MacroTools.postwalk(expr) do sub_expr +Generate a function expression for the given right-hand side expression `rhs`. The generated function will take +a `NamedTuple` as its argument, which contains the values of the variables used in `rhs`. + +# Examples +```jldoctest; setup = :(using JuliaBUGS: make_function_expr) +julia> make_function_expr(:(x[a, b]), :(x[a, b] + 1), (x = [1 2 3; 4 5 6], a = missing, b = missing)) +((:a, :b, :x), :(function (evaluation_env, loop_vars) + (; a, b, x) = evaluation_env + (;) = loop_vars + return x[Int(a), Int(b)] + 1 + end)) + +julia> make_function_expr(:(x[a, b]), :(x[a, b] + 1), (;x = [1 2 3; 4 5 6])) +((:a, :b, :x), :(function (evaluation_env, loop_vars) + (; x) = evaluation_env + (; a, b) = loop_vars + return x[Int(a), Int(b)] + 1 + end)) +``` +""" +function make_function_expr( + lhs, rhs, env::NamedTuple{vars}; use_lhs_as_func_name=false +) where {vars} + args = Tuple(keys(extract_variable_names_and_numdims(rhs, ()))) + loop_vars = Tuple([v for v in args if v ∉ vars]) + variables = setdiff(args, loop_vars) + # arg_exprs = Expr[] + # for v in args + # if v ∈ vars + # value = env[v] + # if value isa Int || value isa Float64 || value isa Missing + # push!(arg_exprs, Expr(:(::), v, :Real)) + # elseif value isa AbstractArray + # push!(arg_exprs, Expr(:(::), v, :(Array{<:Real}))) + # else + # error("Unexpected argument type: $(typeof(value))") + # end + # else # loop variable + # push!(arg_exprs, Expr(:(::), v, :Int)) + # end + # end + + unpacking_expr = :((; $(variables...),) = evaluation_env) + unpacking_loop_vars_expr = :((; $(loop_vars...),) = loop_vars) + + func_body = MacroTools.postwalk(rhs) do sub_expr if @capture(sub_expr, v_[indices__]) new_indices = Any[] for i in eachindex(indices) @@ -744,8 +774,26 @@ function make_function_expr(expr, env::NamedTuple{vars}) where {vars} return sub_expr end - return args, MacroTools.@q function (; $(arg_exprs...)) - return $(expr) + # if use_lhs_as_func_name + # func_name = if lhs isa Symbol + # lhs + # else + # Symbol("__", String(lhs.args[1]), "_", join(lhs.args[2:end], "_"), "__") + # end + + # return args, MacroTools.@q function $func_name($(arg_exprs...)) + # return $(func_body) + # end + # else + # return args, MacroTools.@q function ($(arg_exprs...)) + # return $(func_body) + # end + # end + + return args, MacroTools.@q function (evaluation_env, loop_vars) + $(unpacking_expr) + $(unpacking_loop_vars_expr) + return $(func_body) end end diff --git a/src/gibbs.jl b/src/gibbs.jl index ab447d2d7..f61e8b319 100644 --- a/src/gibbs.jl +++ b/src/gibbs.jl @@ -32,7 +32,8 @@ function AbstractMCMC.step( conditioned_model = AbstractPPL.condition( model, variable_to_condition_on, model.evaluation_env ) - cached_eval_caches[variable_to_condition_on] = conditioned_model.eval_cache + cached_eval_caches[variable_to_condition_on] = + conditioned_model.flattened_graph_node_data end param_values = JuliaBUGS.getparams(model) return param_values, GibbsState(param_values, conditioning_schedule, cached_eval_caches) diff --git a/src/model.jl b/src/model.jl index b3400f4bc..4c178d785 100644 --- a/src/model.jl +++ b/src/model.jl @@ -4,40 +4,41 @@ abstract type AbstractBUGSModel end """ - EvalCache{TNF,TNA,TV} + FlattenedGraphNodeData{TNF,TNA,TV} Pre-compute the values of the nodes in the model to avoid lookups from MetaGraph. """ -struct EvalCache{TNF,TNA,TV} +struct FlattenedGraphNodeData{TNF,TV} sorted_nodes::Vector{<:VarName} is_stochastic_vals::Vector{Bool} is_observed_vals::Vector{Bool} node_function_vals::TNF - node_args_vals::TNA loop_vars_vals::TV end -function EvalCache(sorted_nodes::Vector{<:VarName}, g::BUGSGraph) +function FlattenedGraphNodeData( + g::BUGSGraph, + sorted_nodes::Vector{<:VarName}=VarName[ + label_for(g, node) for node in topological_sort(g) + ], +) is_stochastic_vals = Array{Bool}(undef, length(sorted_nodes)) is_observed_vals = Array{Bool}(undef, length(sorted_nodes)) - node_function_vals = [] - node_args_vals = [] - loop_vars_vals = [] + node_function_vals = Array{Any}(undef, length(sorted_nodes)) + loop_vars_vals = Array{Any}(undef, length(sorted_nodes)) for (i, vn) in enumerate(sorted_nodes) - (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] + (; is_stochastic, is_observed, node_function, loop_vars) = g[vn] is_stochastic_vals[i] = is_stochastic is_observed_vals[i] = is_observed - push!(node_function_vals, node_function) - push!(node_args_vals, Val(node_args)) - push!(loop_vars_vals, loop_vars) + node_function_vals[i] = node_function + loop_vars_vals[i] = loop_vars end - return EvalCache( + return FlattenedGraphNodeData( sorted_nodes, is_stochastic_vals, is_observed_vals, - node_function_vals, - node_args_vals, - loop_vars_vals, + map(identity, node_function_vals), + map(identity, loop_vars_vals), ) end @@ -47,9 +48,8 @@ end The `BUGSModel` object is used for inference and represents the output of compilation. It implements the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface. """ -struct BUGSModel{ - base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TNA,TV -} <: AbstractBUGSModel +struct BUGSModel{base_model_T<:Union{<:AbstractBUGSModel,Nothing},T<:NamedTuple,TNF,TV} <: + AbstractBUGSModel " Indicates whether the model parameters are in the transformed space. " transformed::Bool @@ -66,8 +66,8 @@ struct BUGSModel{ evaluation_env::T "A vector containing the names of the model parameters (unobserved stochastic variables)." parameters::Vector{<:VarName} - "An `EvalCache` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed." - eval_cache::EvalCache{TNF,TNA,TV} + "An `FlattenedGraphNodeData` object containing pre-computed values of the nodes in the model. For each topological order, this needs to be recomputed." + flattened_graph_node_data::FlattenedGraphNodeData{TNF,TV} "An instance of `BUGSGraph`, representing the dependency graph of the model." g::BUGSGraph @@ -76,53 +76,39 @@ struct BUGSModel{ base_model::base_model_T end -function Base.show(io::IO, m::BUGSModel) - if m.transformed +function Base.show(io::IO, model::BUGSModel) + if model.transformed println( io, - "BUGSModel (transformed, with dimension $(m.transformed_param_length)):", + "BUGSModel (transformed, with dimension $(model.transformed_param_length)):", "\n", ) else println( io, - "BUGSModel (untransformed, with dimension $(m.untransformed_param_length)):", + "BUGSModel (untransformed, with dimension $(model.untransformed_param_length)):", "\n", ) end println(io, " Model parameters:") - println(io, " ", join(m.parameters, ", "), "\n") + println(io, " ", join(model.parameters, ", "), "\n") println(io, " Variable values:") - return println(io, "$(m.evaluation_env)") + return println(io, "$(model.evaluation_env)") end """ - parameters(m::BUGSModel) + parameters(model::BUGSModel) Return a vector of `VarName` containing the names of the model parameters (unobserved stochastic variables). """ -parameters(m::BUGSModel) = m.parameters +parameters(model::BUGSModel) = model.parameters """ - variables(m::BUGSModel) + variables(model::BUGSModel) Return a vector of `VarName` containing the names of all the variables in the model. """ -variables(m::BUGSModel) = collect(labels(m.g)) - -@generated function prepare_arg_values( - ::Val{args}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars} -) where {args,lvars} - fields = [] - for arg in args - if arg in lvars - push!(fields, :(loop_vars[$(QuoteNode(arg))])) - else - push!(fields, :(evaluation_env[$(QuoteNode(arg))])) - end - end - return :(NamedTuple{$(args)}(($(fields...),))) -end +variables(model::BUGSModel) = collect(labels(model.g)) function BUGSModel( g::BUGSGraph, @@ -130,21 +116,24 @@ function BUGSModel( initial_params::NamedTuple=NamedTuple(); is_transformed::Bool=true, ) - sorted_nodes = VarName[label_for(g, node) for node in topological_sort(g)] + flattened_graph_node_data = FlattenedGraphNodeData(g) parameters = VarName[] untransformed_param_length, transformed_param_length = 0, 0 untransformed_var_lengths, transformed_var_lengths = Dict{VarName,Int}(), Dict{VarName,Int}() - for vn in sorted_nodes - (; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn] - args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars) + for (i, vn) in enumerate(flattened_graph_node_data.sorted_nodes) + is_stochastic = flattened_graph_node_data.is_stochastic_vals[i] + is_observed = flattened_graph_node_data.is_observed_vals[i] + node_function = flattened_graph_node_data.node_function_vals[i] + loop_vars = flattened_graph_node_data.loop_vars_vals[i] + if !is_stochastic - value = Base.invokelatest(node_function; args...) + value = Base.invokelatest(node_function, evaluation_env, loop_vars) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) elseif !is_observed push!(parameters, vn) - dist = Base.invokelatest(node_function; args...) + dist = Base.invokelatest(node_function, evaluation_env, loop_vars) untransformed_var_lengths[vn] = length(dist) # not all distributions are defined for `Bijectors.transformed` @@ -183,7 +172,7 @@ function BUGSModel( transformed_var_lengths, evaluation_env, parameters, - EvalCache(sorted_nodes, g), + flattened_graph_node_data, g, nothing, ) @@ -204,7 +193,7 @@ function BUGSModel( model.transformed_var_lengths, evaluation_env, parameters, - EvalCache(sorted_nodes, g), + FlattenedGraphNodeData(g, sorted_nodes), g, isnothing(model.base_model) ? model : model.base_model, ) @@ -217,15 +206,13 @@ Initialize the model with a NamedTuple of initial values, the values are expecte """ function initialize!(model::BUGSModel, initial_params::NamedTuple) check_input(initial_params) - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - is_observed = model.eval_cache.is_observed_vals[i] - node_function = model.eval_cache.node_function_vals[i] - node_args = model.eval_cache.node_args_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] - args = prepare_arg_values(node_args, model.evaluation_env, loop_vars) + for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) + is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i] + is_observed = model.flattened_graph_node_data.is_observed_vals[i] + node_function = model.flattened_graph_node_data.node_function_vals[i] + loop_vars = model.flattened_graph_node_data.loop_vars_vals[i] if !is_stochastic - value = Base.invokelatest(node_function; args...) + value = Base.invokelatest(node_function, model.evaluation_env, loop_vars) BangBang.@set!! model.evaluation_env = setindex!!( model.evaluation_env, value, vn ) @@ -242,7 +229,7 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple) else BangBang.@set!! model.evaluation_env = setindex!!( model.evaluation_env, - rand(Base.invokelatest(node_function; args...)), + rand(Base.invokelatest(node_function, model.evaluation_env, loop_vars)), vn, ) end @@ -286,9 +273,8 @@ function getparams(model::BUGSModel) param_vals[pos] = val end else - (; node_function, node_args, loop_vars) = model.g[v] - args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars) - dist = node_function(; args...) + (; node_function, loop_vars) = model.g[v] + dist = node_function(model.evaluation_env, loop_vars) transformed_value = Bijectors.transform( Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v) ) @@ -317,9 +303,8 @@ function getparams(T::Type{<:AbstractDict}, model::BUGSModel) if !model.transformed d[v] = value else - (; node_function, node_args, loop_vars) = model.g[v] - args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars) - dist = node_function(; args...) + (; node_function, loop_vars) = model.g[v] + dist = node_function(model.evaluation_env, loop_vars) d[v] = Bijectors.transform(Bijectors.bijector(dist), value) end end @@ -362,11 +347,11 @@ function AbstractPPL.condition( new_parameters = setdiff(model.parameters, var_group) sorted_blanket_with_vars = if sorted_nodes isa Nothing - model.eval_cache.sorted_nodes + model.flattened_graph_node_data.sorted_nodes else filter( vn -> vn in union(markov_blanket(model.g, new_parameters), new_parameters), - model.eval_cache.sorted_nodes, + model.flattened_graph_node_data.sorted_nodes, ) end @@ -393,15 +378,16 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName}) base_model = model.base_model isa Nothing ? model : model.base_model new_parameters = [ - v for - v in base_model.eval_cache.sorted_nodes if v in union(model.parameters, var_group) + v for v in base_model.flattened_graph_node_data.sorted_nodes if + v in union(model.parameters, var_group) ] # keep the order markov_blanket_with_vars = union( markov_blanket(base_model.g, new_parameters), new_parameters ) sorted_blanket_with_vars = filter( - vn -> vn in markov_blanket_with_vars, base_model.eval_cache.sorted_nodes + vn -> vn in markov_blanket_with_vars, + base_model.flattened_graph_node_data.sorted_nodes, ) new_model = BUGSModel( @@ -424,17 +410,15 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel) (; evaluation_env, g) = model vi = deepcopy(evaluation_env) logp = 0.0 - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - node_function = model.eval_cache.node_function_vals[i] - node_args = model.eval_cache.node_args_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] - args = prepare_arg_values(node_args, evaluation_env, loop_vars) + for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) + is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i] + node_function = model.flattened_graph_node_data.node_function_vals[i] + loop_vars = model.flattened_graph_node_data.loop_vars_vals[i] if !is_stochastic - value = node_function(; args...) + value = node_function(model.evaluation_env, loop_vars) evaluation_env = setindex!!(evaluation_env, value, vn) else - dist = node_function(; args...) + dist = node_function(model.evaluation_env, loop_vars) value = rand(rng, dist) # just sample from the prior logp += logpdf(dist, value) evaluation_env = setindex!!(evaluation_env, value, vn) @@ -446,17 +430,15 @@ end function AbstractPPL.evaluate!!(model::BUGSModel) logp = 0.0 evaluation_env = deepcopy(model.evaluation_env) - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - node_function = model.eval_cache.node_function_vals[i] - node_args = model.eval_cache.node_args_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] - args = prepare_arg_values(node_args, evaluation_env, loop_vars) + for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) + is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i] + node_function = model.flattened_graph_node_data.node_function_vals[i] + loop_vars = model.flattened_graph_node_data.loop_vars_vals[i] if !is_stochastic - value = node_function(; args...) + value = node_function(model.evaluation_env, loop_vars) evaluation_env = setindex!!(evaluation_env, value, vn) else - dist = node_function(; args...) + dist = node_function(model.evaluation_env, loop_vars) value = AbstractPPL.get(evaluation_env, vn) if model.transformed # although the values stored in `evaluation_env` are in their original space, @@ -484,18 +466,16 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect evaluation_env = deepcopy(model.evaluation_env) current_idx = 1 logp = 0.0 - for (i, vn) in enumerate(model.eval_cache.sorted_nodes) - is_stochastic = model.eval_cache.is_stochastic_vals[i] - is_observed = model.eval_cache.is_observed_vals[i] - node_function = model.eval_cache.node_function_vals[i] - node_args = model.eval_cache.node_args_vals[i] - loop_vars = model.eval_cache.loop_vars_vals[i] - args = prepare_arg_values(node_args, evaluation_env, loop_vars) + for (i, vn) in enumerate(model.flattened_graph_node_data.sorted_nodes) + is_stochastic = model.flattened_graph_node_data.is_stochastic_vals[i] + is_observed = model.flattened_graph_node_data.is_observed_vals[i] + node_function = model.flattened_graph_node_data.node_function_vals[i] + loop_vars = model.flattened_graph_node_data.loop_vars_vals[i] if !is_stochastic - value = node_function(; args...) + value = node_function(evaluation_env, loop_vars) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) else - dist = node_function(; args...) + dist = node_function(evaluation_env, loop_vars) if !is_observed l = var_lengths[vn] if model.transformed diff --git a/test/graphs.jl b/test/graphs.jl index bfb9dae12..0e6aa572a 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -47,11 +47,12 @@ c = @varname c cond_model = AbstractPPL.condition(model, setdiff(model.parameters, [c])) # tests for MarkovBlanketBUGSModel constructor @test cond_model.parameters == [c] -@test Set(Symbol.(cond_model.eval_cache.sorted_nodes)) == Set([:l, :a, :b, :f, :c]) +@test Set(Symbol.(cond_model.flattened_graph_node_data.sorted_nodes)) == + Set([:l, :a, :b, :f, :c]) decond_model = AbstractPPL.decondition(cond_model, [a, l]) @test Set(Symbol.(decond_model.parameters)) == Set([:a, :c, :l]) -@test Set(Symbol.(decond_model.eval_cache.sorted_nodes)) == +@test Set(Symbol.(decond_model.flattened_graph_node_data.sorted_nodes)) == Set([:l, :b, :f, :a, :d, :e, :c, :h, :g, :i]) c_value = 4.0