Skip to content

Commit 044f6c3

Browse files
committed
added automatic prefixing for sub-models involved in ~ statements
1 parent 3c204d9 commit 044f6c3

File tree

3 files changed

+58
-53
lines changed

3 files changed

+58
-53
lines changed

src/compiler.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,10 @@ end
179179
check_tilde_rhs(x::Distribution) = x
180180
check_tilde_rhs(x::AbstractArray{<:Distribution}) = x
181181
check_tilde_rhs(x::ReturnedModelWrapper) = x
182-
check_tilde_rhs(x::Sampleable) = Sampleable(check_tilde_rhs(x.model))
182+
function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix}
183+
model = check_tilde_rhs(x.model)
184+
return Sampleable{typeof(model),AutoPrefix}(model)
185+
end
183186

184187
"""
185188
unwrap_right_vn(right, vn)

src/context_implementations.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ probability of `vi` with the returned value.
105105
function tilde_assume!!(context, right, vn, vi)
106106
return if is_rhs_model(right)
107107
# Prefix the variables using the `vn`.
108-
rand_like!!(right, prefix(context Symbol(vn)), vi)
108+
rand_like!!(right, should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context, vi)
109109
else
110110
value, logp, vi = tilde_assume(context, right, vn, vi)
111111
value, acclogp_assume!!(context, vi, logp)

src/model.jl

+53-51
Original file line numberDiff line numberDiff line change
@@ -1263,24 +1263,43 @@ Abstract type for type indicating that something is "distributional".
12631263
"""
12641264
abstract type Distributional end
12651265

1266+
"""
1267+
should_auto_prefix(distributional)
1268+
1269+
Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise.
1270+
"""
1271+
function should_auto_prefix end
1272+
1273+
"""
1274+
is_rhs_model(x)
1275+
1276+
Return `true` if the `distributional` is a model, and `false` otherwise.
1277+
"""
1278+
function is_rhs_model end
1279+
12661280
"""
12671281
Sampleable{M} <: Distributional
12681282
12691283
A wrapper around a model indicating it is sampleable.
12701284
"""
1271-
struct Sampleable{M} <: Distributional
1285+
struct Sampleable{M,AutoPrefix} <: Distributional
12721286
model::M
12731287
end
12741288

1289+
should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix
12751290
is_rhs_model(x::Sampleable) = is_rhs_model(x.model)
12761291

12771292
# TODO: Export this if it end up having a purpose beyond `to_submodel`.
12781293
"""
1279-
to_sampleable(model)
1294+
to_sampleable(model[, auto_prefix])
12801295
12811296
Return a wrapper around `model` indicating it is sampleable.
1297+
1298+
# Arguments
1299+
- `model::Model`: the model to wrap.
1300+
- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`.
12821301
"""
1283-
to_sampleable(model) = Sampleable(model)
1302+
to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model)
12841303

