Skip to content

Commit

Permalink
Further improve performance of logdensity evaluation (#234)
Browse files Browse the repository at this point in the history
A source of type instability is the `prepare_arg_values`. This PR
removes this function by modifying node functions to take
`evaluation_env` directly and unpack within node functions, thus save
the lookup and constructing keyword arguments step.

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sunxd3 and github-actions[bot] authored Nov 8, 2024
1 parent b5673a2 commit aaeed00
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 125 deletions.
4 changes: 3 additions & 1 deletion ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ function JuliaBUGS.gen_chains(
g = model.g

generated_vars = find_generated_vars(g)
generated_vars = [v for v in model.eval_cache.sorted_nodes if v in generated_vars] # keep the order
generated_vars = [
v for v in model.flattened_graph_node_data.sorted_nodes if v in generated_vars
] # keep the order

param_vals = []
generated_quantities = []
Expand Down
96 changes: 72 additions & 24 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,12 @@ function build_node_functions(
)
for statement in expr.args
if is_deterministic(statement) || is_stochastic(statement)
rhs = if is_deterministic(statement)
statement.args[2]
lhs, rhs = if is_deterministic(statement)
statement.args[1], statement.args[2]
else
statement.args[3]
statement.args[2], statement.args[3]
end
args, node_func_expr = make_function_expr(rhs, eval_env)
args, node_func_expr = make_function_expr(lhs, rhs, eval_env)
node_func = eval(node_func_expr)
f_dict[statement] = (args, node_func_expr, node_func)
elseif Meta.isexpr(statement, :for)
Expand All @@ -709,25 +709,55 @@ function build_node_functions(
return f_dict
end

function make_function_expr(expr, env::NamedTuple{vars}) where {vars}
args = Tuple(keys(extract_variable_names_and_numdims(expr, ())))
arg_exprs = Expr[]
for v in args
if v vars
value = env[v]
if value isa Int || value isa Float64 || value isa Missing
push!(arg_exprs, Expr(:(::), v, :Real))
elseif value isa AbstractArray
push!(arg_exprs, Expr(:(::), v, :(Array{<:Real})))
else
error("Unexpected argument type: $(typeof(value))")
end
else # loop variable
push!(arg_exprs, Expr(:(::), v, :Int))
end
end
"""
make_function_expr(lhs, rhs, env::NamedTuple{vars}; use_lhs_as_func_name=false)
expr = MacroTools.postwalk(expr) do sub_expr
Generate a function expression for the given right-hand side expression `rhs`. The generated function will take
a `NamedTuple` as its argument, which contains the values of the variables used in `rhs`.
# Examples
```jldoctest; setup = :(using JuliaBUGS: make_function_expr)
julia> make_function_expr(:(x[a, b]), :(x[a, b] + 1), (x = [1 2 3; 4 5 6], a = missing, b = missing))
((:a, :b, :x), :(function (evaluation_env, loop_vars)
(; a, b, x) = evaluation_env
(;) = loop_vars
return x[Int(a), Int(b)] + 1
end))
julia> make_function_expr(:(x[a, b]), :(x[a, b] + 1), (;x = [1 2 3; 4 5 6]))
((:a, :b, :x), :(function (evaluation_env, loop_vars)
(; x) = evaluation_env
(; a, b) = loop_vars
return x[Int(a), Int(b)] + 1
end))
```
"""
function make_function_expr(
lhs, rhs, env::NamedTuple{vars}; use_lhs_as_func_name=false
) where {vars}
args = Tuple(keys(extract_variable_names_and_numdims(rhs, ())))
loop_vars = Tuple([v for v in args if v vars])
variables = setdiff(args, loop_vars)
# arg_exprs = Expr[]
# for v in args
# if v ∈ vars
# value = env[v]
# if value isa Int || value isa Float64 || value isa Missing
# push!(arg_exprs, Expr(:(::), v, :Real))
# elseif value isa AbstractArray
# push!(arg_exprs, Expr(:(::), v, :(Array{<:Real})))
# else
# error("Unexpected argument type: $(typeof(value))")
# end
# else # loop variable
# push!(arg_exprs, Expr(:(::), v, :Int))
# end
# end

unpacking_expr = :((; $(variables...),) = evaluation_env)
unpacking_loop_vars_expr = :((; $(loop_vars...),) = loop_vars)

func_body = MacroTools.postwalk(rhs) do sub_expr
if @capture(sub_expr, v_[indices__])
new_indices = Any[]
for i in eachindex(indices)
Expand All @@ -744,8 +774,26 @@ function make_function_expr(expr, env::NamedTuple{vars}) where {vars}
return sub_expr
end

return args, MacroTools.@q function (; $(arg_exprs...))
return $(expr)
# if use_lhs_as_func_name
# func_name = if lhs isa Symbol
# lhs
# else
# Symbol("__", String(lhs.args[1]), "_", join(lhs.args[2:end], "_"), "__")
# end

# return args, MacroTools.@q function $func_name($(arg_exprs...))
# return $(func_body)
# end
# else
# return args, MacroTools.@q function ($(arg_exprs...))
# return $(func_body)
# end
# end

return args, MacroTools.@q function (evaluation_env, loop_vars)
$(unpacking_expr)
$(unpacking_loop_vars_expr)
return $(func_body)
end
end

Expand Down
3 changes: 2 additions & 1 deletion src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ function AbstractMCMC.step(
conditioned_model = AbstractPPL.condition(
model, variable_to_condition_on, model.evaluation_env
)
cached_eval_caches[variable_to_condition_on] = conditioned_model.eval_cache
cached_eval_caches[variable_to_condition_on] =
conditioned_model.flattened_graph_node_data
end
param_values = JuliaBUGS.getparams(model)
return param_values, GibbsState(param_values, conditioning_schedule, cached_eval_caches)
Expand Down
Loading

0 comments on commit aaeed00

Please sign in to comment.