Skip to content

Commit

Permalink
:= to keep track of generated quantities (#594)
Browse files Browse the repository at this point in the history
* added assignemnt operator

* use special construct to hijack assume for `:=`

* forgot to include change in previous commit

* test the assignment operator

* remove syntax incompat with older Julia versions

* improved existing test
  • Loading branch information
torfjelde authored May 6, 2024
1 parent 4cf395b commit d5ae280
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 3 deletions.
24 changes: 24 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,33 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
)
end

# Modify the assignment operators.
args_assign = getargs_coloneq(expr)
if args_assign !== nothing
L, R = args_assign
return Base.remove_linenums!(
generate_assign(
generate_mainbody!(mod, found, L, warn),
generate_mainbody!(mod, found, R, warn),
),
)
end

return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
end

function generate_assign(left, right)
right_expr = :($(TrackedValue)($right))
tilde_expr = generate_tilde(left, right_expr)
return quote
if $(is_extracting_values)(__context__)
$tilde_expr
else
$left = $right
end
end
end

function generate_tilde_literal(left, right)
# If the LHS is a literal, it is always an observation
@gensym value
Expand Down
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,20 @@ function getargs_assignment(expr::Expr)
end
end

"""
getargs_coloneq(x)
Return the arguments `L` and `R`, if `x` is an expression of the form `L := R`, or `nothing`
otherwise.
"""
getargs_coloneq(x) = nothing
function getargs_coloneq(expr::Expr)
return MacroTools.@match expr begin
(L_ := R_) => (L, R)
x_ => nothing
end
end

function to_namedtuple_expr(syms)
length(syms) == 0 && return :(NamedTuple())

Expand Down
29 changes: 27 additions & 2 deletions src/values_as_in_model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
struct TrackedValue{T}
value::T
end

is_tracked_value(::TrackedValue) = true
is_tracked_value(::Any) = false

check_tilde_rhs(x::TrackedValue) = x

"""
ValuesAsInModelContext
Expand Down Expand Up @@ -29,6 +37,13 @@ function setchildcontext(context::ValuesAsInModelContext, child)
return ValuesAsInModelContext(context.values, child)
end

is_extracting_values(context::ValuesAsInModelContext) = true
function is_extracting_values(context::AbstractContext)
return is_extracting_values(NodeTrait(context), context)
end
is_extracting_values(::IsParent, ::AbstractContext) = false
is_extracting_values(::IsLeaf, ::AbstractContext) = false

function Base.push!(context::ValuesAsInModelContext, vn::VarName, value)
return setindex!(context.values, copy(value), vn)
end
Expand All @@ -48,7 +63,12 @@ end

# `tilde_asssume`
function tilde_assume(context::ValuesAsInModelContext, right, vn, vi)
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
if is_tracked_value(right)
value = right.value
logp = zero(getlogp(vi))
else
value, logp, vi = tilde_assume(childcontext(context), right, vn, vi)
end
# Save the value.
push!(context, vn, value)
# Save the value.
Expand All @@ -58,7 +78,12 @@ end
function tilde_assume(
rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi
)
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
if is_tracked_value(right)
value = right.value
logp = zero(getlogp(vi))
else
value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi)
end
# Save the value.
push!(context, vn, value)
# Pass on.
Expand Down
30 changes: 29 additions & 1 deletion test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,34 @@ module Issue537 end
# And one explicit test for logging so know that is working.
@model demo_with_logging() = @info "hi"
model = demo_with_logging()
@test model() == nothing
@test model() === nothing
# Make sure that the log message is present.
@test_logs (:info, "hi") model()
end

@testset ":= (tracked values)" begin
@model function demo_tracked()
x ~ Normal()
y := 100 + x
return (; x, y)
end
@model function demo_tracked_submodel()
@submodel (x, y) = demo_tracked()
return (; x, y)
end
for model in [demo_tracked(), demo_tracked_submodel()]
# Make sure it's runnable and `y` is present in the return-value.
@test model() isa NamedTuple{(:x, :y)}

# `VarInfo` should only contain `x`.
varinfo = VarInfo(model)
@test haskey(varinfo, @varname(x))
@test !haskey(varinfo, @varname(y))

# While `values_as_in_model` should contain both `x` and `y`.
values = values_as_in_model(model, deepcopy(varinfo))
@test haskey(values, @varname(x))
@test haskey(values, @varname(y))
end
end
end

0 comments on commit d5ae280

Please sign in to comment.