Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds @returned_quantities macro #696

Open
wants to merge 57 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
5c746c4
Added `@returned_quantities` macro
torfjelde Oct 23, 2024
0b081b7
Added `@returned_quantities` to the docs
torfjelde Oct 23, 2024
dc699a5
Fixed names of doctests for `@returned_quantities`
torfjelde Oct 23, 2024
7067695
Update src/submodel_macro.jl
torfjelde Oct 24, 2024
8cb0796
Added `@prefix` macro which calls `prefix` with a `Val` argument to
torfjelde Oct 29, 2024
2d887c9
Convert the result of `prefix_expr` in `@prefix` into a `Sybmol`
torfjelde Oct 29, 2024
692cfff
Export `prefix` and `@prefix`
torfjelde Oct 29, 2024
32fd6b9
Updated docstring for `@returned_quantities`
torfjelde Oct 29, 2024
5478fb3
Fixed bug in `rand` for `Model` where it would duplicate the non-leaf
torfjelde Oct 29, 2024
5fe65b3
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
torfjelde Oct 29, 2024
9e0730f
Update src/contexts.jl
torfjelde Oct 29, 2024
cc3af46
Added `prefix` and `@prefix` to docs
torfjelde Oct 29, 2024
720053a
removed the prefix=... syntax for `@returned_quantities`
torfjelde Oct 31, 2024
fe0403f
added deprecation.jl + deprecated `generated_quantities` in favour of…
torfjelde Oct 31, 2024
55b95a1
removed export of `prefix` and `generated_quantities` (the latter is
torfjelde Oct 31, 2024
34fb6bd
updated `DynamicPPLMCMCChainsExt` to define `returned_quantities`
torfjelde Oct 31, 2024
9a7e18f
updated docs
torfjelde Oct 31, 2024
7aef65b
Update docs/src/api.md
torfjelde Nov 1, 2024
5ee727b
improved docstring for `prefix` and `@prefix`
torfjelde Nov 6, 2024
d92141c
added `@returned_quantities` macro taking two arguments + removed
torfjelde Nov 6, 2024
64b519d
updated docs to reflect the new two-argument `@returned_quantities`
torfjelde Nov 6, 2024
1b48f65
added depwarn to `@submodel` macro
torfjelde Nov 6, 2024
db2102c
fixed reference
torfjelde Nov 6, 2024
da95aba
fixed reference to `@prefix` in `@returned_quantities` macro
torfjelde Nov 6, 2024
c8d567f
actually fixed doc references
torfjelde Nov 6, 2024
d477137
updated doctests for `@submodel` to include the depwarn + added
torfjelde Nov 8, 2024
4896793
Merge branch 'master' into torfjelde/returned-quantities-macro
torfjelde Nov 8, 2024
946fa6d
Merge branch 'master' into torfjelde/returned-quantities-macro
torfjelde Nov 15, 2024
bf35de4
added `to_sampleable` and limited `~` handling for submodels
torfjelde Nov 15, 2024
0f20624
added docs to `to_sampleable` + removed the unnecessary macro exports
torfjelde Nov 15, 2024
99d99b3
updated more docstrings
torfjelde Nov 15, 2024
0597b2a
added testing of deprecation warning of `@submodel` + replaced some
torfjelde Nov 15, 2024
0c6bada
Update test/compiler.jl
torfjelde Nov 15, 2024
5134ff7
renamed `returned_quantities` to `returned` as requested
torfjelde Nov 25, 2024
45451f7
removed redundant `SampleableModelWrapper` in favour of
torfjelde Nov 25, 2024
c00a9ae
updated tests + docstrings + warnings to use `returned`
torfjelde Nov 25, 2024
f0af1d5
updated docs
torfjelde Nov 25, 2024
1b231a9
formatting
torfjelde Nov 25, 2024
1faa627
Update src/model.jl
torfjelde Nov 25, 2024
92ac6b9
fix docs
torfjelde Nov 25, 2024
f73d1b0
Merge branch 'master' into torfjelde/returned-quantities-macro
torfjelde Nov 25, 2024
b7b2e1d
export `to_sampleable` and add to docs
torfjelde Nov 25, 2024
ed4bb76
fixed typo in warning
torfjelde Nov 25, 2024
36f02f6
removed unnecessary import in docstring
torfjelde Nov 25, 2024
98538c5
added docstring to `rand_like!!`
torfjelde Nov 25, 2024
d316306
fixed docstring for `returned(model)`
torfjelde Nov 25, 2024
0e05901
improvements to docstrings thanks to @penelopesym
torfjelde Nov 25, 2024
f073b25
added abstract type `Distributional` and concrete type `Sampleable`,
torfjelde Nov 26, 2024
2ec03c1
replaced usages of `returned` with `to_submodel`
torfjelde Nov 26, 2024
1f70dfc
formatting
torfjelde Nov 26, 2024
f645259
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
torfjelde Nov 26, 2024
23355ea
Update docs/src/api.md
torfjelde Nov 27, 2024
0e82a60
removed export of `to_sampleable` since it currently has no purpose +
torfjelde Nov 27, 2024
b9017c4
formatting
torfjelde Nov 27, 2024
6e149a3
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
torfjelde Nov 27, 2024
933e4ed
updated docstring for `condition` and `fix` to not use `@submdoel`
torfjelde Nov 27, 2024
4fc7b76
added `check_tilde_rhs` for `Sampleable`
torfjelde Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref
@model
```

One can nest models and call another model inside the model function with [`@submodel`](@ref).
One can nest models and call another model inside the model function with [`@submodel`](@ref) and [`@returned_quantities`](@ref).

```@docs
@submodel
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@returned_quantities
```

### Type
Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ export AbstractVarInfo,
# Convenience macros
@addlogprob!,
@submodel,
@returned_quantities,
value_iterator_from_chain,
check_model,
check_model_and_trace,
Expand Down
182 changes: 182 additions & 0 deletions src/submodel_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,185 @@
end
end
end

"""
@returned_quantities [prefix=...] model

Run `model` nested inside of another model and return the return-values of the `model`.

Valid expressions for `prefix=...` are:
- `prefix=false`: no prefix is used. This is the default.
- `prefix=expression`: results in the prefix `Symbol(expression)`.

Prefixing makes it possible to run the same model multiple times while keeping track of
all random variables correctly, i.e. without name clashes.

# Examples

## Simple example
```jldoctest submodel-returned-quantities; setup=:(using Distributions)
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
end;

julia> @model function demo2(x, y)
a = @returned_quantities(demo1(x))
return y ~ Uniform(0, a)
end;
```

When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled:
```jldoctest submodel-returned-quantities
julia> vi = VarInfo(demo2(missing, 0.4));

julia> @varname(x) in keys(vi)
true
```

Variable `a` is not tracked since it can be computed from the random variable `x` that was
tracked when running `demo1`:
```jldoctest submodel-returned-quantities
julia> @varname(a) in keys(vi)
false
```

We can check that the log joint probability of the model accumulated in `vi` is correct:

```jldoctest submodel-returned-quantities
julia> x = vi[@varname(x)];

julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
true
```

## With prefixing
```jldoctest submodel-returned-quantities-prefix; setup=:(using Distributions)
julia> @model function demo1(x)
x ~ Normal()
return 1 + abs(x)
end;

julia> @model function demo2(x, y, z)
a = @returned_quantities prefix="sub1" demo1(x)
b = @returned_quantities prefix="sub2" demo1(y)
return z ~ Uniform(-a, b)
end;
```

When we sample from the model `demo2(missing, missing, 0.4)` random variables `sub1.x` and
`sub2.x` will be sampled:
```jldoctest submodel-returned-quantities-prefix
julia> vi = VarInfo(demo2(missing, missing, 0.4));

julia> @varname(var"sub1.x") in keys(vi)
true

julia> @varname(var"sub2.x") in keys(vi)
true
```

Variables `a` and `b` are not tracked since they can be computed from the random variables `sub1.x` and
`sub2.x` that were tracked when running `demo1`:
```jldoctest submodel-returned-quantities-prefix
julia> @varname(a) in keys(vi)
false

julia> @varname(b) in keys(vi)
false
```

We can check that the log joint probability of the model accumulated in `vi` is correct:

```jldoctest submodel-returned-quantities-prefix
julia> sub1_x = vi[@varname(var"sub1.x")];

julia> sub2_x = vi[@varname(var"sub2.x")];

julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);

julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4);

julia> getlogp(vi) ≈ logprior + loglikelihood
true
```

## Different ways of setting the prefix
```jldoctest submodel-returned-quantities-prefix-alts; setup=:(using DynamicPPL, Distributions)
julia> @model inner() = x ~ Normal()
inner (generic function with 2 methods)

julia> # When `prefix` is unspecified, no prefix is used.
@model submodel_noprefix() = a = @returned_quantities inner()
submodel_noprefix (generic function with 2 methods)

julia> @varname(x) in keys(VarInfo(submodel_noprefix()))
true

julia> # Explicitely don't use any prefix.
@model submodel_prefix_false() = a = @returned_quantities prefix=false inner()
submodel_prefix_false (generic function with 2 methods)

julia> @varname(x) in keys(VarInfo(submodel_prefix_false()))
true

julia> # Using a static string.
@model submodel_prefix_string() = a = @returned_quantities prefix="my prefix" inner()
submodel_prefix_string (generic function with 2 methods)

julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string()))
true

julia> # Using string interpolation.
@model submodel_prefix_interpolation() = a = @returned_quantities prefix="\$(nameof(inner()))" inner()
submodel_prefix_interpolation (generic function with 2 methods)

julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation()))
true

julia> # Or using some arbitrary expression.
@model submodel_prefix_expr() = a = @returned_quantities prefix=1 + 2 inner()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found

