Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.

### Breaking changes

`PrefixContext` is removed, etc., etc.... (I'll write this if I actually have to)

## 0.38.4

Improve performance of VarNamedVector. It should now be very nearly on par with Metadata for all models we've benchmarked on.
Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`.

```@docs
DefaultContext
PrefixContext
ConditionContext
InitContext
```
Expand Down
5 changes: 2 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ export AbstractVarInfo,
# Contexts
contextualize,
DefaultContext,
PrefixContext,
ConditionContext,
# Tilde pipeline
tilde_assume!!,
Expand Down Expand Up @@ -174,9 +173,9 @@ include("contexts.jl")
include("contexts/default.jl")
include("contexts/init.jl")
include("contexts/transformation.jl")
include("contexts/prefix.jl")
include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl
include("contexts/conditionfix.jl")
include("model.jl")
include("prefix.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("submodel.jl")
Expand Down
53 changes: 28 additions & 25 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function make_varname_expression(expr)
end

"""
isassumption(expr[, vn])
isassumption(expr[, left_vn])

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.
Expand All @@ -55,16 +55,19 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.

If `vn` is specified, it will be assumed to refer to a expression which
If `left_vn` is specified, it will be assumed to refer to a expression which
evaluates to a `VarName`, and this will be used in the subsequent checks.
If `vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
If `left_vn` is not specified, `AbstractPPL.varname(expr, need_concretize(expr))` will be
used in its place.
"""
function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr))
function isassumption(expr::Union{Expr,Symbol}, left_vn=make_varname_expression(expr))
@gensym vn
return quote
if $(DynamicPPL.contextual_isassumption)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# TODO(penelopeysm): This re-prefixing seems a bit wasteful. I'd really like
# the whole `isassumption` thing to be simplified, though, so I'll
# leave it till later.
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
if $(DynamicPPL.contextual_isassumption)(__model__.context, $vn)
# Considered an assumption by `__model__.context` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
Expand All @@ -78,8 +81,8 @@ function isassumption(expr::Union{Expr,Symbol}, vn=make_varname_expression(expr)
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
if !($(DynamicPPL.inargnames)($left_vn, __model__)) ||
$(DynamicPPL.inmissings)($left_vn, __model__)
true
else
$(maybe_view(expr)) === missing
Expand All @@ -99,7 +102,7 @@ isassumption(expr) = :(false)

Return `true` if `vn` is considered an assumption by `context`.
"""
function contextual_isassumption(context::AbstractContext, vn)
function contextual_isassumption(context::AbstractContext, vn::VarName)
if hasconditioned_nested(context, vn)
val = getconditioned_nested(context, vn)
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
Expand All @@ -115,9 +118,7 @@ end

isfixed(expr, vn) = false
function isfixed(::Union{Symbol,Expr}, vn)
return :($(DynamicPPL.contextual_isfixed)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
))
return :($(DynamicPPL.contextual_isfixed)(__model__.context, $vn))
end

