Skip to content

Commit

Permalink
improve perf
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 24, 2024
1 parent 53e9de6 commit e97ade5
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ Return a vector of `VarName` containing the names of all the variables in the mo
"""
variables(m::BUGSModel) = collect(labels(m.g))

function prepare_arg_values(
args::Tuple{Vararg{Symbol}}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars}
) where {lvars}
return NamedTuple{args}(Tuple(
map(args) do arg
if arg in lvars
loop_vars[arg]
else
AbstractPPL.get(evaluation_env, @varname($arg))
end
end,
))
@generated function prepare_arg_values(
::Val{args}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars}
) where {args, lvars}
fields = []
for arg in args
if arg in lvars
push!(fields, :(loop_vars[$(QuoteNode(arg))]))
else
push!(fields, :(evaluation_env[$(QuoteNode(arg))]))
end
end
return :(NamedTuple{$(args)}(($(fields...),)))
end

function BUGSModel(
Expand Down Expand Up @@ -460,14 +460,14 @@ function AbstractPPL.evaluate!!(
current_idx = 1
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
if vn in model.parameters
if is_stochastic && !is_observed
l = var_lengths[vn]
if model.transformed
b = Bijectors.bijector(dist)
Expand Down

0 comments on commit e97ade5

Please sign in to comment.