Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# 0.41.0

HMC and NUTS samplers no longer take an extra single step before starting the chain.
This means that if you do not discard any samples at the start, the first sample will be the initial parameters (which may be user-provided).

Note that if the initial sample is included, the corresponding sampler statistics will be `missing`.
Due to a technical limitation of MCMCChains.jl, this causes all indexing into MCMCChains to return `Union{Float64, Missing}` or similar.
If you want the old behaviour, you can discard the first sample (e.g. using `discard_initial=1`).

# 0.40.3

This patch makes the `resume_from` keyword argument work correctly when sampling multiple chains.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.40.3"
version = "0.41.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
26 changes: 3 additions & 23 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,32 +216,12 @@ function DynamicPPL.initialstep(
else
ϵ = spl.alg.ϵ
end

# Generate a kernel.
# Generate a kernel and adaptor.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Create initial transition and state.
# Already perform one step since otherwise we don't get any statistics.
t = AHMC.transition(rng, hamiltonian, kernel, z)

# Adaptation
adaptor = AHMCAdaptor(spl.alg, hamiltonian.metric; ϵ=ϵ)
if spl.alg isa AdaptiveHamiltonian
hamiltonian, kernel, _ = AHMC.adapt!(
hamiltonian, kernel, adaptor, 1, nadapts, t.z.θ, t.stat.acceptance_rate
)
end

# Update VarInfo parameters based on acceptance
new_params = if t.stat.is_accept
t.z.θ
else
theta
end
vi = DynamicPPL.unflatten(vi, new_params)

transition = Transition(model, vi, t)
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
transition = Transition(model, vi, NamedTuple())
state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor)

return transition, state
end
Expand Down
10 changes: 6 additions & 4 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ using Turing
Random.seed!(5)
chain2 = sample(model, sampler, MCMCThreads(), 10, 4)

@test chain1.value == chain2.value
# For HMC, the first step does not have stats, so we need to use isequal to
# avoid comparing `missing`s
@test isequal(chain1.value, chain2.value)
end

# Should also be stable with an explicit RNG
Expand All @@ -54,7 +56,7 @@ using Turing
Random.seed!(rng, local_seed)
chain2 = sample(rng, model, sampler, MCMCThreads(), 10, 4)

@test chain1.value == chain2.value
@test isequal(chain1.value, chain2.value)
end
end

Expand Down Expand Up @@ -608,8 +610,8 @@ using Turing

@testset "names_values" begin
ks, xs = Turing.Inference.names_values([(a=1,), (b=2,), (a=3, b=4)])
@test all(xs[:, 1] .=== [1, missing, 3])
@test all(xs[:, 2] .=== [missing, 2, 4])
@test isequal(xs[:, 1], [1, missing, 3])
@test isequal(xs[:, 2], [missing, 2, 4])
end

@testset "check model" begin
Expand Down
22 changes: 22 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,28 @@ using Turing
@test Array(res1) == Array(res2) == Array(res3)
end

@testset "initial params are respected" begin
@model demo_norm() = x ~ Beta(2, 2)
init_x = 0.5
@testset "$spl_name" for (spl_name, spl) in
(("HMC", HMC(0.1, 10)), ("NUTS", NUTS()))
chain = sample(
demo_norm(), spl, 5; discard_adapt=false, initial_params=(x=init_x,)
)
@test chain[:x][1] == init_x
chain = sample(
demo_norm(),
spl,
MCMCThreads(),
5,
5;
discard_adapt=false,
initial_params=(fill((x=init_x,), 5)),
)
@test all(chain[:x][1, :] .== init_x)
end
end

@testset "warning for difficult init params" begin
attempt = 0
@model function demo_warn_initial_params()
Expand Down
3 changes: 2 additions & 1 deletion test/mcmc/repeat_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ using Turing
num_chains;
chain_type=Chains,
)
@test chn1.value == chn2.value
# isequal to avoid comparing `missing`s in chain stats
@test isequal(chn1.value, chn2.value)
end
end

Expand Down
4 changes: 1 addition & 3 deletions test/stdlib/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ using Turing
end

@testset "single distribution correctness" begin
rng = StableRNG(1)

n_samples = 10_000
mean_tol = 0.1
var_atol = 1.0
Expand Down Expand Up @@ -132,7 +130,7 @@ using Turing

@model m() = x ~ dist

chn = sample(rng, m(), HMC(0.05, 20), n_samples)
chn = sample(StableRNG(468), m(), HMC(0.05, 20), n_samples)

# Numerical tests.
check_dist_numerical(
Expand Down