"""
Expand Down Expand Up @@ -413,7 +414,9 @@ function generate_assign(left, right)
return quote
$right_val = $right
if $(DynamicPPL.is_extracting_values)(__varinfo__)
$vn = $(DynamicPPL.prefix)(__model__.context, $(make_varname_expression(left)))
$vn = $(DynamicPPL.maybe_prefix)(
$(make_varname_expression(left)), __model__.prefix
)
__varinfo__ = $(map_accumulator!!)(
$acc -> push!($acc, $vn, $right_val), __varinfo__, Val(:ValuesAsInModel)
)
Expand Down Expand Up @@ -448,24 +451,23 @@ function generate_tilde(left, right)

# Otherwise it is determined by the model or its value,
# if the LHS represents an observation
@gensym vn isassumption value dist
@gensym left_vn vn isassumption value dist

return quote
$dist = $right
$vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$isassumption = $(DynamicPPL.isassumption(left, vn))
$left_vn = $(DynamicPPL.resolve_varnames)($(make_varname_expression(left)), $dist)
$vn = $(DynamicPPL.maybe_prefix)($left_vn, __model__.prefix)
$isassumption = $(DynamicPPL.isassumption(left, left_vn))
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
$left = $(DynamicPPL.getfixed_nested)(__model__.context, $vn)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(
__model__.context, $(DynamicPPL.prefix)(__model__.context, $vn)
)
# If `left_vn` is not in `argnames`, we need to make sure that the variable is defined.
# (Note: we use the unprefixed `left_vn` here rather than `vn` which will have had
# prefixes applied!)
if !$(DynamicPPL.inargnames)($left_vn, __model__)
$left = $(DynamicPPL.getconditioned_nested)(__model__.context, $vn)
end

$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
Expand Down Expand Up @@ -495,6 +497,7 @@ function generate_tilde_assume(left, right, vn)
return quote
$value, __varinfo__ = $(DynamicPPL.tilde_assume!!)(
__model__.context,
__model__.prefix,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
)
Expand Down
13 changes: 11 additions & 2 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ setleafcontext(::IsLeaf, ::IsLeaf, left::AbstractContext, right::AbstractContext
"""
DynamicPPL.tilde_assume!!(
context::AbstractContext,
prefix::Union{VarName,Nothing},
right::Distribution,
vn::VarName,
vi::AbstractVarInfo
Expand All @@ -134,13 +135,21 @@ sampled value and updated `vi`.

`vn` is the VarName on the left-hand side of the tilde statement.

`prefix` is the currently active prefix; this is `nothing` if there is no active prefix.
For example, in `a ~ to_submodel(inner_model)`, when executing `inner_model`, the active
prefix will be `@varname(a)`.

This function should return a tuple `(x, vi)`, where `x` is the sampled value (which
must be in unlinked space!) and `vi` is the updated VarInfo.
"""
function tilde_assume!!(
context::AbstractContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
context::AbstractContext,
prefix::Union{VarName,Nothing},
right::Distribution,
vn::VarName,
vi::AbstractVarInfo,
)
return tilde_assume!!(childcontext(context), right, vn, vi)
return tilde_assume!!(childcontext(context), prefix, right, vn, vi)
end

"""
Expand Down
129 changes: 0 additions & 129 deletions src/contexts/conditionfix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn)
function hasconditioned_nested(::IsParent, context, vn)
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
end
function hasconditioned_nested(context::PrefixContext, vn)
return hasconditioned_nested(collapse_prefix_stack(context), vn)
end

"""
getconditioned_nested(context, vn)
Expand All @@ -101,9 +98,6 @@ end
function getconditioned_nested(::IsLeaf, context, vn)
return error("context $(context) does not contain value for $vn")
end
function getconditioned_nested(context::PrefixContext, vn)
return getconditioned_nested(collapse_prefix_stack(context), vn)
end
function getconditioned_nested(::IsParent, context, vn)
return if hasconditioned(context, vn)
getconditioned(context, vn)
Expand Down Expand Up @@ -172,9 +166,6 @@ function conditioned(context::ConditionContext)
# precedence over decendants of `context`.
return _merge(context.values, conditioned(childcontext(context)))
end
function conditioned(context::PrefixContext)
return conditioned(collapse_prefix_stack(context))
end

struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
Expand Down Expand Up @@ -237,9 +228,6 @@ hasfixed_nested(::IsLeaf, context, vn) = hasfixed(context, vn)
function hasfixed_nested(::IsParent, context, vn)
return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn)
end
function hasfixed_nested(context::PrefixContext, vn)
return hasfixed_nested(collapse_prefix_stack(context), vn)
end

"""
getfixed_nested(context, vn)
Expand All @@ -255,9 +243,6 @@ end
function getfixed_nested(::IsLeaf, context, vn)
return error("context $(context) does not contain value for $vn")
end
function getfixed_nested(context::PrefixContext, vn)
return getfixed_nested(collapse_prefix_stack(context), vn)
end
function getfixed_nested(::IsParent, context, vn)
return if hasfixed(context, vn)
getfixed(context, vn)
Expand Down Expand Up @@ -351,117 +336,3 @@ function fixed(context::FixedContext)
# precedence over decendants of `context`.
return _merge(context.values, fixed(childcontext(context)))
end
function fixed(context::PrefixContext)
return fixed(collapse_prefix_stack(context))
end

###########################################################################
### Interaction of PrefixContext with ConditionContext and FixedContext ###
###########################################################################

