Skip to content

Commit

Permalink
fix test errors
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Oct 31, 2024
1 parent 58162a3 commit e902811
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
10 changes: 7 additions & 3 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,13 @@ function AbstractMCMC.getparams(t::Transition)
return t.params
end

function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, t::Transition, params)
t = BangBang.setproperty!!(t, :params, params)
return BangBang.setproperty!!(t, :lp, LogDensityProblems.logdensity(model.logdensity, params))
# TODO (sunxd): remove `DensityModel` in favor of `AbstractMCMC.LogDensityModel`
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::Transition, params)
return Transition(
params,
logdensity(model, params),
t.accepted
)
end

# Include inference methods.
Expand Down
8 changes: 5 additions & 3 deletions src/MALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ function AbstractMCMC.getparams(t::GradientTransition)
return t.params
end

function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, t::GradientTransition, params)
return AdvancedMH.GradientTransition(
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::GradientTransition, params)
lp, gradient = logdensity_and_gradient(model, params)
return GradientTransition(
params,
AdvancedMH.logdensity_and_gradient(model.logdensity, params)...,
lp,
gradient,
t.accepted
)
end
Expand Down
16 changes: 12 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,18 @@ include("util.jl")
t1, _ = AbstractMCMC.step(Random.default_rng(), model, StaticMH([Normal(0, 1), Normal(0, 1)]))
t2, _ = AbstractMCMC.step(Random.default_rng(), model, MALA(x -> MvNormal(x, I)); initial_params=ones(2))
for t in [t1, t2]
@test AbstractMCMC.getparams(t) == t.params
@test AbstractMCMC.setparams!!(model, t, AbstractMCMC.getparams(t)) == t
t_replaced = AbstractMCMC.setparams!!(model, t, (μ=1.0, σ=2.0))
@test t_replaced.params ===1.0, σ=2.0)
@test AbstractMCMC.getparams(model, t) == t.params

new_transition = AbstractMCMC.setparams!!(model, t, AbstractMCMC.getparams(model, t))
@test new_transition.lp == t.lp
@test new_transition.accepted == t.accepted
@test new_transition.params == t.params
if hasfield(typeof(t), :gradient)
@test new_transition.gradient == t.gradient
end

t_replaced = AbstractMCMC.setparams!!(model, t, [1.0, 2.0])
@test t_replaced.params == [1.0, 2.0]
end
end

Expand Down

0 comments on commit e902811

Please sign in to comment.