diff --git a/HISTORY.md b/HISTORY.md index 45be1772d..dff3bb0d3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,10 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +### Breaking changes + +`PrefixContext` is removed, etc., etc.... (I'll write this if I actually have to) + ## 0.38.4 Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on. diff --git a/docs/src/api.md b/docs/src/api.md index b04bd445d..237c4a242 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -467,7 +467,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`. ```@docs DefaultContext -PrefixContext ConditionContext InitContext ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f5bd33d6d..94e554ace 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -97,7 +97,6 @@ export AbstractVarInfo, # Contexts contextualize, DefaultContext, - PrefixContext, ConditionContext, # Tilde pipeline tilde_assume!!, @@ -174,9 +173,9 @@ include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") include("contexts/transformation.jl") -include("contexts/prefix.jl") -include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl +include("contexts/conditionfix.jl") include("model.jl") +include("prefix.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") diff --git a/src/compiler.jl b/src/compiler.jl index 6384eaa7c..d9473315d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -42,7 +42,7 @@ function make_varname_expression(expr) end """ - isassumption(expr[, vn]) + isassumption(expr[, left_vn]) Return an expression that can be evaluated to check if `expr` is an assumption in the model. @@ -55,16 +55,19 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. -If `vn` is specified, it will be assumed to refer to a expression which +If `left_vn` is specified, it will be assumed to refer to a expression which evaluates to a `VarName`, and this will be used in the subsequent checks. -If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be +If `left_vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be used in its place. """ -function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)) +function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(expr)) + @gensym vn return quote - if $(DynamicPPL.contextual_isassumption)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + # TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like + # the whole `isassumption` thing to be simplified, though, so I'll + # leave it till later. + $vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix) + if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn) # Considered an assumption by `__model__.context` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of @@ -78,8 +81,8 @@ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) + if !($(DynamicPPL.inargnames)($left_vn, __model__)) || + $(DynamicPPL.inmissings)($left_vn, __model__) true else $(maybe_view(expr)) === missing @@ -99,7 +102,7 @@ isassumption(expr) = :(false) Return `true` if `vn` is considered an assumption by `context`. """ -function contextual_isassumption(context::AbstractContext, vn) +function contextual_isassumption(context::AbstractContext, vn::VarName) if hasconditioned_nested(context, vn) val = getconditioned_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? @@ -115,9 +118,7 @@ end isfixed(expr, vn) = false function isfixed(::Union{Symbol,Expr}, vn) - return :($(DynamicPPL.contextual_isfixed)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - )) + return :($(DynamicPPL.contextual_isfixed)(__model__.context, $vn)) end """ @@ -413,7 +414,9 @@ function generate_assign(left, right) return quote $right_val = $right if $(DynamicPPL.is_extracting_values)(__varinfo__) - $vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left))) + $vn = $(DynamicPPL.maybe_prefix)( + $(make_varname_expression(left)), __model__.prefix + ) __varinfo__ = $(map_accumulator!!)( $acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel) ) @@ -448,24 +451,23 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption value dist + @gensym left_vn vn isassumption value dist return quote $dist = $right - $vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) - $isassumption = $(DynamicPPL.isassumption(left, vn)) + $left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist) + $vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix) + $isassumption = $(DynamicPPL.isassumption(left, left_vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + $left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)( - __model__.context, $(DynamicPPL.prefix)(__model__.context, $vn) - ) + # If `left_vn` is not in `argnames`, we need to make sure that the variable is defined. + # (Note: we use the unprefixed `left_vn` here rather than `vn` which will have had + # prefixes applied!) + if !$(DynamicPPL.inargnames)($left_vn, __model__) + $left = $(DynamicPPL.getconditioned_nested)(__model__.context, $vn) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( @@ -495,6 +497,7 @@ function generate_tilde_assume(left, right, vn) return quote $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __model__.context, + __model__.prefix, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) diff --git a/src/contexts.jl b/src/contexts.jl index 32a236e8e..3c6882116 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -123,6 +123,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext """ DynamicPPL.tilde_assume!!( context::AbstractContext, + prefix::Union{VarName,Nothing}, right::Distribution, vn::VarName, vi::AbstractVarInfo @@ -134,13 +135,21 @@ sampled value and updated `vi`. `vn` is the VarName on the left-hand side of the tilde statement. +`prefix` is the currently active prefix; this is `nothing` if there is no active prefix. +For example, in `a ~ to_submodel(inner_model)`, when executing `inner_model`, the active +prefix will be `@varname(a)`. + This function should return a tuple `(x, vi)`, where `x` is the sampled value (which must be in unlinked space!) and `vi` is the updated VarInfo. """ function tilde_assume!!( - context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + context::AbstractContext, + prefix::Union{VarName,Nothing}, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo, ) - return tilde_assume!!(childcontext(context), right, vn, vi) + return tilde_assume!!(childcontext(context), prefix, right, vn, vi) end """ diff --git a/src/contexts/conditionfix.jl b/src/contexts/conditionfix.jl index d3802de85..5445f64ba 100644 --- a/src/contexts/conditionfix.jl +++ b/src/contexts/conditionfix.jl @@ -83,9 +83,6 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end -function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(collapse_prefix_stack(context), vn) -end """ getconditioned_nested(context, vn) @@ -101,9 +98,6 @@ end function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(collapse_prefix_stack(context), vn) -end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) getconditioned(context, vn) @@ -172,9 +166,6 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end -function conditioned(context::PrefixContext) - return conditioned(collapse_prefix_stack(context)) -end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values @@ -237,9 +228,6 @@ hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn) function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end -function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(collapse_prefix_stack(context), vn) -end """ getfixed_nested(context, vn) @@ -255,9 +243,6 @@ end function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(collapse_prefix_stack(context), vn) -end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) getfixed(context, vn) @@ -351,117 +336,3 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end -function fixed(context::PrefixContext) - return fixed(collapse_prefix_stack(context)) -end - -########################################################################### -### Interaction of PrefixContext with ConditionContext and FixedContext ### -########################################################################### - -""" - collapse_prefix_stack(context::AbstractContext) - -Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove -the `PrefixContext`s from the context stack. - -!!! note - If you are reading this docstring, you might probably be interested in a more -thorough explanation of how PrefixContext and ConditionContext / FixedContext -interact with one another, especially in the context of submodels. - The DynamicPPL documentation contains [a separate page on this -topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) -which explains this in much more detail. - -```jldoctest -julia> using DynamicPPL: collapse_prefix_stack - -julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); - -julia> collapse_prefix_stack(c1) -ConditionContext(Dict(a.x => 1), DefaultContext()) - -julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); - -julia> collapsed = collapse_prefix_stack(c2); - -julia> # `collapsed` really looks something like this: - # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) - # To avoid fragility arising from the order of the keys in the doctest, we test - # this indirectly: - collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] -(1, 2) -``` -""" -function collapse_prefix_stack(context::PrefixContext) - # Collapse the child context (thus applying any inner prefixes first) - collapsed = collapse_prefix_stack(childcontext(context)) - # Prefix any conditioned variables with the current prefix - # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. - # So is this function. In the worst case scenario, this is O(N^2) in the - # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) -end -function collapse_prefix_stack(context::AbstractContext) - return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) -end -collapse_prefix_stack(::IsLeaf, context) = context -function collapse_prefix_stack(::IsParent, context) - new_child_context = collapse_prefix_stack(childcontext(context)) - return setchildcontext(context, new_child_context) -end - -""" - prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) - -Prefix all the conditioned and fixed variables in a given context with a single -`prefix`. - -```jldoctest -julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext - -julia> c1 = ConditionContext((a=1, )) -ConditionContext((a = 1,), DefaultContext()) - -julia> prefix_cond_and_fixed_variables(c1, @varname(y)) -ConditionContext(Dict(y.a => 1), DefaultContext()) -``` -""" -function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) - return FixedContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) - return prefix_cond_and_fixed_variables( - NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix - ) -end -function prefix_cond_and_fixed_variables( - ::IsLeaf, context::AbstractContext, prefix::VarName -) - return context -end -function prefix_cond_and_fixed_variables( - ::IsParent, context::AbstractContext, prefix::VarName -) - return setchildcontext( - context, prefix_cond_and_fixed_variables(childcontext(context), prefix) - ) -end diff --git a/src/contexts/default.jl b/src/contexts/default.jl index ec21e1a56..92e07c538 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -21,14 +21,22 @@ NodeTrait(::DefaultContext) = IsLeaf() """ DynamicPPL.tilde_assume!!( - ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + ::DefaultContext, + prefix::Union{VarName,Nothing}, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo ) Handle assumed variables. For `DefaultContext`, this function extracts the value associated with `vn` from `vi`, If `vi` does not contain an appropriate value then this will error. """ function tilde_assume!!( - ::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo + ::DefaultContext, + ::Union{VarName,Nothing}, + right::Distribution, + vn::VarName, + vi::AbstractVarInfo, ) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, right) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 44dbc5508..c7bebeb39 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -153,7 +153,11 @@ end NodeTrait(::InitContext) = IsLeaf() function tilde_assume!!( - ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo + ctx::InitContext, + ::Union{VarName,Nothing}, + dist::Distribution, + vn::VarName, + vi::AbstractVarInfo, ) in_varinfo = haskey(vi, vn) # `init()` always returns values in original space, i.e. possibly diff --git a/src/contexts/prefix.jl b/src/contexts/prefix.jl deleted file mode 100644 index 24615e683..000000000 --- a/src/contexts/prefix.jl +++ /dev/null @@ -1,116 +0,0 @@ -""" - PrefixContext(vn::VarName[, context::AbstractContext]) - PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} - -Create a context that allows you to use the wrapped `context` when running the model and -prefixes all parameters with the VarName `vn`. - -`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. -If `context` is not provided, it defaults to `DefaultContext()`. - -This context is useful in nested models to ensure that the names of the parameters are -unique. - -See also: [`to_submodel`](@ref) -""" -struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext - vn_prefix::Tvn - context::C -end -PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) -function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} - return PrefixContext(VarName{sym}(), context) -end -PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) - -NodeTrait(::PrefixContext) = IsParent() -childcontext(context::PrefixContext) = context.context -function setchildcontext(ctx::PrefixContext, child::AbstractContext) - return PrefixContext(ctx.vn_prefix, child) -end - -""" - prefix(ctx::AbstractContext, vn::VarName) - -Apply the prefixes in the context `ctx` to the variable name `vn`. -""" -function prefix(ctx::PrefixContext, vn::VarName) - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) -end -function prefix(ctx::AbstractContext, vn::VarName) - return prefix(NodeTrait(ctx), ctx, vn) -end -prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn -function prefix(::IsParent, ctx::AbstractContext, vn::VarName) - return prefix(childcontext(ctx), vn) -end - -""" - prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - -Same as `prefix`, but additionally returns a new context stack that has all the -PrefixContexts removed. - -NOTE: This does _not_ modify any variables in any `ConditionContext` and -`FixedContext` that may be present in the context stack. This is because this -function is only used in `tilde_assume!!`, which is lower in the tilde-pipeline -than `contextual_isassumption` and `contextual_isfixed` (the functions which -actually use the `ConditionContext` and `FixedContext` values). Thus, by this -time, any `ConditionContext`s and `FixedContext`s present have already served -their purpose. - -If you call this function, you must therefore be careful to ensure that you _do -not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you -_do_ need to modify them, then you may need to use -`prefix_cond_and_fixed_variables` instead. -""" -function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) - child_context = childcontext(ctx) - # vn_prefixed contains the prefixes from all lower levels - vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( - child_context, vn - ) - return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes -end -function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) - return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) -end -prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) -function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) - vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) - return vn, setchildcontext(ctx, new_ctx) -end - -function tilde_assume!!( - context::PrefixContext, right::Distribution, vn::VarName, vi::AbstractVarInfo -) - # Note that we can't use something like this here: - # new_vn = prefix(context, vn) - # return tilde_assume!!(childcontext(context), right, new_vn, vi) - # This is because `prefix` applies _all_ prefixes in a given context to a - # variable name. Thus, if we had two levels of nested prefixes e.g. - # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the - # first call would apply the prefix `a.b._`, and the recursive call - # would apply the prefix `b._`, resulting in `b.a.b._`. - # This is why we need a special function, `prefix_and_strip_contexts`. - new_vn, new_context = prefix_and_strip_contexts(context, vn) - return tilde_assume!!(new_context, right, new_vn, vi) -end - -function tilde_observe!!( - context::PrefixContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal - # value. For the need for prefix_and_strip_contexts rather than just prefix, see the - # comment in `tilde_assume!!`. - new_vn, new_context = if vn !== nothing - prefix_and_strip_contexts(context, vn) - else - vn, childcontext(context) - end - return tilde_observe!!(new_context, right, left, new_vn, vi) -end diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 5153f7857..106f13974 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -14,6 +14,7 @@ NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume!!( ::DynamicTransformationContext{isinverse}, + ::Union{VarName,Nothing}, right::Distribution, vn::VarName, vi::AbstractVarInfo, diff --git a/src/model.jl b/src/model.jl index ec98b90cd..ef0dcfb33 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,9 +1,10 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} + struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,Tprefix<:Union{VarName,Nothing}} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx=DefaultContext() + prefix::Tprefix=nothing end A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` @@ -33,12 +34,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F, + argnames, + defaultnames, + missings, + Targs, + Tdefaults, + Ctx<:AbstractContext, + Tprefix<:Union{VarName,Nothing}, +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx + prefix::Tprefix @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -51,9 +61,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( - f, args, defaults, context + prefix::Tprefix=nothing, + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,Tprefix} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Tprefix}( + f, args, defaults, context, prefix ) end end @@ -71,6 +82,7 @@ model with different arguments. args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), + prefix::Union{VarName,Nothing}=nothing, ) where {F,argnames,Targs,kwargnames,Tkwargs} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing @@ -78,11 +90,19 @@ model with different arguments. missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{$(missing_args..., missing_kwargs...)}( + f, args, defaults, context, prefix + )) end -function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) +function Model( + f, + args::NamedTuple, + context::AbstractContext=DefaultContext(), + prefix::Union{VarName,Nothing}=nothing; + kwargs..., +) + return Model(f, args, NamedTuple(kwargs), context, prefix) end """ @@ -92,7 +112,7 @@ Return a new `Model` with the same evaluation function and other arguments, but with its underlying context set to `context`. """ function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model(model.f, model.args, model.defaults, context, model.prefix) end """ @@ -427,7 +447,7 @@ Return the conditioned values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: conditioned, contextualize +julia> using DynamicPPL: conditioned julia> @model function demo() m ~ Normal() @@ -437,36 +457,25 @@ demo (generic function with 2 methods) julia> m = demo(); -julia> # Returns all the variables we have conditioned on + their values. - conditioned(condition(m, x=100.0, m=1.0)) -(x = 100.0, m = 1.0) - -julia> # Nested ones also work. - # (Note that `PrefixContext` also prefixes the variables of any - # ConditionContext that is _inside_ it; because of this, the type of the - # container has to be broadened to a `Dict`.) - cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); +julia> # Condition on some values. + cm = m | (; x = 100.0, m = 1.0); -julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) -true - -julia> # Since we conditioned on `a.m`, it is not treated as a random variable. - # However, `a.x` will still be a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # Returns all the variables we have conditioned on, and their values. + conditioned(cm) +(x = 100.0, m = 1.0) -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); +julia> # If we prefix the model, the conditioned variables will also be prefixed. + pm = prefix(cm, @varname(f)); conditioned(pm) +Dict{VarName{:f}, Float64} with 2 entries: + f.x => 100.0 + f.m => 1.0 -julia> conditioned(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: - a.m => 1.0 +julia> # If we condition _after_ the prefix, the prefix is not applied. + pm2 = prefix(m, @varname(f)); cm2 = pm2 | (; x = 100.0, m = 1.0); -julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # When running this model, the variables inside are not treated as conditioned! + conditioned(cm2) +(x = 100.0, m = 1.0) ``` """ conditioned(model::Model) = conditioned(model.context) @@ -770,7 +779,7 @@ Return the fixed values in `model`. ```jldoctest julia> using Distributions -julia> using DynamicPPL: fixed, contextualize +julia> using DynamicPPL: fixed julia> @model function demo() m ~ Normal() @@ -780,70 +789,29 @@ demo (generic function with 2 methods) julia> m = demo(); -julia> # Returns all the variables we have fixed on + their values. - fixed(fix(m, x=100.0, m=1.0)) -(x = 100.0, m = 1.0) - -julia> # The rest of this is the same as the `condition` example above. - cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); - -julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) -true +julia> # Fix some values. + fm = fix(m, (; x = 100.0, m = 1.0)); -julia> keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # Returns all the variables we have fixed on, and their values. + fixed(fm) +(x = 100.0, m = 1.0) -julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); +julia> # If we prefix the model, the fixed variables will also be prefixed. + pm = prefix(fm, @varname(f)); fixed(pm) +Dict{VarName{:f}, Float64} with 2 entries: + f.x => 100.0 + f.m => 1.0 -julia> fixed(cm) -Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: - a.m => 1.0 +julia> # If we fix _after_ the prefix, the prefix is not applied. + pm2 = prefix(m, @varname(f)); fm2 = fix(pm2, (; x = 100.0, m = 1.0)); -julia> # Now `a.x` will be sampled. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: - a.x +julia> # When running this model, the variables inside are not treated as fixed! + fixed(fm2) +(x = 100.0, m = 1.0) ``` """ fixed(model::Model) = fixed(model.context) -""" - prefix(model::Model, x::VarName) - prefix(model::Model, x::Val{sym}) - prefix(model::Model, x::Any) - -Return `model` but with all random variables prefixed by `x`, where `x` is either: -- a `VarName` (e.g. `@varname(a)`), -- a `Val{sym}` (e.g. `Val(:a)`), or -- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that - this will introduce runtime overheads so is not recommended unless absolutely - necessary. - -# Examples - -```jldoctest -julia> using DynamicPPL: prefix - -julia> @model demo() = x ~ Dirac(1) -demo (generic function with 2 methods) - -julia> rand(prefix(demo(), @varname(my_prefix))) -(var"my_prefix.x" = 1,) - -julia> rand(prefix(demo(), Val(:my_prefix))) -(var"my_prefix.x" = 1,) -``` -""" -prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) -function prefix(model::Model, x::Val{sym}) where {sym} - return contextualize(model, PrefixContext(VarName{sym}(), model.context)) -end -function prefix(model::Model, x) - return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) -end - """ (model::Model)([rng, varinfo]) diff --git a/src/prefix.jl b/src/prefix.jl new file mode 100644 index 000000000..c8f258cac --- /dev/null +++ b/src/prefix.jl @@ -0,0 +1,108 @@ +""" + maybe_prefix(inner::Union{Nothing,<:VarName}, outer::Union{Nothing,<:VarName}) + +Prefix `inner` with the prefix `outer`. Both `inner` and `outer` can be either +`VarName`s or `Nothing`. + +Note that this differs from `AbstractPPL.prefix` in that it handles `nothing` values. +This can happen e.g. when prefixing a model that is not already prefixed; or when +executing submodels without automatic prefixing. +""" +maybe_prefix(inner::VarName, outer::VarName) = AbstractPPL.prefix(inner, outer) +maybe_prefix(vn::VarName, ::Nothing) = vn +maybe_prefix(::Nothing, vn::VarName) = vn +maybe_prefix(::Nothing, ::Nothing) = nothing + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end + +""" + DynamicPPL.prefix(model::Model, x::VarName) + DynamicPPL.prefix(model::Model, x::Val{sym}) + DynamicPPL.prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. + +If `x` is `nothing`, then the model is returned unchanged. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), @varname(my_prefix))) +(var"my_prefix.x" = 1,) + +julia> rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, ::Nothing) = model +function prefix(model::Model, vn::VarName) + # Add it to the model prefix field + new_prefix = maybe_prefix(model.prefix, vn) + # And also make sure to prefix any conditioned and fixed variables stored in the model + new_context = prefix_cond_and_fixed_variables(model.context, vn) + return Model(model.f, model.args, model.defaults, new_context, new_prefix) +end +prefix(model::Model, ::Val{sym}) where {sym} = prefix(model, VarName{sym}()) +prefix(model::Model, x) = return prefix(model, VarName{Symbol(x)}()) diff --git a/src/submodel.jl b/src/submodel.jl index 145bd42c9..c369b286b 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -163,6 +163,7 @@ to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}( """ DynamicPPL.tilde_assume!!( context::AbstractContext, + prefix::Union{VarName,Nothing}, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo @@ -171,9 +172,13 @@ to_submodel(m::Model, auto_prefix::Bool=true) = Submodel{typeof(m),auto_prefix}( Evaluate the submodel with the given context. """ function tilde_assume!!( - context::AbstractContext, right::DynamicPPL.Submodel, vn::VarName, vi::AbstractVarInfo + context::AbstractContext, + prefix::Union{VarName,Nothing}, + right::DynamicPPL.Submodel, + vn::VarName, + vi::AbstractVarInfo, ) - return _evaluate!!(right, vi, context, vn) + return _evaluate!!(right, vi, context, prefix, vn) end # When automatic prefixing is used, the submodel itself doesn't carry the @@ -182,28 +187,31 @@ end # passed into this function. # # `parent_context` here refers to the context of the model that contains the -# submodel. +# submodel. `parent_prefix` is the prefix applied to the parent model. function _evaluate!!( submodel::Submodel{M,AutoPrefix}, vi::AbstractVarInfo, parent_context::AbstractContext, - left_vn::VarName, + parent_prefix::Union{VarName,Nothing}, + vn::VarName, ) where {M<:Model,AutoPrefix} # First, we construct the context to be used when evaluating the submodel. There # are several considerations here: - # (1) We need to apply an appropriate PrefixContext when evaluating the submodel, but - # _only_ if automatic prefixing is supposed to be applied. - submodel_context_prefixed = if AutoPrefix - PrefixContext(left_vn, submodel.model.context) + # (1) Before even touching the contexts, we need to make sure that we apply + # automatic prefixing if it was requested. (If the prefix was manually applied, then + # `prefix()` will have been called by the user, and we don't need to do it again.) + submodel_prefix = if AutoPrefix + # Note that by the time we see it here (in `tilde_assume!!`), `vn` + # has already prefixed with `parent_prefix`, so no need to re-prefix it + vn else - submodel.model.context + parent_prefix end + submodel_model = DynamicPPL.prefix(submodel.model, submodel_prefix) # (2) We need to respect the leaf-context of the parent model. This, unfortunately, # means disregarding the leaf-context of the submodel. - submodel_context = setleafcontext( - submodel_context_prefixed, leafcontext(parent_context) - ) + submodel_context = setleafcontext(submodel_model.context, leafcontext(parent_context)) # (3) We need to use the parent model's context to wrap the whole thing, so that # e.g. if the user conditions the parent model, the conditioned variables will be @@ -211,7 +219,7 @@ function _evaluate!!( eval_context = setleafcontext(parent_context, submodel_context) # (4) Finally, we need to store that context inside the submodel. - model = contextualize(submodel.model, eval_context) + model = contextualize(submodel_model, eval_context) # Once that's all set up nicely, we can just _evaluate!! the wrapped model. This # returns a tuple of submodel.model's return value and the new varinfo. diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..d2f8484f0 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -18,9 +18,7 @@ using DynamicPPL: conditioned, fixed, hasconditioned_nested, - getconditioned_nested, - collapse_prefix_stack, - prefix_cond_and_fixed_variables + getconditioned_nested using LinearAlgebra: I using Random: Xoshiro @@ -48,16 +46,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() contexts = Dict( :default => DefaultContext(), :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), - :prefix => PrefixContext(@varname(x)), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), - :condition3 => ConditionContext( - (x=1.0,), - PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), - ), - :condition4 => ConditionContext((x=[1.0, missing],)), + :condition3 => ConditionContext((x=[1.0, missing],)), ) @testset "$(name)" for (name, context) in contexts @@ -118,89 +111,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "PrefixContext" begin - @testset "prefixing" begin - ctx = @inferred PrefixContext( - @varname(a), - PrefixContext( - @varname(b), - PrefixContext( - @varname(c), - PrefixContext( - @varname(d), - PrefixContext( - @varname(e), PrefixContext(@varname(f), DefaultContext()) - ), - ), - ), - ), - ) - vn = @varname(x) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test vn_prefixed == @varname(a.b.c.d.e.f.x) - - vn = @varname(x[1]) - vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) - @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) - end - - @testset "nested within arbitrary context stacks" begin - vn = @varname(x[1]) - ctx1 = PrefixContext(@varname(a)) - @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) - ctx2 = ConditionContext(Dict{VarName,Any}(), ctx1) - @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext(@varname(b), ctx2) - @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) - ctx4 = FixedContext(Dict(), ctx3) - @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) - end - - @testset "prefix_and_strip_contexts" begin - vn = @varname(x[1]) - ctx1 = PrefixContext(@varname(a)) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == DefaultContext() - - ctx2 = FixedContext((b=4,), PrefixContext(@varname(a))) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == FixedContext((b=4,)) - - ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == ConditionContext((a=1,)) - - ctx4 = FixedContext( - (b=4,), PrefixContext(@varname(a), ConditionContext((a=1,))) - ) - new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) - @test new_vn == @varname(a.x[1]) - @test new_ctx == FixedContext((b=4,), ConditionContext((a=1,))) - end - - @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix_vn = @varname(my_prefix) - context = DynamicPPL.PrefixContext(prefix_vn, DefaultContext()) - new_model = contextualize(model, context) - # Initialize a new varinfo with the prefixed model - _, varinfo = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) - # Extract the resulting varnames - vns_actual = Set(keys(varinfo)) - - # Extract the ground truth varnames - vns_expected = Set([ - AbstractPPL.prefix(vn, prefix_vn) for - vn in DynamicPPL.TestUtils.varnames(model) - ]) - - # Check that all variables are prefixed correctly. - @test vns_actual == vns_expected - end - end - @testset "ConditionContext" begin @testset "Nesting" begin @testset "NamedTuple" begin @@ -316,105 +226,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "PrefixContext + Condition/FixedContext interactions" begin - @testset "prefix_cond_and_fixed_variables" begin - c1 = ConditionContext((c=1, d=2)) - c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) - @test c1_prefixed isa ConditionContext - @test childcontext(c1_prefixed) isa DefaultContext - @test c1_prefixed.values[@varname(a.c)] == 1 - @test c1_prefixed.values[@varname(a.d)] == 2 - - c2 = FixedContext((f=1, g=2)) - c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) - @test c2_prefixed isa FixedContext - @test childcontext(c2_prefixed) isa DefaultContext - @test c2_prefixed.values[@varname(a.f)] == 1 - @test c2_prefixed.values[@varname(a.g)] == 2 - - c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) - c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) - c3_prefixed_child = childcontext(c3_prefixed) - @test c3_prefixed isa ConditionContext - @test c3_prefixed.values[@varname(a.c)] == 1 - @test c3_prefixed.values[@varname(a.d)] == 2 - @test c3_prefixed_child isa FixedContext - @test c3_prefixed_child.values[@varname(a.f)] == 1 - @test c3_prefixed_child.values[@varname(a.g)] == 2 - @test childcontext(c3_prefixed_child) isa DefaultContext - end - - @testset "collapse_prefix_stack" begin - # Utility function to make sure that there are no PrefixContexts in - # the context stack. - function has_no_prefixcontexts(ctx::AbstractContext) - return !(ctx isa PrefixContext) && ( - NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) - ) - end - - # Prefix -> Condition - c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) - c1 = collapse_prefix_stack(c1) - @test has_no_prefixcontexts(c1) - c1_vals = conditioned(c1) - @test length(c1_vals) == 2 - @test getvalue(c1_vals, @varname(a.c)) == 1 - @test getvalue(c1_vals, @varname(a.d)) == 2 - - # Condition -> Prefix - c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) - c2 = collapse_prefix_stack(c2) - @test has_no_prefixcontexts(c2) - c2_vals = conditioned(c2) - @test length(c2_vals) == 2 - @test getvalue(c2_vals, @varname(c)) == 1 - @test getvalue(c2_vals, @varname(d)) == 2 - - # Prefix -> Fixed - c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) - c3 = collapse_prefix_stack(c3) - c3_vals = fixed(c3) - @test length(c3_vals) == 2 - @test length(c3_vals) == 2 - @test getvalue(c3_vals, @varname(a.f)) == 1 - @test getvalue(c3_vals, @varname(a.g)) == 2 - - # Fixed -> Prefix - c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) - c4 = collapse_prefix_stack(c4) - @test has_no_prefixcontexts(c4) - c4_vals = fixed(c4) - @test length(c4_vals) == 2 - @test getvalue(c4_vals, @varname(f)) == 1 - @test getvalue(c4_vals, @varname(g)) == 2 - - # Prefix -> Condition -> Prefix -> Condition - c5 = PrefixContext( - @varname(a), - ConditionContext( - (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) - ), - ) - c5 = collapse_prefix_stack(c5) - @test has_no_prefixcontexts(c5) - c5_vals = conditioned(c5) - @test length(c5_vals) == 2 - @test getvalue(c5_vals, @varname(a.c)) == 1 - @test getvalue(c5_vals, @varname(a.b.d)) == 2 - - # Prefix -> Condition -> Prefix -> Fixed - c6 = PrefixContext( - @varname(a), - ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), - ) - c6 = collapse_prefix_stack(c6) - @test has_no_prefixcontexts(c6) - @test conditioned(c6) == Dict(@varname(a.c) => 1) - @test fixed(c6) == Dict(@varname(a.b.d) => 2) - end - end - @testset "InitContext" begin empty_varinfos = [ ("untyped+metadata", VarInfo()), diff --git a/test/prefix.jl b/test/prefix.jl new file mode 100644 index 000000000..57065689b --- /dev/null +++ b/test/prefix.jl @@ -0,0 +1,121 @@ +""" +Note that `test/submodel.jl` also contains a number of tests which make use of +prefixing functionality (more like end-to-end tests). This file contains what +are essentially unit tests for prefixing functions. +""" +module DPPLPrefixTests + +using DynamicPPL +# not exported +using DynamicPPL: FixedContext, prefix_cond_and_fixed_variables, childcontext +using Distributions +using Test + +@testset "prefix.jl" begin + @testset "prefix_cond_and_fixed_variables" begin + @testset "ConditionContext" begin + c1 = ConditionContext((c=1, d=2)) + c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) + @test c1_prefixed isa ConditionContext + @test childcontext(c1_prefixed) isa DefaultContext + @test length(c1_prefixed.values) == 2 + @test c1_prefixed.values[@varname(a.c)] == 1 + @test c1_prefixed.values[@varname(a.d)] == 2 + end + + @testset "FixedContext" begin + c2 = FixedContext((f=1, g=2)) + c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) + @test c2_prefixed isa FixedContext + @test childcontext(c2_prefixed) isa DefaultContext + @test length(c2_prefixed.values) == 2 + @test c2_prefixed.values[@varname(a.f)] == 1 + @test c2_prefixed.values[@varname(a.g)] == 2 + end + + @testset "Nested ConditionContext and FixedContext" begin + c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) + c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) + c3_prefixed_child = childcontext(c3_prefixed) + @test c3_prefixed isa ConditionContext + @test length(c3_prefixed.values) == 2 + @test c3_prefixed.values[@varname(a.c)] == 1 + @test c3_prefixed.values[@varname(a.d)] == 2 + @test c3_prefixed_child isa FixedContext + @test length(c3_prefixed_child.values) == 2 + @test c3_prefixed_child.values[@varname(a.f)] == 1 + @test c3_prefixed_child.values[@varname(a.g)] == 2 + @test childcontext(c3_prefixed_child) isa DefaultContext + end + end + + @testset "DynamicPPL.prefix(::Model, x)" begin + @model function demo() + x ~ Normal() + return y ~ Normal() + end + model = demo() + + @testset "No conditioning / fixing" begin + pmodel = DynamicPPL.prefix(model, @varname(a)) + @test pmodel.prefix == @varname(a) + vi = VarInfo(pmodel) + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + end + + @testset "Prefixing a conditioned model" begin + cmodel = model | (; x=1.0) + # Sanity check. + vi = VarInfo(cmodel) + @test Set(keys(vi)) == Set([@varname(y)]) + # Now prefix. + pcmodel = DynamicPPL.prefix(cmodel, @varname(a)) + @test pcmodel.prefix == @varname(a) + # Because the model was conditioned on `x` _prior_ to prefixing, + # the resulting `a.x` variable should also be conditioned. In + # other words, which variables are treated as conditioned should be + # invariant to prefixing. + vi = VarInfo(pcmodel) + @test Set(keys(vi)) == Set([@varname(a.y)]) + end + + @testset "Prefixing a fixed model" begin + # Same as above but for FixedContext rather than Condition. + fmodel = fix(model, (; y=1.0)) + # Sanity check. + vi = VarInfo(fmodel) + @test Set(keys(vi)) == Set([@varname(x)]) + # Now prefix. + pfmodel = DynamicPPL.prefix(fmodel, @varname(a)) + @test pfmodel.prefix == @varname(a) + # Because the model was conditioned on `x` _prior_ to prefixing, + # the resulting `a.x` variable should also be conditioned. In + # other words, which variables are treated as conditioned should be + # invariant to prefixing. + vi = VarInfo(pfmodel) + @test Set(keys(vi)) == Set([@varname(a.x)]) + end + + @testset "Conditioning a prefixed model" begin + # If the prefixing happens first, then we want to make sure that the + # user is forced to apply conditioning WITH the prefix. + pmodel = DynamicPPL.prefix(model, @varname(a)) + + # If this doesn't happen... + cpmodel_wrong = pmodel | (; x=1.0) + @test cpmodel_wrong.prefix == @varname(a) + vi = VarInfo(cpmodel_wrong) + # Then `a.x` will be `assume`d + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + + # If it does... + cpmodel_right = pmodel | (@varname(a.x) => 1.0) + @test cpmodel_right.prefix == @varname(a) + vi = VarInfo(cpmodel_right) + # Then `a.x` will be `observe`d + @test Set(keys(vi)) == Set([@varname(a.y)]) + end + end +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index b6a3f7bf6..c80969b53 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,7 @@ include("test_util.jl") include("varinfo.jl") include("simple_varinfo.jl") include("model.jl") + include("prefix.jl") include("distribution_wrappers.jl") include("logdensityfunction.jl") include("linking.jl") diff --git a/test/submodels.jl b/test/submodels.jl index 986aea1d0..6ccc6cede 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -140,43 +140,51 @@ end x ~ Normal() return y ~ Normal() end + # `g` and `gmanual` are the same model; just that one has automatic prefixing + # and the other manual. @model function g() + return b ~ to_submodel(f()) + end + @model function gmanual() return _unused ~ to_submodel(prefix(f(), :b), false) end @model function h() return a ~ to_submodel(g()) end - # No conditioning - vi = VarInfo(h()) - @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) - @test getlogjoint(vi) == - logpdf(Normal(), vi[@varname(a.b.x)]) + - logpdf(Normal(), vi[@varname(a.b.y)]) - - # Conditioning/fixing at the top level - op_h = op(h(), (@varname(a.b.x) => x_val)) - - # Conditioning/fixing at the second level - op_g = op(g(), (@varname(b.x) => x_val)) - @model function h2() - return a ~ to_submodel(op_g) - end - - # Conditioning/fixing at the very bottom - op_f = op(f(), (@varname(x) => x_val)) - @model function g2() - return _unused ~ to_submodel(prefix(op_f, :b), false) - end - @model function h3() - return a ~ to_submodel(g2()) - end - - models = [("top", op_h), ("middle", h2()), ("bottom", h3())] - @testset "$name" for (name, model) in models - vi = VarInfo(model) - @test Set(keys(vi)) == Set([@varname(a.b.y)]) - @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + @testset "$name prefix" for (name, g_model) in + [("auto", g()), ("manual", gmanual())] + # No conditioning + vi = VarInfo(h()) + @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogjoint(vi) == + logpdf(Normal(), vi[@varname(a.b.x)]) + + logpdf(Normal(), vi[@varname(a.b.y)]) + + # Conditioning/fixing at the top level + op_h = op(h(), (@varname(a.b.x) => x_val)) + + # Conditioning/fixing at the second level + op_g = op(g_model, (@varname(b.x) => x_val)) + @model function h2() + return a ~ to_submodel(op_g) + end + + # Conditioning/fixing at the very bottom + op_f = op(f(), (@varname(x) => x_val)) + @model function g2() + return _unused ~ to_submodel(prefix(op_f, :b), false) + end + @model function h3() + return a ~ to_submodel(g2()) + end + + models = [("top", op_h), ("middle", h2()), ("bottom", h3())] + @testset "$name" for (name, model) in models + vi = VarInfo(model) + @test Set(keys(vi)) == Set([@varname(a.b.y)]) + @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + end end end end