Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for issue #95 #96

Merged
merged 6 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.8.1"
version = "0.8.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
9 changes: 3 additions & 6 deletions src/MALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ function AbstractMCMC.step(

# Compute the log ratio of proposal densities.
logratio_proposal_density = q(
proposal(-gradient_logdensity_candidate), state, candidate
) - q(proposal(-gradient_logdensity_state), candidate, state)
proposal(gradient_logdensity_candidate), state, candidate
) - q(proposal(gradient_logdensity_state), candidate, state)

# Compute the log acceptance probability.
logα = logdensity_candidate - logdensity_state + logratio_proposal_density
Expand All @@ -72,10 +72,7 @@ function AbstractMCMC.step(
transition = if -Random.randexp(rng) < logα
GradientTransition(candidate, logdensity_candidate, gradient_logdensity_candidate, true)
else
candidate = transition_prev.params
lp = transition_prev.lp
gradient = transition_prev.gradient
GradientTransition(candidate, lp, gradient, false)
GradientTransition(transition_prev.params, transition_prev.lp, transition_prev.gradient, false)
end

return transition, transition
Expand Down
6 changes: 4 additions & 2 deletions test/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
sampler = Ensemble(1_000, StretchProposal([InverseGamma(2, 3), Normal(0, 1)]))

chain = sample(model, sampler, 1_000;
param_names = ["s", "m"], chain_type = Chains)
param_names = ["s", "m"], chain_type = Chains, progress=false)
@test chain isa Chains
@test range(chain) == 1:1_000
@test mean(chain["s"]) ≈ 49/24 atol=0.1
Expand All @@ -33,6 +33,7 @@
chain_type = Chains,
discard_initial=25,
thinning=4,
progress=false
)
@test chain2 isa Chains
@test range(chain2) == range(26; step=4, length=1_000)
Expand All @@ -59,7 +60,7 @@
Random.seed!(100)
sampler = Ensemble(1_000, StretchProposal(MvNormal(zeros(2), I)))
chain = sample(model, sampler, 1_000;
param_names = ["logs", "m"], chain_type = Chains)
param_names = ["logs", "m"], chain_type = Chains, progress=false)
@test chain isa Chains
@test range(chain) == 1:1_000
@test mean(exp, chain["logs"]) ≈ 49/24 atol=0.1
Expand All @@ -73,6 +74,7 @@
chain_type = Chains,
discard_initial=25,
thinning=4,
progress=false
)
@test chain2 isa Chains
@test range(chain2) == range(26; step=4, length=1_000)
Expand Down
122 changes: 86 additions & 36 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ include("util.jl")
spl3 = StaticMH(2)

# Sample from the posterior.
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"])
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"])
chain3 = sample(model, spl3, 100000; chain_type=StructArray, param_names=["μ", "σ"])
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
chain3 = sample(model, spl3, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)

# chn_mean ≈ dist_mean atol=atol_v
@test mean(chain1.μ) ≈ 0.0 atol=0.1
Expand All @@ -60,9 +60,9 @@ include("util.jl")
spl3 = RWMH(2)

# Sample from the posterior.
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"])
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"])
chain3 = sample(model, spl3, 200000; chain_type=StructArray, param_names=["μ", "σ"])
chain1 = sample(model, spl1, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
chain2 = sample(model, spl2, 100000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)
chain3 = sample(model, spl3, 200000; chain_type=StructArray, param_names=["μ", "σ"], progress=false)

# chn_mean ≈ dist_mean atol=atol_v
@test mean(chain1.μ) ≈ 0.0 atol=0.1
Expand All @@ -77,13 +77,13 @@ include("util.jl")
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])

chain1 = sample(model, spl1, MCMCDistributed(), 10000, 4;
param_names=["μ", "σ"], chain_type=Chains)
param_names=["μ", "σ"], chain_type=Chains, progress=false)
@test mean(chain1["μ"]) ≈ 0.0 atol=0.1
@test mean(chain1["σ"]) ≈ 1.0 atol=0.1

if VERSION >= v"1.3"
chain2 = sample(model, spl1, MCMCThreads(), 10000, 4;
param_names=["μ", "σ"], chain_type=Chains)
param_names=["μ", "σ"], chain_type=Chains, progress=false)
@test mean(chain2["μ"]) ≈ 0.0 atol=0.1
@test mean(chain2["σ"]) ≈ 1.0 atol=0.1
end
Expand All @@ -93,7 +93,7 @@ include("util.jl")
# Array of parameters
chain1 = sample(
model, StaticMH([Normal(0,1), Normal(0, 1)]), 10_000;
param_names=["μ", "σ"], chain_type=Chains
param_names=["μ", "σ"], chain_type=Chains, progress=false
)
@test chain1 isa Chains
@test range(chain1) == 1:10_000
Expand All @@ -103,6 +103,7 @@ include("util.jl")
chain1b = sample(
model, StaticMH([Normal(0,1), Normal(0, 1)]), 10_000;
param_names=["μ", "σ"], chain_type=Chains, discard_initial=25, thinning=4,
progress=false
)
@test chain1b isa Chains
@test range(chain1b) == range(26; step=4, length=10_000)
Expand All @@ -115,7 +116,8 @@ include("util.jl")
MetropolisHastings(
(μ = StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1)))
), 10_000;
chain_type=Chains
chain_type=Chains,
progress=false
)
@test chain2 isa Chains
@test range(chain2) == 1:10_000
Expand All @@ -128,6 +130,7 @@ include("util.jl")
(μ = StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1)))
), 10_000;
chain_type=Chains, discard_initial=25, thinning=4,
progress=false
)
@test chain2b isa Chains
@test range(chain2b) == range(26; step=4, length=10_000)
Expand All @@ -137,7 +140,8 @@ include("util.jl")
# Scalar parameter
chain3 = sample(
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
StaticMH(Normal(0, 1)), 10_000; param_names=["μ"], chain_type=Chains
StaticMH(Normal(0, 1)), 10_000; param_names=["μ"], chain_type=Chains,
progress=false
)
@test chain3 isa Chains
@test range(chain3) == 1:10_000
Expand All @@ -147,6 +151,7 @@ include("util.jl")
DensityModel(x -> loglikelihood(Normal(x, 1), data)),
StaticMH(Normal(0, 1)), 10_000;
param_names=["μ"], chain_type=Chains, discard_initial=25, thinning=4,
progress=false
)
@test chain3b isa Chains
@test range(chain3b) == range(26; step=4, length=10_000)
Expand All @@ -164,10 +169,10 @@ include("util.jl")
p3 = (a=StaticProposal(Normal(0,1)), b=StaticProposal(InverseGamma(2,3)))
p4 = StaticProposal((x=1.0) -> Normal(x, 1))

