diff --git a/Project.toml b/Project.toml index c6d662c44..305b1c52c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.5" +version = "0.24.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 2b28b44a9..49605832a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -16,10 +16,22 @@ require_particles(spl::Sampler) = false # Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline. function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp) +end +function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp_assume!!(childcontext(context), vi, logp) +end +function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) return acclogp!!(context, vi, logp) end function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp) +end +function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) + return acclogp_observe!!(childcontext(context), vi, logp) +end +function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) return acclogp!!(context, vi, logp) end