"""
collapse_prefix_stack(context::AbstractContext)

Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove
the `PrefixContext`s from the context stack.

!!! note
If you are reading this docstring, you might probably be interested in a more
thorough explanation of how PrefixContext and ConditionContext / FixedContext
interact with one another, especially in the context of submodels.
The DynamicPPL documentation contains [a separate page on this
topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/)
which explains this in much more detail.

```jldoctest
julia> using DynamicPPL: collapse_prefix_stack

julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, )));

julia> collapse_prefix_stack(c1)
ConditionContext(Dict(a.x => 1), DefaultContext())

julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both.
c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,)))));

julia> collapsed = collapse_prefix_stack(c2);

julia> # `collapsed` really looks something like this:
# ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext())
# To avoid fragility arising from the order of the keys in the doctest, we test
# this indirectly:
collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)]
(1, 2)
```
"""
function collapse_prefix_stack(context::PrefixContext)
# Collapse the child context (thus applying any inner prefixes first)
collapsed = collapse_prefix_stack(childcontext(context))
# Prefix any conditioned variables with the current prefix
# Note: prefix_conditioned_variables is O(N) in the depth of the context stack.
# So is this function. In the worst case scenario, this is O(N^2) in the
# depth of the context stack.
return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix)
end
function collapse_prefix_stack(context::AbstractContext)
return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context)
end
collapse_prefix_stack(::IsLeaf, context) = context
function collapse_prefix_stack(::IsParent, context)
new_child_context = collapse_prefix_stack(childcontext(context))
return setchildcontext(context, new_child_context)
end

"""
prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName)

Prefix all the conditioned and fixed variables in a given context with a single
`prefix`.

```jldoctest
julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext

julia> c1 = ConditionContext((a=1, ))
ConditionContext((a = 1,), DefaultContext())

julia> prefix_cond_and_fixed_variables(c1, @varname(y))
ConditionContext(Dict(y.a => 1), DefaultContext())
```
"""
function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName)
# Replace the prefix of the conditioned variables
vn_dict = to_varname_dict(ctx.values)
prefixed_vn_dict = Dict(
AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict
)
# Prefix the child context as well
prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix)
return ConditionContext(prefixed_vn_dict, prefixed_child_ctx)
end
function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName)
# Replace the prefix of the conditioned variables
vn_dict = to_varname_dict(ctx.values)
prefixed_vn_dict = Dict(
AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict
)
# Prefix the child context as well
prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix)
return FixedContext(prefixed_vn_dict, prefixed_child_ctx)
end
function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName)
return prefix_cond_and_fixed_variables(
NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix
)
end
function prefix_cond_and_fixed_variables(
::IsLeaf, context::AbstractContext, prefix::VarName
)
return context
end
function prefix_cond_and_fixed_variables(
::IsParent, context::AbstractContext, prefix::VarName
)
return setchildcontext(
context, prefix_cond_and_fixed_variables(childcontext(context), prefix)
)
end
12 changes: 10 additions & 2 deletions src/contexts/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,22 @@ NodeTrait(::DefaultContext) = IsLeaf()

"""
DynamicPPL.tilde_assume!!(
::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
::DefaultContext,
prefix::Union{VarName,Nothing},
right::Distribution,
vn::VarName,
vi::AbstractVarInfo
)

Handle assumed variables. For `DefaultContext`, this function extracts the value associated
with `vn` from `vi`, If `vi` does not contain an appropriate value then this will error.
"""
function tilde_assume!!(
::DefaultContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
::DefaultContext,
::Union{VarName,Nothing},
right::Distribution,
vn::VarName,
vi::AbstractVarInfo,
)
y = getindex_internal(vi, vn)
f = from_maybe_linked_internal_transform(vi, vn, right)
Expand Down
6 changes: 5 additions & 1 deletion src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ end
NodeTrait(::InitContext) = IsLeaf()

function tilde_assume!!(
ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
ctx::InitContext,
::Union{VarName,Nothing},
dist::Distribution,
vn::VarName,
vi::AbstractVarInfo,
)
in_varinfo = haskey(vi, vn)
# `init()` always returns values in original space, i.e. possibly
Expand Down
Loading