Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Contexts for the sack of simplicity #229

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ using AbstractMCMC
using AdvancedHMC
using AdvancedHMC: Transition, stat
using JuliaBUGS
using JuliaBUGS:
AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.BangBang
using JuliaBUGS.LogDensityProblems
Expand Down
2 changes: 1 addition & 1 deletion ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module JuliaBUGSAdvancedMHExt
using AbstractMCMC
using AdvancedMH
using JuliaBUGS
using JuliaBUGS: BUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: BUGSModel, find_generated_vars, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
Expand Down
4 changes: 2 additions & 2 deletions ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module JuliaBUGSMCMCChainsExt

using JuliaBUGS
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, evaluate!!
using JuliaBUGS.AbstractPPL
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
Expand Down Expand Up @@ -89,7 +89,7 @@ function JuliaBUGS.gen_chains(
param_vals = []
generated_quantities = []
for i in axes(samples)[1]
evaluation_env = first(evaluate!!(model, LogDensityContext(), samples[i]))
evaluation_env = first(evaluate!!(model, samples[i]))
push!(
param_vals,
[AbstractPPL.get(evaluation_env, param_var) for param_var in param_vars],
Expand Down
4 changes: 2 additions & 2 deletions src/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ function gibbs_internal end

function gibbs_internal(rng::Random.AbstractRNG, cond_model::BUGSModel, ::MHFromPrior)
transformed_original = JuliaBUGS.getparams(cond_model)
values, logp = evaluate!!(cond_model, LogDensityContext(), transformed_original)
values_proposed, logp_proposed = evaluate!!(cond_model, SamplingContext())
values, logp = evaluate!!(cond_model, transformed_original)
values_proposed, logp_proposed = evaluate!!(rng, cond_model)

if logp_proposed - logp > log(rand(rng))
values = values_proposed
Expand Down
2 changes: 1 addition & 1 deletion src/logdensityproblems.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function LogDensityProblems.logdensity(model::AbstractBUGSModel, x::AbstractArray)
_, logp = evaluate!!(model, LogDensityContext(), x)
_, logp = evaluate!!(model, x)
return logp
end

Expand Down
59 changes: 16 additions & 43 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ end
Initialize the model with a vector of initial values, the values can be in transformed space if `model.transformed` is set to true.
"""
function initialize!(model::BUGSModel, initial_params::AbstractVector)
evaluation_env, _ = AbstractPPL.evaluate!!(model, LogDensityContext(), initial_params)
evaluation_env, _ = AbstractPPL.evaluate!!(model, initial_params)
return BangBang.setproperty!!(model, :evaluation_env, evaluation_env)
end

Expand Down Expand Up @@ -260,18 +260,23 @@ function getparams(model::BUGSModel)
return param_vals
end

function getparams_as_ordereddict(model::BUGSModel)
d = OrderedDict{VarName,Any}()
"""
getparams(T::Type{<:AbstractDict}, model::BUGSModel)

Extract the parameter values from the model into a dictionary of type T.
If model.transformed is true, returns parameters in transformed space.
"""
function getparams(T::Type{<:AbstractDict}, model::BUGSModel)
d = T()
for v in model.parameters
value = AbstractPPL.get(model.evaluation_env, v)
if !model.transformed
d[v] = AbstractPPL.get(model.evaluation_env, v)
d[v] = value
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
d[v] = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
)
d[v] = Bijectors.transform(Bijectors.bijector(dist), value)
end
end
return d
Expand Down Expand Up @@ -355,7 +360,7 @@ function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
new_model = BUGSModel(
model, new_parameters, sorted_blanket_with_vars, base_model.evaluation_env
)
evaluate_env, _ = evaluate!!(new_model, DefaultContext())
evaluate_env, _ = evaluate!!(new_model)
return BangBang.setproperty!!(new_model, :evaluation_env, evaluate_env)
end

Expand All @@ -368,33 +373,7 @@ function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel)
)
end

"""
DefaultContext

Use values in varinfo to compute the log joint density.
"""
struct DefaultContext <: AbstractPPL.AbstractContext end

"""
SamplingContext

Do an ancestral sampling of the model parameters. Also accumulate log joint density.
"""
@kwdef struct SamplingContext{T<:Random.AbstractRNG} <: AbstractPPL.AbstractContext
rng::T = Random.default_rng()
end

"""
LogDensityContext

Use the given values to compute the log joint density.
"""
struct LogDensityContext <: AbstractPPL.AbstractContext end

function AbstractPPL.evaluate!!(model::BUGSModel, rng::Random.AbstractRNG)
return evaluate!!(model, SamplingContext(rng))
end
function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
function AbstractPPL.evaluate!!(rng::Random.AbstractRNG, model::BUGSModel)
(; evaluation_env, g, sorted_nodes) = model
vi = deepcopy(evaluation_env)
logp = 0.0
Expand All @@ -406,7 +385,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
evaluation_env = setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
value = rand(ctx.rng, dist) # just sample from the prior
value = rand(rng, dist) # just sample from the prior
logp += logpdf(dist, value)
evaluation_env = setindex!!(evaluation_env, value, vn)
end
Expand All @@ -415,11 +394,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
end

function AbstractPPL.evaluate!!(model::BUGSModel)
return AbstractPPL.evaluate!!(model, DefaultContext())
end
function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
(; sorted_nodes, g, evaluation_env) = model
vi = deepcopy(evaluation_env)
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
Expand All @@ -446,9 +421,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
return evaluation_env, logp
end

function AbstractPPL.evaluate!!(
model::BUGSModel, ::LogDensityContext, flattened_values::AbstractVector
)
function AbstractPPL.evaluate!!(model::BUGSModel, flattened_values::AbstractVector)
var_lengths = if model.transformed
model.transformed_var_lengths
else
Expand Down
8 changes: 2 additions & 6 deletions test/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ mb_logp = begin
end

# order: b, l, c, a
@test mb_logp ≈ evaluate!!(cond_model, JuliaBUGS.LogDensityContext(), [c_value])[2] rtol =
1e-8
@test mb_logp ≈ evaluate!!(cond_model, [c_value])[2] rtol = 1e-8

# test LogDensityContext
@test begin
logp = 0
logp += logpdf(dnorm(1.0, 3.0), 1.0) # a, where f = 1.0
Expand All @@ -80,9 +78,7 @@ end
logp += logpdf(dnorm(2.0, 1.0), 4.0) # d, where g = 2.0
logp += logpdf(dnorm(4.0, 4.0), 5.0) # e, where h = 4.0
logp
end ≈ evaluate!!(
model, JuliaBUGS.LogDensityContext(), [-2.0, 4.0, 3.0, 2.0, 1.0, 4.0, 5.0]
)[2] atol = 1e-8
end ≈ evaluate!!(model, [-2.0, 4.0, 3.0, 2.0, 1.0, 4.0, 5.0])[2] atol = 1e-8

# AuxiliaryNodeInfo
test_model = @bugs begin
Expand Down
2 changes: 1 addition & 1 deletion test/log_density.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: make this available in JuliaBUGS
function _logjoint(model::JuliaBUGS.BUGSModel)
return JuliaBUGS.evaluate!!(model, JuliaBUGS.DefaultContext())[2]
return JuliaBUGS.evaluate!!(model)[2]
end

@testset "Log density" begin
Expand Down
Loading