12851304
"""
12861305
rand_like!!(model_wrap, context, varinfo)
@@ -1326,7 +1345,7 @@ Return a `model` wrapper indicating that it is a model over its return-values.
13261345
returned(model::Model) = ReturnedModelWrapper(model)
13271346

13281347
"""
1329-
to_submodel(model::Model)
1348+
to_submodel(model::Model[, auto_prefix::Bool])
13301349
13311350
Return a model wrapper indicating that it is a sampleable model over the return-values.
13321351
@@ -1338,9 +1357,13 @@ the model can be sampled from but not necessarily evaluated for its log density.
13381357
such as [`condition`](@ref) or [`fix`](@ref), will also not work with `to_submodel`.
13391358
13401359
!!! warning
1341-
It's generally recommended to use [`prefix(::Model, input)`](@ref) when working with submodels
1342-
to ensure that the variables in `model` are unique and do not clash with other variables in the
1343-
parent model or in other submodels.
1360+
To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`.
1361+
If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly.
1362+
1363+
# Arguments
1364+
- `model::Model`: the model to wrap.
1365+
- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand
1366+
side of the `~` statement. Default: `true`.
13441367
13451368
# Examples
13461369
@@ -1361,7 +1384,7 @@ When we sample from the model `demo2(missing, 0.4)` random variable `x` will be
13611384
```jldoctest submodel-to_submodel
13621385
julia> vi = VarInfo(demo2(missing, 0.4));
13631386
1364-
julia> @varname(x) in keys(vi)
1387+
julia> @varname(var\"a.x\") in keys(vi)
13651388
true
13661389
```
13671390
@@ -1375,29 +1398,42 @@ false
13751398
We can check that the log joint probability of the model accumulated in `vi` is correct:
13761399
13771400
```jldoctest submodel-to_submodel
1378-
julia> x = vi[@varname(x)];
1401+
julia> x = vi[@varname(var\"a.x\")];
13791402
13801403
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
13811404
true
13821405
```
13831406
1384-
## With prefixing
1407+
## Without automatic prefixing
1408+
As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically
1409+
prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel
1410+
will not be prefixed.
13851411
```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions)
13861412
julia> @model function demo1(x)
13871413
x ~ Normal()
13881414
return 1 + abs(x)
13891415
end;
13901416
1417+
julia> @model function demo2_no_prefix(x, z)
1418+
a ~ to_submodel(demo1(x), false)
1419+
return z ~ Uniform(-a, 1)
1420+
end;
1421+
1422+
julia> vi = VarInfo(demo2_no_prefix(missing, 0.4));
1423+
1424+
julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x`
1425+
true
1426+
```
1427+
However, not using prefixing is generally not recommended as it can lead to variable name clashes
1428+
unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing
1429+
will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref):
1430+
```jldoctest submodel-to_submodel-prefix
13911431
julia> @model function demo2(x, y, z)
1392-
a ~ to_submodel(prefix(demo1(x), :sub1))
1393-
b ~ to_submodel(prefix(demo1(y), :sub2))
1432+
a ~ to_submodel(prefix(demo1(x), :sub1), false)
1433+
b ~ to_submodel(prefix(demo1(y), :sub2), false)
13941434
return z ~ Uniform(-a, b)
13951435
end;
1396-
```
13971436
1398-
When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and
1399-
`sub2.x` will be sampled:
1400-
```jldoctest submodel-to_submodel-prefix
14011437
julia> vi = VarInfo(demo2(missing, missing, 0.4));
14021438
14031439
julia> @varname(var"sub1.x") in keys(vi)
@@ -1432,40 +1468,6 @@ julia> getlogp(vi) ≈ logprior + loglikelihood
14321468
true
14331469
```
14341470
1435-
## Different ways of setting the prefix
1436-
```jldoctest submodel-to_submodel-prefix-alts; setup=:(using DynamicPPL, Distributions)
1437-
julia> @model inner() = x ~ Normal()
1438-
inner (generic function with 2 methods)
1439-
1440-
julia> # When `prefix` is unspecified, no prefix is used.
1441-
@model submodel_noprefix() = a ~ to_submodel(inner())
1442-
submodel_noprefix (generic function with 2 methods)
1443-
1444-
julia> @varname(x) in keys(VarInfo(submodel_noprefix()))
1445-
true
1446-
1447-
julia> # Using a static string.
1448-
@model submodel_prefix_string() = a ~ to_submodel(prefix(inner(), "my prefix"))
1449-
submodel_prefix_string (generic function with 2 methods)
1450-
1451-
julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string()))
1452-
true
1453-
1454-
julia> # Using string interpolation.
1455-
@model submodel_prefix_interpolation() = a ~ to_submodel(prefix(inner(), "\$(nameof(inner()))"))
1456-
submodel_prefix_interpolation (generic function with 2 methods)
1457-
1458-
julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation()))
1459-
true
1460-
1461-
julia> # Or using some arbitrary expression.
1462-
@model submodel_prefix_expr() = a ~ to_submodel(prefix(inner(), 1 + 2))
1463-
submodel_prefix_expr (generic function with 2 methods)
1464-
1465-
julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr()))
1466-
true
1467-
```
1468-
14691471
## Usage as likelihood is illegal
14701472
14711473
Note that it is illegal to use a `to_submodel` model as a likelihood in another model:
@@ -1483,4 +1485,4 @@ julia> model()
14831485
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
14841486
[...]
14851487
"""
1486-
to_submodel(model::Model) = to_sampleable(returned(model))
1488+
to_submodel(model::Model, auto_prefix::Bool=true) = to_sampleable(returned(model), auto_prefix)

0 commit comments

Comments
 (0)