Skip to content

Commit

Permalink
more perf improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Nov 7, 2024
1 parent 1900aed commit 4fb1eab
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 71 deletions.
96 changes: 65 additions & 31 deletions src/compiler_pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -709,27 +709,55 @@ function build_node_functions(
return f_dict
end

"""
make_function_expr(lhs, rhs, env::NamedTuple{vars}; use_lhs_as_func_name=false)
Generate a function expression for the given right-hand side expression `rhs`. The generated function will take
a `NamedTuple` as its argument, which contains the values of the variables used in `rhs`.
# Examples
```jldoctest; setup = :(using JuliaBUGS: make_function_expr)
julia> make_function_expr(:(x[a, b]), :(x[a, b] + 1), (x = [1 2 3; 4 5 6], a = missing, b = missing))
((:a, :b, :x), :(function (evaluation_env, loop_vars)
(; a, b, x) = evaluation_env
(;) = loop_vars
return x[Int(a), Int(b)] + 1
end))
julia> make_function_expr(:(x[a, b]), :(x[a, b] + 1), (;x = [1 2 3; 4 5 6]))
((:a, :b, :x), :(function (evaluation_env, loop_vars)
(; x) = evaluation_env
(; a, b) = loop_vars
return x[Int(a), Int(b)] + 1
end))
```
"""
function make_function_expr(
lhs, rhs, env::NamedTuple{vars}; use_lhs_as_func_name=false
) where {vars}
args = Tuple(keys(extract_variable_names_and_numdims(rhs, ())))
arg_exprs = Expr[]
for v in args
if v vars
value = env[v]
if value isa Int || value isa Float64 || value isa Missing
push!(arg_exprs, Expr(:(::), v, :Real))
elseif value isa AbstractArray
push!(arg_exprs, Expr(:(::), v, :(Array{<:Real})))
else
error("Unexpected argument type: $(typeof(value))")
end
else # loop variable
push!(arg_exprs, Expr(:(::), v, :Int))
end
end

expr = MacroTools.postwalk(rhs) do sub_expr
loop_vars = Tuple([v for v in args if v vars])
variables = setdiff(args, loop_vars)
# arg_exprs = Expr[]
# for v in args
# if v ∈ vars
# value = env[v]
# if value isa Int || value isa Float64 || value isa Missing
# push!(arg_exprs, Expr(:(::), v, :Real))
# elseif value isa AbstractArray
# push!(arg_exprs, Expr(:(::), v, :(Array{<:Real})))
# else
# error("Unexpected argument type: $(typeof(value))")
# end
# else # loop variable
# push!(arg_exprs, Expr(:(::), v, :Int))
# end
# end

unpacking_expr = :((; $(variables...),) = evaluation_env)
unpacking_loop_vars_expr = :((; $(loop_vars...),) = loop_vars)

func_body = MacroTools.postwalk(rhs) do sub_expr
if @capture(sub_expr, v_[indices__])
new_indices = Any[]
for i in eachindex(indices)
Expand All @@ -746,20 +774,26 @@ function make_function_expr(
return sub_expr
end

if use_lhs_as_func_name
func_name = if lhs isa Symbol
lhs
else
Symbol("__", String(lhs.args[1]), "_", join(lhs.args[2:end], "_"), "__")
end

return args, MacroTools.@q function $func_name($(arg_exprs...))
return $(expr)
end
else
return args, MacroTools.@q function ($(arg_exprs...))
return $(expr)
end
# if use_lhs_as_func_name
# func_name = if lhs isa Symbol
# lhs
# else
# Symbol("__", String(lhs.args[1]), "_", join(lhs.args[2:end], "_"), "__")
# end

# return args, MacroTools.@q function $func_name($(arg_exprs...))
# return $(func_body)
# end
# else
# return args, MacroTools.@q function ($(arg_exprs...))
# return $(func_body)
# end
# end

return args, MacroTools.@q function (evaluation_env, loop_vars)
$(unpacking_expr)
$(unpacking_loop_vars_expr)
return $(func_body)
end
end

Expand Down
55 changes: 15 additions & 40 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,6 @@ Return a vector of `VarName` containing the names of all the variables in the mo
"""
variables(model::BUGSModel) = collect(labels(model.g))

@generated function prepare_arg_values(
::Val{args}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars}
) where {args,lvars}
fields = []
for arg in args
if arg in lvars
push!(fields, :(loop_vars[$(QuoteNode(arg))]))
else
push!(fields, :(evaluation_env[$(QuoteNode(arg))]))
end
end
return :(NamedTuple{$(args)}(($(fields...),)))
end

function BUGSModel(
g::BUGSGraph,
evaluation_env::NamedTuple,
Expand All @@ -137,14 +123,13 @@ function BUGSModel(
Dict{VarName,Int}()

for vn in sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
(; is_stochastic, is_observed, node_function, loop_vars) = g[vn]
if !is_stochastic
value = Base.invokelatest(node_function; args...)
value = Base.invokelatest(node_function, evaluation_env, loop_vars)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
elseif !is_observed
push!(parameters, vn)
dist = Base.invokelatest(node_function; args...)
dist = Base.invokelatest(node_function, evaluation_env, loop_vars)

untransformed_var_lengths[vn] = length(dist)
# not all distributions are defined for `Bijectors.transformed`
Expand Down Expand Up @@ -221,11 +206,9 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
is_observed = model.eval_cache.is_observed_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
if !is_stochastic
value = Base.invokelatest(node_function; args...)
value = Base.invokelatest(node_function, model.evaluation_env, loop_vars)
BangBang.@set!! model.evaluation_env = setindex!!(
model.evaluation_env, value, vn
)
Expand All @@ -242,7 +225,7 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple)
else
BangBang.@set!! model.evaluation_env = setindex!!(
model.evaluation_env,
rand(Base.invokelatest(node_function; args...)),
rand(Base.invokelatest(node_function, model.evaluation_env, loop_vars)),
vn,
)
end
Expand Down Expand Up @@ -286,9 +269,8 @@ function getparams(model::BUGSModel)
param_vals[pos] = val
end
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
(; node_function, loop_vars) = model.g[v]
dist = node_function(model.evaluation_env, loop_vars)
transformed_value = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
)
Expand Down Expand Up @@ -317,9 +299,8 @@ function getparams(T::Type{<:AbstractDict}, model::BUGSModel)
if !model.transformed
d[v] = value
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
(; node_function, loop_vars) = model.g[v]
dist = node_function(model.evaluation_env, loop_vars)
d[v] = Bijectors.transform(Bijectors.bijector(dist), value)
end
end
Expand Down Expand Up @@ -427,14 +408,12 @@ function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
value = node_function(model.evaluation_env, loop_vars)
evaluation_env = setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
dist = node_function(model.evaluation_env, loop_vars)
value = rand(rng, dist) # just sample from the prior
logp += logpdf(dist, value)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand All @@ -449,14 +428,12 @@ function AbstractPPL.evaluate!!(model::BUGSModel)
for (i, vn) in enumerate(model.eval_cache.sorted_nodes)
is_stochastic = model.eval_cache.is_stochastic_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
value = node_function(model.evaluation_env, loop_vars)
evaluation_env = setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
dist = node_function(model.evaluation_env, loop_vars)
value = AbstractPPL.get(evaluation_env, vn)
if model.transformed
# although the values stored in `evaluation_env` are in their original space,
Expand Down Expand Up @@ -488,14 +465,12 @@ function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVect
is_stochastic = model.eval_cache.is_stochastic_vals[i]
is_observed = model.eval_cache.is_observed_vals[i]
node_function = model.eval_cache.node_function_vals[i]
node_args = model.eval_cache.node_args_vals[i]
loop_vars = model.eval_cache.loop_vars_vals[i]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
value = node_function(evaluation_env, loop_vars)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
dist = node_function(evaluation_env, loop_vars)
if !is_observed
l = var_lengths[vn]
if model.transformed
Expand Down

0 comments on commit 4fb1eab

Please sign in to comment.