@returned_quantities prefix=1 + 2 inner()

hard and unintuitive to parse. I think

@returned_quantities prefix=(1 + 2) inner()

would be much clearer. Not sure if this a documentation issue, or if we should disallow the former.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a documentation issue IMO, as this is not doing any special parsing but reliying on Julia's expression parsing.

submodel_prefix_expr (generic function with 2 methods)

julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr()))
true
```
"""
macro returned_quantities(expr)
return returned_quantities_expr(:(prefix = false), expr)
end

macro returned_quantities(prefix_expr, expr)
return returned_quantities_expr(prefix_expr, expr)
end

"""
@returned_quantities_expr model

Returns an expression that captures the return-values of a model in addition to the varinfo.
"""
function returned_quantities_expr(prefix_expr, expr, ctx=esc(:__context__))
mhauru marked this conversation as resolved.
Show resolved Hide resolved
prefix_left, prefix = getargs_assignment(prefix_expr)
if prefix_left !== :prefix
error("$(prefix_left) is not a valid kwarg")

Check warning on line 411 in src/submodel_macro.jl

View check run for this annotation

Codecov / codecov/patch

src/submodel_macro.jl#L411

Added line #L411 was not covered by tests
end

# The user expects `@submodel ...` to return the
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# return-value of the `...`, hence we need to capture
# the return-value and handle it correctly.
@gensym retval

# Prefix.
if prefix !== nothing
ctx = prefix_submodel_context(prefix, ctx)
end
return quote
# Evaluate the model and capture the return values + varinfo.
$retval, $(esc(:__varinfo__)) = $(_evaluate!!)(
$(esc(expr)), $(esc(:__varinfo__)), $(ctx)
)

# Return the return-value of the model.
$retval
end
end