diff --git a/docs/src/ref/combinators.md b/docs/src/ref/combinators.md index 8c18258bd..71f51cc83 100644 --- a/docs/src/ref/combinators.md +++ b/docs/src/ref/combinators.md @@ -119,4 +119,41 @@ TODO: document me schematic of recurse combinatokr ``` +## Switch combinator +```@docs +Switch +``` + +In the schematic below, the kernel is denoted `S` and accepts an integer index `k`. + +Consider the following constructions: + +```julia +@gen function bang((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z +end + +@gen function fuzz((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z +end + +sc = Switch(bang, fuzz) +``` + +This creates a new generative function `sc`. We can then obtain the trace of `sc`: + +```julia +(trace, _) = simulate(sc, (2, 5.0, 3.0)) +``` + +The resulting trace contains the subtrace from the branch with index `2` - in this case, a call to `fuzz`: + +``` +│ +└── :z : 13.552870875213735 +``` diff --git a/src/modeling_library/cond.jl b/src/modeling_library/cond.jl new file mode 100644 index 000000000..9c0ce4fd3 --- /dev/null +++ b/src/modeling_library/cond.jl @@ -0,0 +1,20 @@ +# ------------ Switch trace ------------ # + +struct SwitchTrace{T} <: Trace + gen_fn::GenerativeFunction{T} + index::Int + branch::Trace + retval::T + args::Tuple + score::Float64 + noise::Float64 +end + +@inline get_choices(tr::SwitchTrace) = get_choices(tr.branch) +@inline get_retval(tr::SwitchTrace) = tr.retval +@inline get_args(tr::SwitchTrace) = tr.args +@inline get_score(tr::SwitchTrace) = tr.score +@inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn +@inline Base.getindex(tr::SwitchTrace, addr) = Base.getindex(tr.branch, addr) +@inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection) +@inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index d0797426c..2572b0fdd 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -66,12 +66,16 @@ include("dist_dsl/dist_dsl.jl") # code shared by vector-shaped combinators include("vector.jl") +# traces for with prob/switch combinator +include("cond.jl") + # built-in generative function combinators include("choice_at/choice_at.jl") include("call_at/call_at.jl") include("map/map.jl") include("unfold/unfold.jl") include("recurse/recurse.jl") +include("switch/switch.jl") ############################################################# # abstractions for constructing custom generative functions # diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl new file mode 100644 index 000000000..4371eb8a4 --- /dev/null +++ b/src/modeling_library/switch/assess.jl @@ -0,0 +1,26 @@ +mutable struct SwitchAssessState{T} + weight::Float64 + retval::T + SwitchAssessState{T}(weight::Float64) where T = new{T}(weight) +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + args::Tuple, + choices::ChoiceMap, + state::SwitchAssessState{T}) where {C, N, K, T} + (weight, retval) = assess(getindex(gen_fn.branches, index), args, choices) + state.weight = weight + state.retval = retval +end + +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) + +function assess(gen_fn::Switch{C, N, K, T}, + args::Tuple, + choices::ChoiceMap) where {C, N, K, T} + index = args[1] + state = SwitchAssessState{T}(0.0) + process!(gen_fn, index, args[2 : end], choices, state) + return state.weight, state.retval +end diff --git a/src/modeling_library/switch/backprop.jl b/src/modeling_library/switch/backprop.jl new file mode 100644 index 000000000..28add2423 --- /dev/null +++ b/src/modeling_library/switch/backprop.jl @@ -0,0 +1,2 @@ +@inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = choice_gradients(getfield(trace, :branch), selection, retval_grad) +@inline accumulate_param_gradients!(trace::SwitchTrace{T}, retval_grad, scale_factor = 1.) where {T} = accumulate_param_gradients!(getfield(trace, :branch), retval_grad, scale_factor) diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl new file mode 100644 index 000000000..bd03f632e --- /dev/null +++ b/src/modeling_library/switch/generate.jl @@ -0,0 +1,34 @@ +mutable struct SwitchGenerateState{T} + score::Float64 + noise::Float64 + weight::Float64 + index::Int + subtrace::Trace + retval::T + SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + args::Tuple, + choices::ChoiceMap, + state::SwitchGenerateState{T}) where {C, N, K, T} + + (subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices) + state.index = index + state.subtrace = subtrace + state.weight += weight + state.retval = get_retval(subtrace) +end + +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) + +function generate(gen_fn::Switch{C, N, K, T}, + args::Tuple, + choices::ChoiceMap) where {C, N, K, T} + + index = args[1] + state = SwitchGenerateState{T}(0.0, 0.0, 0.0) + process!(gen_fn, index, args[2 : end], choices, state) + return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight +end diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl new file mode 100644 index 000000000..b4df1d97f --- /dev/null +++ b/src/modeling_library/switch/propose.jl @@ -0,0 +1,29 @@ +mutable struct SwitchProposeState{T} + choices::DynamicChoiceMap + weight::Float64 + retval::T + SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + args::Tuple, + state::SwitchProposeState{T}) where {C, N, K, T} + + (submap, weight, retval) = propose(getindex(gen_fn.branches, index), args) + state.choices = submap + state.weight += weight + state.retval = retval +end + +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) + +function propose(gen_fn::Switch{C, N, K, T}, + args::Tuple) where {C, N, K, T} + + index = args[1] + choices = choicemap() + state = SwitchProposeState{T}(choices, 0.0) + process!(gen_fn, index, args[2:end], state) + return state.choices, state.weight, state.retval +end diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl new file mode 100644 index 000000000..cb1094aff --- /dev/null +++ b/src/modeling_library/switch/regenerate.jl @@ -0,0 +1,60 @@ +mutable struct SwitchRegenerateState{T} + weight::Float64 + score::Float64 + noise::Float64 + prev_trace::Trace + trace::Trace + index::Int + retdiff::Diff + SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::Diff, + args::Tuple, + kernel_argdiffs::Tuple, + selection::Selection, + state::SwitchRegenerateState{T}) where {C, N, K, T} + branch_fn = getfield(gen_fn.branches, index) + merged = get_selected(get_choices(state.prev_trace), complement(selection)) + new_trace, weight = generate(branch_fn, args, merged) + retdiff = UnknownChange() + weight -= project(state.prev_trace, complement(selection)) + weight += (project(new_trace, selection) - project(state.prev_trace, selection)) + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) + state.trace = new_trace + state.retdiff = retdiff +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::NoChange, + args::Tuple, + kernel_argdiffs::Tuple, + selection::Selection, + state::SwitchRegenerateState{T}) where {C, N, K, T} + new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) + state.trace = new_trace + state.retdiff = retdiff +end + +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) + +function regenerate(trace::SwitchTrace{T}, + args::Tuple, + argdiffs::Tuple, + selection::Selection) where T + gen_fn = trace.gen_fn + index, index_argdiff = args[1], argdiffs[1] + state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) + process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state) + return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff +end diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl new file mode 100644 index 000000000..fc4b3b02a --- /dev/null +++ b/src/modeling_library/switch/simulate.jl @@ -0,0 +1,32 @@ +mutable struct SwitchSimulateState{T} + score::Float64 + noise::Float64 + index::Int + subtrace::Trace + retval::T + SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + args::Tuple, + state::SwitchSimulateState{T}) where {C, N, K, T} + local retval::T + subtrace = simulate(getindex(gen_fn.branches, index), args) + state.index = index + state.noise += project(subtrace, EmptySelection()) + state.subtrace = subtrace + state.score += get_score(subtrace) + state.retval = get_retval(subtrace) +end + +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) + +function simulate(gen_fn::Switch{C, N, K, T}, + args::Tuple) where {C, N, K, T} + + index = args[1] + state = SwitchSimulateState{T}(0.0, 0.0) + process!(gen_fn, index, args[2 : end], state) + return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) +end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl new file mode 100644 index 000000000..821143448 --- /dev/null +++ b/src/modeling_library/switch/switch.jl @@ -0,0 +1,56 @@ +struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace} + branches::NTuple{N, GenerativeFunction{T}} + cases::Dict{C, Int} + function Switch(gen_fns::GenerativeFunction...) + @assert !isempty(gen_fns) + rettype = get_return_type(getindex(gen_fns, 1)) + new{Int, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, Dict{Int, Int}()) + end + function Switch(d::Dict{C, Int}, gen_fns::GenerativeFunction...) where C + @assert !isempty(gen_fns) + rettype = get_return_type(getindex(gen_fns, 1)) + new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d) + end +end +export Switch + +has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.branches)...)) do as + all(as) +end +accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches) + +function (gen_fn::Switch)(index::Int, args...) + (_, _, retval) = propose(gen_fn, (index, args...)) + retval +end + +function (gen_fn::Switch{C})(index::C, args...) where C + (_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...)) + retval +end + +include("assess.jl") +include("propose.jl") +include("simulate.jl") +include("generate.jl") +include("update.jl") +include("regenerate.jl") +include("backprop.jl") + +@doc( +""" + gen_fn = Switch(gen_fns::GenerativeFunction...) + +Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` where the first index indicates which branch to call. + + gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T + +Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` or an argument tuple of type `Tuple{T, ...}` where the first index either indicates which branch to call, or indicates an index into `d` which maps to the selected branch. This form is meant for convenience - it allows the programmer to use `d` like if-else or case statements. + +`Switch` is designed to allow for the expression of patterns of if-else control flow. `gen_fns` must satisfy a few requirements: + +1. Each `gen_fn` in `gen_fns` must accept the same argument types. +2. Each `gen_fn` in `gen_fns` must return the same return type. + +Otherwise, each `gen_fn` can come from different modeling languages, possess different traces, etc. +""", Switch) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl new file mode 100644 index 000000000..6aa672fd4 --- /dev/null +++ b/src/modeling_library/switch/update.jl @@ -0,0 +1,141 @@ +mutable struct SwitchUpdateState{T} + weight::Float64 + score::Float64 + noise::Float64 + prev_trace::Trace + trace::Trace + index::Int + discard::ChoiceMap + updated_retdiff::Diff + SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) +end + +function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) + prev_choice_submap_iterator = get_submaps_shallow(prev_choices) + prev_choice_value_iterator = get_values_shallow(prev_choices) + choice_submap_iterator = get_submaps_shallow(choices) + choice_value_iterator = get_values_shallow(choices) + new_choices = DynamicChoiceMap() + + # Add (address, value) to new_choices from prev_choices if address does not occur in choices. + for (address, value) in prev_choice_value_iterator + address in keys(choice_value_iterator) && continue + set_value!(new_choices, address, value) + end + + # Add (address, submap) to new_choices from prev_choices if address does not occur in choices. + # If it does, enter a recursive call to update_recurse_merge. + for (address, node1) in prev_choice_submap_iterator + if address in keys(choice_submap_iterator) + node2 = get_submap(choices, address) + node = update_recurse_merge(node1, node2) + set_submap!(new_choices, address, node) + else + set_submap!(new_choices, address, node1) + end + end + + # Add (address, value) from choices to new_choices. This is okay because we've excluded any conflicting addresses from the prev_choices above. + for (address, value) in choice_value_iterator + set_value!(new_choices, address, value) + end + + sel, _ = zip(prev_choice_submap_iterator...) + comp = complement(select(sel...)) + for (address, node) in get_submaps_shallow(get_selected(choices, comp)) + set_submap!(new_choices, address, node) + end + return new_choices +end + +@doc( +""" +update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) + +Returns choices that are in constraints, merged with all choices in the previous trace that do not have the same address as some choice in the constraints." +""", update_recurse_merge) + +function update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap) + discard = choicemap() + for (k, v) in get_submaps_shallow(prev_choices) + new_submap = get_submap(new_choices, k) + choices_submap = get_submap(choices, k) + sub_discard = update_discard(v, choices_submap, new_submap) + set_submap!(discard, k, sub_discard) + end + for (k, v) in get_values_shallow(prev_choices) + if (!has_value(new_choices, k) || has_value(choices, k)) + set_value!(discard, k, v) + end + end + discard +end + +@doc( +""" +update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap) + +Returns choices from previous trace that: + 1. have an address which does not appear in the new trace. + 2. have an address which does appear in the constraints. +""", update_discard) + +@inline update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) = update_discard(get_choices(prev_trace), choices, get_choices(new_trace)) + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::UnknownChange, + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T}) where {C, N, K, T, DV} + + # Generate new trace. + merged = update_recurse_merge(get_choices(state.prev_trace), choices) + branch_fn = getfield(gen_fn.branches, index) + new_trace, weight = generate(branch_fn, args, merged) + weight -= get_score(state.prev_trace) + state.discard = update_discard(state.prev_trace, choices, new_trace) + + # Set state. + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) + state.trace = new_trace + state.updated_retdiff = UnknownChange() +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::NoChange, # TODO: Diffed wrapper? + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T}) where {C, N, K, T} + + # Update trace. + new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) + + # Set state. + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) + state.trace = new_trace + state.updated_retdiff = retdiff + state.discard = discard +end + +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state) + +function update(trace::SwitchTrace{T}, + args::Tuple, + argdiffs::Tuple, + choices::ChoiceMap) where T + gen_fn = trace.gen_fn + index, index_argdiff = args[1], argdiffs[1] + state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace) + process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], choices, state) + return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.updated_retdiff, state.discard +end diff --git a/test/modeling_library/modeling_library.jl b/test/modeling_library/modeling_library.jl index 616110f84..2ebb8929d 100644 --- a/test/modeling_library/modeling_library.jl +++ b/test/modeling_library/modeling_library.jl @@ -5,4 +5,5 @@ include("call_at.jl") include("map.jl") include("unfold.jl") include("recurse.jl") +include("switch.jl") include("dist_dsl.jl") diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl new file mode 100644 index 000000000..8c183aa16 --- /dev/null +++ b/test/modeling_library/switch.jl @@ -0,0 +1,340 @@ +@testset "switch combinator" begin + + # ------------ Trace ------------ # + + @gen function swtrg() + z ~ normal(3.0, 5.0) + return z + end + + @testset "switch trace" begin + tr = simulate(swtrg, ()) + swtr = Gen.SwitchTrace(swtrg, 1, tr, get_retval(tr), (), get_score(tr), 0.0) + @test swtr[:z] == tr[:z] + @test project(swtr, AllSelection()) == project(swtr.branch, AllSelection()) + @test project(swtr, EmptySelection()) == swtr.noise + end + + # ------------ Bare combinator ------------ # + + # Model chunk. + @gen (grad) function bang0((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z + end + + @gen (grad) function fuzz0((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z + end + sc = Switch(bang0, fuzz0) + # ----. + + @testset "simulate" begin + tr = simulate(sc, (1, 5.0, 3.0)) + @test isapprox(get_score(tr), logpdf(normal, tr[:z], 5.0 + 3.0, 3.0)) + tr = simulate(sc, (2, 5.0, 3.0)) + @test isapprox(get_score(tr), logpdf(normal, tr[:z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "generate" begin + chm = choicemap() + chm[:z] = 5.0 + tr, w = generate(sc, (2, 5.0, 3.0), chm) + assignment = get_choices(tr) + @test assignment[:z] == 5.0 + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) + end + + @testset "assess" begin + chm = choicemap() + chm[:z] = 5.0 + w, ret = assess(sc, (2, 5.0, 3.0), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) + end + + @testset "propose" begin + chm, w = propose(sc, (2, 5.0, 3.0)) + @test isapprox(w, logpdf(normal, chm[:z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "update" begin + tr = simulate(sc, (1, 5.0, 3.0)) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + new_tr, w, rd, discard = update(tr, (2, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + chm) + @test isapprox(old_sc, get_score(new_tr) - w) + chm = choicemap((:x => :z, 10.0)) + new_tr, w, rd, discard = update(tr, (1, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + chm) + @test isapprox(old_sc, get_score(new_tr) - w) + end + + @testset "regenerate" begin + tr = simulate(sc, (2, 5.0, 3.0)) + old_sc = get_score(tr) + sel = select(:z) + new_tr, w, rd = regenerate(tr, (2, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + sel) + @test isapprox(old_sc, get_score(new_tr) - w) + new_tr, w, rd = regenerate(tr, (1, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + sel) + @test isapprox(old_sc, get_score(new_tr) - w) + end + + @testset "choice gradients" begin + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:z, z)) + tr, _ = generate(sc, (1, 5.0, 3.0), chm) + sel = select(:z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 3.0, 3.0) + @test isapprox(gradients[:z], expected_choice_grad[1]) + tr, _ = generate(sc, (2, 5.0, 3.0), chm) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 2 * 3.0, 3.0) + @test isapprox(gradients[:z], expected_choice_grad[1]) + end + end + + # ------------ Hierarchy ------------ # + + # Model chunk. + @gen (grad) function bang1((grad)(x::Float64), (grad)(y::Float64)) + @param(std::Float64) + z = @trace(normal(x + y, std), :z) + return z + end + init_param!(bang1, :std, 3.0) + @gen (grad) function fuzz1((grad)(x::Float64), (grad)(y::Float64)) + @param(std::Float64) + z = @trace(normal(x + 2 * y, std), :z) + return z + end + init_param!(fuzz1, :std, 3.0) + sc = Switch(bang1, fuzz1) + @gen (grad) function bam(s::Int) + x ~ sc(s, 5.0, 3.0) + return x + end + # ----. + + @testset "simulate" begin + tr = simulate(bam, (2, )) + @test isapprox(get_score(tr), logpdf(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "generate" begin + chm = choicemap() + chm[:x => :z] = 5.0 + tr, w = generate(bam, (2, ), chm) + assignment = get_choices(tr) + @test assignment[:x => :z] == 5.0 + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) + end + + @testset "assess" begin + chm = choicemap() + chm[:x => :z] = 5.0 + w, ret = assess(bam, (2, ), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) + end + + @testset "propose" begin + chm, w = propose(bam, (2, )) + @test isapprox(w, logpdf(normal, chm[:x => :z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "update" begin + tr = simulate(bam, (2, )) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) + @test isapprox(old_sc, get_score(new_tr) - w) + chm = choicemap((:x => :z, 10.0)) + new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) + @test isapprox(old_sc, get_score(new_tr) - w) + end + + @testset "regenerate" begin + tr = simulate(bam, (2, )) + old_sc = get_score(tr) + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test isapprox(old_sc, get_score(new_tr) - w) + new_tr, w = regenerate(tr, (2, ), (UnknownChange(), ), select()) + @test isapprox(old_sc, get_score(new_tr) - w) + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test isapprox(old_sc, get_score(new_tr) - w) + end + + @testset "choice gradients" begin + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:x => :z, z)) + tr, _ = generate(bam, (1, ), chm) + sel = select(:x => :z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 3.0, 3.0) + @test isapprox(gradients[:x => :z], expected_choice_grad[1]) + chm = choicemap((:x => :z, z)) + tr, _ = generate(bam, (2, ), chm) + sel = select(:x => :z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 2 * 3.0, 3.0) + @test isapprox(gradients[:x => :z], expected_choice_grad[1]) + end + end + + @testset "accumulate parameter gradients" begin + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:z, z)) + tr, _ = generate(bam, (1, ), chm) + zero_param_grad!(bang1, :std) + input_grads = accumulate_param_gradients!(tr, 1.0) + expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 3.0, 3.0)[3] + @test isapprox(get_param_grad(bang1, :std), expected_std_grad) + tr, _ = generate(bam, (2, ), chm) + zero_param_grad!(fuzz1, :std) + input_grads = accumulate_param_gradients!(tr, 1.0) + expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)[3] + @test isapprox(get_param_grad(fuzz1, :std), expected_std_grad) + end + end + + # ------------ (More complex) hierarchy ------------ # + + # Model chunk. + @gen (grad) function bang2((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z + end + @gen (grad) function fuzz2((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + q = @trace(bang2(z, y), :q) + return z + end + sc2 = Switch(bang2, fuzz2) + @gen (grad) function bam2(s::Int) + x ~ sc2(s, 5.0, 3.0) + return x + end + # ----. + + @testset "simulate" begin + tr = simulate(bam2, (1, )) + @test isapprox(get_score(tr), logpdf(normal, tr[:x => :z], 5.0 + 3.0, 3.0)) + tr = simulate(bam2, (2, )) + @test isapprox(get_score(tr), logpdf(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0) + logpdf(normal, tr[:x => :q => :z], tr[:x => :z] + 3.0, 3.0)) + end + + @testset "generate" begin + chm = choicemap() + chm[:x => :z] = 5.0 + tr, w = generate(bam2, (1, ), chm) + assignment = get_choices(tr) + @test assignment[:x => :z] == 5.0 + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 3.0, 3.0)) + tr, w = generate(bam2, (2, ), chm) + assignment = get_choices(tr) + @test assignment[:x => :z] == 5.0 + @test isapprox(w, logpdf(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "assess" begin + chm = choicemap() + chm[:x => :z] = 5.0 + w, ret = assess(bam2, (1, ), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 3.0, 3.0)) + chm[:x => :q => :z] = 5.0 + w, ret = assess(bam2, (2, ), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0) + logpdf(normal, 5.0, 5.0 + 3.0, 3.0)) + end + + @testset "propose" begin + chm, w = propose(bam2, (1, )) + @test isapprox(w, logpdf(normal, chm[:x => :z], 5.0 + 3.0, 3.0)) + chm, w = propose(bam2, (2, )) + @test isapprox(w, logpdf(normal, chm[:x => :z], 5.0 + 2 * 3.0, 3.0) + logpdf(normal, chm[:x => :q => :z], chm[:x => :z] + 3.0, 3.0)) + end + + @testset "update" begin + tr = simulate(bam2, (2, )) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) + @test isapprox(old_sc, get_score(new_tr) - w) + chm = choicemap((:x => :z, 10.0)) + new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) + @test isapprox(old_sc, get_score(new_tr) - w) + end + + @testset "regenerate" begin + tr = simulate(bam2, (2, )) + old_sc = get_score(tr) + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test isapprox(old_sc, get_score(new_tr) - w) + new_tr, w = regenerate(tr, (2, ), (UnknownChange(), ), select()) + @test isapprox(old_sc, get_score(new_tr) - w) + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test isapprox(old_sc, get_score(new_tr) - w) + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select(:x => :z)) + @test isapprox(old_sc, get_score(new_tr) - w) + end + + @testset "choice gradients" begin + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:x => :z, z)) + tr, _ = generate(bam2, (1, ), chm) + sel = select(:x => :z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 3.0, 3.0) + @test isapprox(gradients[:x => :z], expected_choice_grad[1]) + end + end + + # ------------ (More complex) hierarchy to test discard ------------ # + + # Model chunk. + @gen (grad) function bang3((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + q = @trace(bang2(z, y), :q) + return z + end + @gen (grad) function fuzz3((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + m = @trace(normal(x + 3 * y, std), :m) + q = @trace(bang3(z, y), :q) + return z + end + sc3 = Switch(bang3, fuzz3) + @gen (grad) function bam3(s::Int) + x ~ sc3(s, 5.0, 3.0) + return x + end + # ----. + + @testset "update" begin + tr = simulate(bam3, (2, )) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + future_discarded = tr[:x => :z] + new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) + @test discard[:x => :z] == future_discarded + @test isapprox(old_sc, get_score(new_tr) - w) + chm = choicemap((:x => :z, 10.0)) + future_discarded = tr[:x => :q => :q => :z] + new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) + @test discard[:x => :q => :q => :z] == future_discarded + @test isapprox(old_sc, get_score(new_tr) - w) + end +end