Skip to content

Commit

Permalink
use full matrix instead of lower cholesky in AdaptiveMHTuner
Browse files Browse the repository at this point in the history
  • Loading branch information
Cornelius-G committed Jul 3, 2023
1 parent f640fbc commit d0ce010
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
15 changes: 12 additions & 3 deletions examples/dev-internal/transformed_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,20 @@ rng = Philox4x()
posterior = BAT.example_posterior()

my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000))
my_samples = my_result.result

mh_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(tuning_alg=TransformedAdaptiveMHTuning(), pre_transform=PriorToGaussian(), nchains=4, nsteps=4*100000))

(;chain, tuner) = BAT.g_state


density_notrafo = convert(BAT.AbstractMeasureOrDensity, posterior)
density, trafo = BAT.transform_and_unshape(PriorToGaussian(), density_notrafo)

c = BAT._approx_cov(density)
f = BAT.CustomTransform(Mul(c))

my_result = @time BAT.bat_sample_impl(rng, posterior, TransformedMCMCSampling(pre_transform=PriorToGaussian(), tuning_alg=TransformedAdaptiveMHTuning(), nchains=4, nsteps=4*100000, adaptive_transform=f))

my_samples = my_result.result



using Plots
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator,
transform = chain.f_transform


#TODO AC: check with Oli
#TODO AC: rename S_L to S, check with Oli

This comment has been minimized.

Copy link
@oschulz

oschulz Jul 3, 2023

Member

We do not want to use the full matrix. It's numerically more expensive (especially in cases where we have to invert it for a reverse transform later). Log-abs-det-jacobian calculation is also very costly for a full matrix. We should definitely only use triangular matrices (or diagonal, e.g. for high-dim HMC later).

S_L = transform.A
Σ_old = S_L

Expand Down Expand Up @@ -113,8 +113,8 @@ function tuning_update!(tuner::TransformedProposalCovTuner, chain::MCMCIterator,

Σ_new = new_Σ_unscal * tuner.scale
#TODO AC: check
S = cholesky(Positive, Σ_new)
chain.f_transform = Mul(S.L)
#S = cholesky(Positive, Σ_new)
chain.f_transform = Mul(Σ_new)
tuner.iteration += 1

nothing
Expand All @@ -133,7 +133,7 @@ tuning_callback(::TransformedProposalCovTuner) = nop_func
function tune_mcmc_transform!!(
rng::AbstractRNG,
tuner::TransformedProposalCovTuner,
transform::Mul{<:LowerTriangular}, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}},
transform::Any, #AffineMaps.AbstractAffineMap,#{<:typeof(*), <:LowerTriangular{<:Real}},
p_accept::Real,
z_proposed::Vector{<:Float64}, #TODO: use DensitySamples instead
z_current::Vector{<:Float64},
Expand Down
2 changes: 1 addition & 1 deletion src/transforms/adaptive_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function init_adaptive_transform(
adaptive_transform::CustomTransform,
density
)
return adaptive_transform
return adaptive_transform.f
end


Expand Down

0 comments on commit d0ce010

Please sign in to comment.