Skip to content

Commit

Permalink
Merge pull request #1388 from pints-team/mala
Browse files Browse the repository at this point in the history
Add MALA methods
  • Loading branch information
ben18785 authored Aug 16, 2021
2 parents 24603ab + 32033d5 commit c3d98dc
Showing 1 changed file with 137 additions and 0 deletions.
137 changes: 137 additions & 0 deletions pints/functionaltests/mala_mcmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#!/usr/bin/env python3
#
# This file is part of PINTS (https://github.com/pints-team/pints/) which is
# released under the BSD 3-clause license. See accompanying LICENSE.md for
# copyright notice and full license details.
#

from __future__ import division

import pints

from ._problems import (RunMcmcMethodOnTwoDimGaussian,
RunMcmcMethodOnBanana,
RunMcmcMethodOnHighDimensionalGaussian,
RunMcmcMethodOnCorrelatedGaussian,
RunMcmcMethodOnAnnulus,
RunMcmcMethodOnMultimodalGaussian,
RunMcmcMethodOnCone)


def test_mala_mcmc_on_two_dim_gaussian(n_iterations=None):
if n_iterations is None:
n_iterations = 1000
problem = RunMcmcMethodOnTwoDimGaussian(
method=pints.MALAMCMC,
n_chains=4,
n_iterations=n_iterations,
n_warmup=200,
method_hyper_parameters=[[1.0, 1.0]]
)

return {
'kld': problem.estimate_kld(),
'mean-ess': problem.estimate_mean_ess()
}


def test_mala_mcmc_on_banana(n_iterations=None):
if n_iterations is None:
n_iterations = 2000
problem = RunMcmcMethodOnBanana(
method=pints.MALAMCMC,
n_chains=4,
n_iterations=n_iterations,
n_warmup=500,
method_hyper_parameters=[[0.8] * 2]
)

return {
'kld': problem.estimate_kld(),
'mean-ess': problem.estimate_mean_ess()
}


def test_mala_mcmc_on_high_dim_gaussian(n_iterations=None):
if n_iterations is None:
n_iterations = 2000
problem = RunMcmcMethodOnHighDimensionalGaussian(
method=pints.MALAMCMC,
n_chains=4,
n_iterations=n_iterations,
n_warmup=500,
method_hyper_parameters=[[1.2] * 20]
)

return {
'kld': problem.estimate_kld(),
'mean-ess': problem.estimate_mean_ess()
}


def test_mala_mcmc_on_correlated_gaussian(n_iterations=None):
if n_iterations is None:
n_iterations = 2000
problem = RunMcmcMethodOnCorrelatedGaussian(
method=pints.MALAMCMC,
n_chains=4,
n_iterations=n_iterations,
n_warmup=500,
method_hyper_parameters=[[1.0] * 6]
)

return {
'kld': problem.estimate_kld(),
'mean-ess': problem.estimate_mean_ess()
}


def test_mala_mcmc_on_annulus(n_iterations=None):
if n_iterations is None:
n_iterations = 2000
problem = RunMcmcMethodOnAnnulus(
method=pints.MALAMCMC,
n_chains=4,
n_iterations=n_iterations,
n_warmup=500,
method_hyper_parameters=[[1.2] * 2]
)

return {
'distance': problem.estimate_distance(),
'mean-ess': problem.estimate_mean_ess()
}


def test_mala_mcmc_on_multimodal_gaussian(n_iterations=None):
if n_iterations is None:
n_iterations = 2000
problem = RunMcmcMethodOnMultimodalGaussian(
method=pints.MALAMCMC,
n_chains=4,
n_iterations=n_iterations,
n_warmup=500,
method_hyper_parameters=[[2.0] * 2]
)

return {
'kld': problem.estimate_kld(),
'mean-ess': problem.estimate_mean_ess()
}


def test_mala_mcmc_on_cone(n_iterations=None):
if n_iterations is None:
n_iterations = 2000
problem = RunMcmcMethodOnCone(
method=pints.MALAMCMC,
n_chains=4,
n_iterations=n_iterations,
n_warmup=500,
method_hyper_parameters=[[1.0, 1.0]]
)

return {
'distance': problem.estimate_distance(),
'mean-ess': problem.estimate_mean_ess()
}

0 comments on commit c3d98dc

Please sign in to comment.