From d5ae280776f64ef7ee56356cddffdcb80e944606 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 May 2024 19:01:04 +0100 Subject: [PATCH] `:=` to keep track of generated quantities (#594) * added assignemnt operator * use special construct to hijack assume for `:=` * forgot to include change in previous commit * test the assignment operator * remove syntax incompat with older Julia versions * improved existing test --- src/compiler.jl | 24 ++++++++++++++++++++++++ src/utils.jl | 14 ++++++++++++++ src/values_as_in_model.jl | 29 +++++++++++++++++++++++++++-- test/compiler.jl | 30 +++++++++++++++++++++++++++++- 4 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index f8a04a557..898acad10 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -371,9 +371,33 @@ function generate_mainbody!(mod, found, expr::Expr, warn) ) end + # Modify the assignment operators. + args_assign = getargs_coloneq(expr) + if args_assign !== nothing + L, R = args_assign + return Base.remove_linenums!( + generate_assign( + generate_mainbody!(mod, found, L, warn), + generate_mainbody!(mod, found, R, warn), + ), + ) + end + return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) end +function generate_assign(left, right) + right_expr = :($(TrackedValue)($right)) + tilde_expr = generate_tilde(left, right_expr) + return quote + if $(is_extracting_values)(__context__) + $tilde_expr + else + $left = $right + end + end +end + function generate_tilde_literal(left, right) # If the LHS is a literal, it is always an observation @gensym value diff --git a/src/utils.jl b/src/utils.jl index 291020c2e..9493e1bc9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -163,6 +163,20 @@ function getargs_assignment(expr::Expr) end end +""" + getargs_coloneq(x) + +Return the arguments `L` and `R`, if `x` is an expression of the form `L := R`, or `nothing` +otherwise. +""" +getargs_coloneq(x) = nothing +function getargs_coloneq(expr::Expr) + return MacroTools.@match expr begin + (L_ := R_) => (L, R) + x_ => nothing + end +end + function to_namedtuple_expr(syms) length(syms) == 0 && return :(NamedTuple()) diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index dcf68c15c..52ba6eb61 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -1,3 +1,11 @@ +struct TrackedValue{T} + value::T +end + +is_tracked_value(::TrackedValue) = true +is_tracked_value(::Any) = false + +check_tilde_rhs(x::TrackedValue) = x """ ValuesAsInModelContext @@ -29,6 +37,13 @@ function setchildcontext(context::ValuesAsInModelContext, child) return ValuesAsInModelContext(context.values, child) end +is_extracting_values(context::ValuesAsInModelContext) = true +function is_extracting_values(context::AbstractContext) + return is_extracting_values(NodeTrait(context), context) +end +is_extracting_values(::IsParent, ::AbstractContext) = false +is_extracting_values(::IsLeaf, ::AbstractContext) = false + function Base.push!(context::ValuesAsInModelContext, vn::VarName, value) return setindex!(context.values, copy(value), vn) end @@ -48,7 +63,12 @@ end # `tilde_asssume` function tilde_assume(context::ValuesAsInModelContext, right, vn, vi) - value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + if is_tracked_value(right) + value = right.value + logp = zero(getlogp(vi)) + else + value, logp, vi = tilde_assume(childcontext(context), right, vn, vi) + end # Save the value. push!(context, vn, value) # Save the value. @@ -58,7 +78,12 @@ end function tilde_assume( rng::Random.AbstractRNG, context::ValuesAsInModelContext, sampler, right, vn, vi ) - value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + if is_tracked_value(right) + value = right.value + logp = zero(getlogp(vi)) + else + value, logp, vi = tilde_assume(rng, childcontext(context), sampler, right, vn, vi) + end # Save the value. push!(context, vn, value) # Pass on. diff --git a/test/compiler.jl b/test/compiler.jl index 6f23cd38e..9fa36b5ff 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -687,6 +687,34 @@ module Issue537 end # And one explicit test for logging so know that is working. @model demo_with_logging() = @info "hi" model = demo_with_logging() - @test model() == nothing + @test model() === nothing + # Make sure that the log message is present. + @test_logs (:info, "hi") model() + end + + @testset ":= (tracked values)" begin + @model function demo_tracked() + x ~ Normal() + y := 100 + x + return (; x, y) + end + @model function demo_tracked_submodel() + @submodel (x, y) = demo_tracked() + return (; x, y) + end + for model in [demo_tracked(), demo_tracked_submodel()] + # Make sure it's runnable and `y` is present in the return-value. + @test model() isa NamedTuple{(:x, :y)} + + # `VarInfo` should only contain `x`. + varinfo = VarInfo(model) + @test haskey(varinfo, @varname(x)) + @test !haskey(varinfo, @varname(y)) + + # While `values_as_in_model` should contain both `x` and `y`. + values = values_as_in_model(model, deepcopy(varinfo)) + @test haskey(values, @varname(x)) + @test haskey(values, @varname(y)) + end end end