@@ -25,10 +25,11 @@ function _pobserve(expr::Expr)
2525 end
2626 retvals_and_likelihoods = fetch .(likelihood_tasks)
2727 total_likelihoods = sum (last, retvals_and_likelihoods)
28- # println("Total likelihoods: ", total_likelihoods)
29- $ (esc (:(__varinfo__))) = $ (DynamicPPL. accloglikelihood!!)(
30- $ (esc (:(__varinfo__))), total_likelihoods
31- )
28+ if $ (DynamicPPL. hasacc)($ (esc (:(__varinfo__))), Val (:LogLikelihood ))
29+ $ (esc (:(__varinfo__))) = $ (DynamicPPL. accloglikelihood!!)(
30+ $ (esc (:(__varinfo__))), total_likelihoods
31+ )
32+ end
3233 map (first, retvals_and_likelihoods)
3334 end
3435 return return_expr
@@ -49,8 +50,13 @@ function process_tilde_statements(expr::Expr)
4950 end
5051 ) || error (" expected block" )
5152 @gensym loglike
52- beginning_statement =
53- :($ loglike = zero ($ (DynamicPPL. getloglikelihood)($ (esc (:(__varinfo__))))))
53+ beginning_expr = quote
54+ $ loglike = if $ (DynamicPPL. hasacc)($ (esc (:(__varinfo__))), Val (:LogLikelihood ))
55+ zero ($ (DynamicPPL. getloglikelihood)($ (esc (:(__varinfo__)))))
56+ else
57+ zero ($ (DynamicPPL. LogProbType))
58+ end
59+ end
5460 n_statements = length (statements)
5561 transformed_statements:: Vector{Vector{Expr}} = map (enumerate (statements)) do (i, stmt)
5662 is_last = i == n_statements
@@ -79,6 +85,6 @@ function process_tilde_statements(expr::Expr)
7985 e
8086 end
8187 end
82- new_statements = [beginning_statement , reduce (vcat, transformed_statements)... ]
88+ new_statements = [beginning_expr . args ... , reduce (vcat, transformed_statements)... ]
8389 return Expr (:block , new_statements... )
8490end
0 commit comments