c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple})
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple})
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple}, progress=false)
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple}, progress=false)
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple}, progress=false)
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple}, progress=false)

@test keys(c1[1]) == (:param_1, :lp)
@test keys(c2[1]) == (:param_1, :param_2, :lp)
Expand All @@ -182,7 +187,7 @@ include("util.jl")
val = [0.4, 1.2]

# Sample from the posterior.
chain1 = sample(model, spl1, 10, initial_params = val)
chain1 = sample(model, spl1, 10, initial_params = val, progress=false)

@test chain1[1].params == val
end
Expand All @@ -199,12 +204,12 @@ include("util.jl")
p1 = RandomWalkProposal(CustomNormal())
@test p1 isa RandomWalkProposal{false}
@test_throws MethodError AdvancedMH.logratio_proposal_density(p1, randn(), randn())
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10)
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10, progress=false)

p1 = StaticProposal(x -> CustomNormal(x))
@test p1 isa StaticProposal{false}
@test_throws MethodError AdvancedMH.logratio_proposal_density(p1, randn(), randn())
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10)
@test_throws MethodError sample(m1, MetropolisHastings(p1), 10, progress=false)

# If the proposal is declared to be symmetric, the log ratio of the proposal
# density is not evaluated.
Expand All @@ -227,7 +232,8 @@ include("util.jl")
))
chain1 = sample(
m1, MetropolisHastings(p2), 100000;
chain_type=StructArray, param_names=["x"]
chain_type=StructArray, param_names=["x"],
progress=false
)
@test mean(chain1.x) ≈ mean(d1) atol=0.05
@test std(chain1.x) ≈ std(d1) atol=0.05
Expand Down Expand Up @@ -260,29 +266,73 @@ include("util.jl")
end

@testset "MALA" begin
# Set up the sampler.
σ² = 0.01
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
@testset "basic" begin
# Set up the sampler.
σ² = 1e-3
spl1 = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))

# Sample from the posterior with initial parameters.
chain1 = sample(model, spl1, 100000; initial_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
# Sample from the posterior with initial parameters.
chain1 = sample(
model, spl1, 1000;
initial_params=ones(2),
chain_type=StructArray,
param_names=["μ", "σ"],
discard_initial=100,
progress=false
)

@test mean(chain1.μ) ≈ 0.0 atol=0.1
@test mean(chain1.σ) ≈ 1.0 atol=0.1
@test mean(chain1.μ) ≈ 0.0 atol = 0.1
@test mean(chain1.σ) ≈ 1.0 atol = 0.1

@testset "LogDensityProblems interface" begin
admodel = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), density)
chain2 = sample(
admodel,
spl1,
1000;
initial_params=ones(2),
chain_type=StructArray,
param_names=["μ", "σ"],
discard_initial=100,
progress=false
)

@test mean(chain2.μ) ≈ 0.0 atol = 0.1
@test mean(chain2.σ) ≈ 1.0 atol = 0.1
end
end

@testset "LogDensityProblems interface" begin
admodel = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), density)
chain2 = sample(
admodel,
spl1,
100000;
@testset "issue #95" begin
struct TheNormalLogDensity{M}
A::M
end

# can do gradient
LogDensityProblems.capabilities(::Type{<:TheNormalLogDensity}) = LogDensityProblems.LogDensityOrder{1}()

LogDensityProblems.dimension(d::TheNormalLogDensity) = size(d.A, 1)
LogDensityProblems.logdensity(d::TheNormalLogDensity, x) = -x' * d.A * x / 2

function LogDensityProblems.logdensity_and_gradient(d::TheNormalLogDensity, x)
return -x' * d.A * x / 2, -d.A * x
end

Σ = [1.5 0.35; 0.35 1.0]
σ² = 0.5
spl = AdvancedMH.MALA(g -> Distributions.MvNormal((σ² / 2) .* g, σ² * I))

chain = sample(
TheNormalLogDensity(inv(Σ)),
spl,
500000;
initial_params=ones(2),
chain_type=StructArray,
param_names=["μ", "σ"]
progress=false
)
data = mapreduce(Base.Fix2(getproperty, :params), hcat, chain)
Σ_est = cov(data, dims=2)

@test mean(chain2.μ) ≈ 0.0 atol=0.1
@test mean(chain2.σ)1.0 atol=0.1
@test mean(data, dims=2) ≈ zeros(2) atol = 0.1
@test ΣΣ_est atol = 2e-1
end
end

Expand Down
Loading