Skip to content

Commit

Permalink
Store multivariate proposal in MHProposalState
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Oct 28, 2024
1 parent ab202a2 commit 4bbe199
Showing 1 changed file with 47 additions and 22 deletions.
69 changes: 47 additions & 22 deletions src/samplers/mcmc/mh_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ Fields:
$(TYPEDFIELDS)
"""
@with_kw struct RandomWalk{Q<:ContinuousUnivariateDistribution} <: MCMCProposal
@with_kw struct RandomWalk{Q<:Union{AbstractMeasure,Distribution{<:Union{Univariate,Multivariate},Continuous}}} <: MCMCProposal
proposaldist::Q = TDist(1.0)
end

export RandomWalk

struct MHProposalState{Q<:ContinuousUnivariateDistribution} <: MCMCProposalState
struct MHProposalState{Q<:BATMeasure} <: MCMCProposalState
proposaldist::Q
end
export MHProposalState
Expand All @@ -44,36 +44,66 @@ bat_default(::Type{TransformedMCMC}, ::Val{:burnin}, proposal::RandomWalk, pretr
MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 2500))


function _get_sample_id(proposal::MHProposalState, id::Int32, cycle::Int32, stepno::Integer, sample_type::Integer)
return MCMCSampleID(id, cycle, stepno, sample_type), MCMCSampleID
end


function _create_proposal_state(
proposal::RandomWalk,
target::BATMeasure,
context::BATContext,
v_init::AbstractVector{<:Real},
rng::AbstractRNG
)
return MHProposalState(proposal.proposaldist)
n_dims = length(v_init)
mv_pdist = batmeasure(_full_random_walk_proposal(proposal.proposaldist, n_dims))
return MHProposalState(mv_pdist)
end


function _get_sample_id(proposal::MHProposalState, id::Int32, cycle::Int32, stepno::Integer, sample_type::Integer)
return MCMCSampleID(id, cycle, stepno, sample_type), MCMCSampleID
function _full_random_walk_proposal(m::AbstractMeasure, n_dims::Integer)
x = testvalue(m)
@argcheck x isa AbstractVector{<:Real} && length(x) == n_dims
return m

Check warning on line 68 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L65-L68

Added lines #L65 - L68 were not covered by tests
end

function _full_random_walk_proposal(m::BATDistMeasure, n_dims::Integer)
d = convert(Distribution, m)
return batmeasure(_full_random_walk_proposal(d, n_dims))

Check warning on line 73 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L71-L73

Added lines #L71 - L73 were not covered by tests
end

function _full_random_walk_proposal(d::Distribution{Multivariate,Continuous}, n_dims::Integer)
@assert false
@argcheck length(d) == n_dims
return d

Check warning on line 79 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L76-L79

Added lines #L76 - L79 were not covered by tests
end

# Theoretical optimally proposal scale for random walk with gaussian proposal, according to
# [Gelman et al., Ann. Appl. Probab. 7 (1) 110 - 120, 1997](https://doi.org/10.1214/aoap/1034625254)
_optimal_proposal_scale(d::ContinuousUnivariateDistribution, n_dims::Integer) = 2.38 / sqrt(n_dims) / sqrt(var(d))
function _full_random_walk_proposal(d::Normal, n_dims::Integer)

Check warning on line 82 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L82

Added line #L82 was not covered by tests
# Theoretical optimally proposal scale for random walk with gaussian proposal, according to
# [Gelman et al., Ann. Appl. Probab. 7 (1) 110 - 120, 1997](https://doi.org/10.1214/aoap/1034625254):
proposal_scale = 2.38 / sqrt(n_dims)

Check warning on line 85 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L85

Added line #L85 was not covered by tests

# Determined experimentally for TDist
const _tdist_corr_exp = [0.5, 0.2, 0.14, 0.085, 0.06, 0.045, 0.035, 0.02, 0.015, 0.015]
function _optimal_proposal_scale(d::TDist, n_dims::Integer)
ν_int = round(Int, d.ν)
k = ν_int > 10 ? zero(eltype(_tdist_corr_exp)) : _tdist_corr_exp[ν_int]
2.38 / sqrt(n_dims) / n_dims^k
@argcheck mean(d) 0
σ² = var(d)
Σ = ScalMat(n_dims, proposal_scale^2 * σ²)
return MvNormal(Σ)

Check warning on line 90 in src/samplers/mcmc/mh_sampler.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mh_sampler.jl#L87-L90

Added lines #L87 - L90 were not covered by tests
end

function _full_random_walk_proposal(d::TDist, n_dims::Integer)
# Theoretically optimal proposal scale for gaussian seems to work quite well for
# t-distribution proposals with any degrees of freedom as well:
proposal_scale = 2.38 / sqrt(n_dims)

ν = dof(d)
Σ = ScalMat(n_dims, proposal_scale^2)
return Distributions.IsoTDist(ν, Σ)
end


const MHChainState = MCMCChainState{<:BATMeasure, <:RNGPartition, <:Function, <:MHProposalState}


function mcmc_propose!!(mc_state::MHChainState)
@unpack target, proposal, f_transform, context = mc_state
rng = get_rng(context)
Expand All @@ -85,11 +115,9 @@ function mcmc_propose!!(mc_state::MHChainState)

z_current, logd_z_current = sample_z_current.v, sample_z_current.logd
T = eltype(z_current)
n_dims = size(z_current, 1)

proposal_scale = T(_optimal_proposal_scale(pdist, n_dims))

z_proposed = z_current + proposal_scale .* T.(rand(rng, pdist, n_dims)) #TODO: check if proposal is symmetric? otherwise need additional factor?
# ToDo: Use gen-context:
z_proposed = z_current + T.(rand(rng, pdist))
x_proposed, ladj = with_logabsdet_jacobian(f_transform, z_proposed)
logd_x_proposed = BAT.checked_logdensityof(target, x_proposed)
logd_z_proposed = logd_x_proposed + ladj
Expand All @@ -99,12 +127,9 @@ function mcmc_propose!!(mc_state::MHChainState)
mc_state.samples[proposed_x_idx] = DensitySample(x_proposed, logd_x_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing)
mc_state.sample_z[2] = DensitySample(z_proposed, logd_z_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing)

# TODO: MD, should we check for symmetriy of proposal distribution?
# TODO: check if proposal is symmetric - otherwise need Hastings correction:
p_accept = clamp(exp(logd_z_proposed - logd_z_current), 0, 1)


@assert p_accept >= 0

accepted = rand(rng) <= p_accept

return mc_state, accepted, p_accept
Expand Down

0 comments on commit 4bbe199

Please sign in to comment.