From 45451f77408bac8721fa16d5bbcdc796ac2ec814 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 25 Nov 2024 10:22:19 +0100 Subject: [PATCH] removed redundant `SampleableModelWrapper` in favour of `ReturnedModelWrapper` + introduced `rand_like!!` to hide explicit calls to `_evaluate!!` --- src/compiler.jl | 194 ++++----------------------------- src/context_implementations.jl | 20 +++- src/model.jl | 194 +++++++++++++++++++++++++++++++++ 3 files changed, 231 insertions(+), 177 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f19cd5c75..97b6aa308 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,152 +1,5 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) -struct SampleableModelWrapper{M} - model::M -end - -""" - to_sampleable(model::Model) - -Return a wrapper around `model` which indicates that this model can only be sampled from. - -This is mainly meant to be used on the right-hand side of a `~` operator to indicate that -the model can be sampled from but not necessarily evaluated for its log density. - -!!! warning - Note that other operations that one typically associate with expressions of the form `left ~ right` - such as [`condition`](@ref) or [`fix`](@ref), will also not work with `to_sampleable`. - -!!! warning - It's generally recommended to use [`prefix(::Model, input)`](@ref) when working with submodels - to ensure that the variables in `model` are unique and do not clash with other variables in the - parent model or in other submodels. - -# Examples - -## Simple example -```jldoctest submodel-to-sampleable; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y) - a ~ to_sampleable(demo1(x)) - return y ~ Uniform(0, a) - end; -``` - -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: -```jldoctest submodel-to-sampleable -julia> vi = VarInfo(demo2(missing, 0.4)); - -julia> @varname(x) in keys(vi) -true -``` - -Variable `a` is not tracked since it can be computed from the random variable `x` that was -tracked when running `demo1`: -```jldoctest submodel-to-sampleable -julia> @varname(a) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to-sampleable -julia> x = vi[@varname(x)]; - -julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) -true -``` - -## With prefixing -```jldoctest submodel-to-sampleable-prefix; setup=:(using Distributions) -julia> @model function demo1(x) - x ~ Normal() - return 1 + abs(x) - end; - -julia> @model function demo2(x, y, z) - a ~ to_sampleable(prefix(demo1(x), :sub1)) - b ~ to_sampleable(prefix(demo1(y), :sub2)) - return z ~ Uniform(-a, b) - end; -``` - -When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and -`sub2.x` will be sampled: -```jldoctest submodel-to-sampleable-prefix -julia> vi = VarInfo(demo2(missing, missing, 0.4)); - -julia> @varname(var"sub1.x") in keys(vi) -true - -julia> @varname(var"sub2.x") in keys(vi) -true -``` - -Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and -`sub2.x` that were tracked when running `demo1`: -```jldoctest submodel-to-sampleable-prefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to-sampleable-prefix -julia> sub1_x = vi[@varname(var"sub1.x")]; - -julia> sub2_x = vi[@varname(var"sub2.x")]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); - -julia> getlogp(vi) ≈ logprior + loglikelihood -true -``` - -## Different ways of setting the prefix -```jldoctest submodel-to-sampleable-prefix-alts; setup=:(using DynamicPPL, Distributions) -julia> @model inner() = x ~ Normal() -inner (generic function with 2 methods) - -julia> # When `prefix` is unspecified, no prefix is used. - @model submodel_noprefix() = a ~ to_sampleable(inner()) -submodel_noprefix (generic function with 2 methods) - -julia> @varname(x) in keys(VarInfo(submodel_noprefix())) -true - -julia> # Using a static string. - @model submodel_prefix_string() = a ~ to_sampleable(prefix(inner(), "my prefix")) -submodel_prefix_string (generic function with 2 methods) - -julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) -true - -julia> # Using string interpolation. - @model submodel_prefix_interpolation() = a = to_sampleable(prefix(inner(), "\$(nameof(inner()))")) -submodel_prefix_interpolation (generic function with 2 methods) - -julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) -true - -julia> # Or using some arbitrary expression. - @model submodel_prefix_expr() = a ~ to_sampleable(prefix(inner(), 1 + 2)) -submodel_prefix_expr (generic function with 2 methods) - -julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) -true -``` -""" -to_sampleable(model::Model) = SampleableModelWrapper(model) - """ need_concretize(expr) @@ -325,6 +178,7 @@ function check_tilde_rhs(@nospecialize(x)) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x +check_tilde_rhs(x::ReturnedModelWrapper) = x """ unwrap_right_vn(right, vn) @@ -574,34 +428,28 @@ function generate_tilde(left, right) # more selective with our escape. Until that's the case, we remove them all. return quote $dist = $right - - if $dist isa $(SampleableModelWrapper) - $left, __varinfo__ = $(_evaluate!!)($dist.model, __varinfo__, __context__) - $left + $vn = $(DynamicPPL.resolve_varnames)( + $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist + ) + $isassumption = $(DynamicPPL.isassumption(left, vn)) + if $(DynamicPPL.isfixed(left, vn)) + $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) + elseif $isassumption + $(generate_tilde_assume(left, dist, vn)) else - $vn = $(DynamicPPL.resolve_varnames)( - $(AbstractPPL.drop_escape(varname(left, need_concretize(left)))), $dist - ) - $isassumption = $(DynamicPPL.isassumption(left, vn)) - if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) - elseif $isassumption - $(generate_tilde_assume(left, dist, vn)) - else - # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) - end - - $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($dist), - $(maybe_view(left)), - $vn, - __varinfo__, - ) - $value + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) end + + $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( + __context__, + $(DynamicPPL.check_tilde_rhs)($dist), + $(maybe_view(left)), + $vn, + __varinfo__, + ) + $value end end end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f3c5171b0..119fa457b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -141,8 +141,12 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) - value, logp, vi = tilde_assume(context, right, vn, vi) - return value, acclogp_assume!!(context, vi, logp) + return if is_rhs_model(right) + rand_like!!(right, context, vi) + else + value, logp, vi = tilde_assume(context, right, vn, vi) + value, acclogp_assume!!(context, vi, logp) + end end # observe @@ -197,6 +201,7 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati and indices; if needed, these can be accessed through this function, though. """ function tilde_observe!!(context, right, left, vname, vi) + is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported")) return tilde_observe!!(context, right, left, vi) end @@ -210,6 +215,7 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) + is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported")) logp, vi = tilde_observe(context, right, left, vi) return left, acclogp_observe!!(context, vi, logp) end @@ -420,8 +426,12 @@ model inputs), accumulate the log probability, and return the sampled value and Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) - value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp_assume!!(context, vi, logp), vi + return if is_rhs_model(right) + rand_like!!(right, context, vi) + else + value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) + value, acclogp_assume!!(context, vi, logp) + end end # `dot_assume` @@ -672,6 +682,7 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor name and indices; if needed, these can be accessed through this function, though. """ function dot_tilde_observe!!(context, right, left, vn, vi) + is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported")) return dot_tilde_observe!!(context, right, left, vi) end @@ -684,6 +695,7 @@ probability, and return the observed value and updated `vi`. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) + is_rhs_model(right) && throw(ArgumentError("`~` with a model on the right-hand side of an observe statement is not supported")) logp, vi = dot_tilde_observe(context, right, left, vi) return left, acclogp_observe!!(context, vi, logp) end diff --git a/src/model.jl b/src/model.jl index a05b7c444..3cf722868 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1250,3 +1250,197 @@ end function returned(model::Model, values, keys) return returned(model, NamedTuple{keys}(values)) end + +""" + is_rhs_model(x) + +Return `true` if `x` is a model or model wrapper, and `false` otherwise. +""" +is_rhs_model(x) = false + + +""" + to_sampleable(model) + +Return a wrapper around `model` indicating it is sampleable. +""" +function to_sampleable end + + +""" + ReturnedModelWrapper + +A wrapper around a model indicating it is a model over its return values. + +This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. +""" +struct ReturnedModelWrapper{M<:Model} + model::M +end + +is_rhs_model(::ReturnedModelWrapper) = true + +function rand_like!!(model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo) + # Return's the value and the (possibly mutated) varinfo. + return _evaluate!!(model_wrap.model, varinfo, context) +end + +""" + returned(model::Model) + +Return a wrapper around `model` which indicates that this model can only be sampled from. + +This is mainly meant to be used on the right-hand side of a `~` operator to indicate that +the model can be sampled from but not necessarily evaluated for its log density. + +!!! warning + Note that other operations that one typically associate with expressions of the form `left ~ right` + such as [`condition`](@ref) or [`fix`](@ref), will also not work with `returned`. + +!!! warning + It's generally recommended to use [`prefix(::Model, input)`](@ref) when working with submodels + to ensure that the variables in `model` are unique and do not clash with other variables in the + parent model or in other submodels. + +# Examples + +## Simple example +```jldoctest submodel-returned; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y) + a ~ returned(demo1(x)) + return y ~ Uniform(0, a) + end; +``` + +When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: +```jldoctest submodel-returned +julia> vi = VarInfo(demo2(missing, 0.4)); + +julia> @varname(x) in keys(vi) +true +``` + +Variable `a` is not tracked since it can be computed from the random variable `x` that was +tracked when running `demo1`: +```jldoctest submodel-returned +julia> @varname(a) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-returned +julia> x = vi[@varname(x)]; + +julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +true +``` + +## With prefixing +```jldoctest submodel-returned-prefix; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y, z) + a ~ returned(prefix(demo1(x), :sub1)) + b ~ returned(prefix(demo1(y), :sub2)) + return z ~ Uniform(-a, b) + end; +``` + +When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and +`sub2.x` will be sampled: +```jldoctest submodel-returned-prefix +julia> vi = VarInfo(demo2(missing, missing, 0.4)); + +julia> @varname(var"sub1.x") in keys(vi) +true + +julia> @varname(var"sub2.x") in keys(vi) +true +``` + +Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and +`sub2.x` that were tracked when running `demo1`: +```jldoctest submodel-returned-prefix +julia> @varname(a) in keys(vi) +false + +julia> @varname(b) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-returned-prefix +julia> sub1_x = vi[@varname(var"sub1.x")]; + +julia> sub2_x = vi[@varname(var"sub2.x")]; + +julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); + +julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); + +julia> getlogp(vi) ≈ logprior + loglikelihood +true +``` + +## Different ways of setting the prefix +```jldoctest submodel-returned-prefix-alts; setup=:(using DynamicPPL, Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> # When `prefix` is unspecified, no prefix is used. + @model submodel_noprefix() = a ~ returned(inner()) +submodel_noprefix (generic function with 2 methods) + +julia> @varname(x) in keys(VarInfo(submodel_noprefix())) +true + +julia> # Using a static string. + @model submodel_prefix_string() = a ~ returned(prefix(inner(), "my prefix")) +submodel_prefix_string (generic function with 2 methods) + +julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) +true + +julia> # Using string interpolation. + @model submodel_prefix_interpolation() = a ~ returned(prefix(inner(), "\$(nameof(inner()))")) +submodel_prefix_interpolation (generic function with 2 methods) + +julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) +true + +julia> # Or using some arbitrary expression. + @model submodel_prefix_expr() = a ~ returned(prefix(inner(), 1 + 2)) +submodel_prefix_expr (generic function with 2 methods) + +julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) +true +``` + +## Usage as likelihood is illegal + +Note that it is illegal to use a `returned` model as a likelihood in another model: + +```jldoctest submodel-returned-illegal; setup=:(using Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> @model illegal_likelihood() = a ~ returned(inner()) +illegal_likelihood (generic function with 2 methods) + +julia> model = illegal_likelihood() | (a = 1.0,); + +julia> model() +ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +[...] +""" +returned(model::Model) = ReturnedModelWrapper(model)