Skip to content

Commit

Permalink
Merge branch 'trafo-sampling-changes'
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Oct 19, 2024
2 parents 0f0563f + 2bfd997 commit 4f061a7
Show file tree
Hide file tree
Showing 46 changed files with 484 additions and 428 deletions.
88 changes: 67 additions & 21 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,40 +1,54 @@
BAT.jl v3.0.0 Release Notes
===========================
BAT.jl Release Notes
====================

New features
------------
BAT.jl v4.0.0
-------------

* Support for [DensityInterface](https://github.com/JuliaMath/DensityInterface.jl): BAT will now accept any object that implements the DensityInterface API (specifically `DensityInterface.densitykind` and `DensityInterface.logdensityof`) as likelihoods. In return, all BAT priors and posteriors support the DensityInterface API as well.
### Breaking changes

* Support for [InverseFunctions](https://github.com/JuliaMath/InverseFunctions.jl) and [ChangesOfVariables](https://github.com/JuliaMath/ChangesOfVariables.jl): Parameter transformations in BAT now implement the DensityInterface API. Any function that supports
Several algorithms have changed their names, but also their role:

* `InverseFunctions.inverse`
* `ChangesOfVariables.with_logabsdet_jacobian`
* `output_shape = f(input_shape::ValueShapes.AbstractValueShape)::AbstractValueShape`
* `MCMCSampling` has become `TransformedMCMC`.

can now be used as a parameter transformation in BAT.
* `MetropolisHastings` has become `RandomWalk`. It's parameters have
changed (no deprecation for the parameter changes). Tuning and
sample weighting scheme selection have moved to `TransformedMCMC`.

* `BATContext`, `get_batcontext` and `set_batcontext`
Partial deprecations are available for the above, a lot of old code should
run more or less unchanged (with deprecation warnings). Also:

* `bat_transform` with enhanced capabilities
* `AdaptiveMHTuning` has become `AdaptiveAffineTuning`, but is now
used as a parameter for `TransformedMCMC` (formerly `MCMCSampling`)
instead of `RandomWalk` (formerly `MetropolisHastings`).

* `distprod`, `distbind` and `lbqintegral` are the new ways to express priors and posteriors in BAT.
* `MCMCNoOpTuning` has become `NoMCMCTransformTuning`.

* `bat_report`
* The parameters of `HamiltonianMC` have changed.

* `BAT.enable_error_log` (experimental)
* `MCMCTuningAlgorithm` has been replaced by `MCMCTransformTuning`.

* `BAT.error_log` (experimental)

* `BridgeSampling` (experimental)
### New features

* `EllipsoidalNestedSampling` (experimental)
* The new `RAMTuning` is now the default (transform) tuning algorithm for
`RandomWalk` (formerly `MetropolisHastings`). It typically results in a much
faster burn-in process than `AdaptiveAffineTuning` (formerly
`AdaptiveMHTuning`, the previous default).

* MCMC Sampling handles parameter scale and correlation adaptivity via
via tunable space transformations instead of tuning covariance matrices
in proposal distributions.

MCMC tuning has been split into proposal tuning (algorithms of type
`MCMCProposalTuning`) and transform turning (algorithms of type
`MCMCTransformTuning`). Proposal tuning has now a much more limited role
and often may be `NoMCMCProposalTuning()` (e.g. for `RandomWalk`).

* `ReactiveNestedSampling` (experimental)

BAT.jl v3.0.0
-------------

Breaking changes
----------------
### Breaking changes

* `AbstractVariateTransform` and the function `ladjof` have been removed, BAT parameter transformations do not need to have a specific supertype any longer (see above).

Expand Down Expand Up @@ -90,3 +104,35 @@ Breaking changes
* Pending: BAT will rely less on ValueShapes in the future. Do not use ValueShapes functionality directly where avoidable. Use `distprod` instead of using `ValueShapes.NamedTupleDist` directly, and favor using `bat_transform` instead of shaping and unshaping data using values shapes directly, if possible.

* Use the new function `bat_report` to generate a sampling output report instead of `show(BAT.SampledDensity(samples))`.


### New features
------------

* Support for [DensityInterface](https://github.com/JuliaMath/DensityInterface.jl): BAT will now accept any object that implements the DensityInterface API (specifically `DensityInterface.densitykind` and `DensityInterface.logdensityof`) as likelihoods. In return, all BAT priors and posteriors support the DensityInterface API as well.

* Support for [InverseFunctions](https://github.com/JuliaMath/InverseFunctions.jl) and [ChangesOfVariables](https://github.com/JuliaMath/ChangesOfVariables.jl): Parameter transformations in BAT now implement the DensityInterface API. Any function that supports

* `InverseFunctions.inverse`
* `ChangesOfVariables.with_logabsdet_jacobian`
* `output_shape = f(input_shape::ValueShapes.AbstractValueShape)::AbstractValueShape`

can now be used as a parameter transformation in BAT.

* `BATContext`, `get_batcontext` and `set_batcontext`

* `bat_transform` with enhanced capabilities

* `distprod`, `distbind` and `lbqintegral` are the new ways to express priors and posteriors in BAT.

* `bat_report`

* `BAT.enable_error_log` (experimental)

* `BAT.error_log` (experimental)

* `BridgeSampling` (experimental)

* `EllipsoidalNestedSampling` (experimental)

* `ReactiveNestedSampling` (experimental)
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ AdvancedHMC = "0.5, 0.6"
AffineMaps = "0.2.3, 0.3"
ArgCheck = "1, 2.0"
ArraysOfArrays = "0.4, 0.5, 0.6"
AutoDiffOperators = "0.2"
AutoDiffOperators = "0.1.8, 0.2"
ChainRulesCore = "0.9.44, 0.10, 1"
ChangesOfVariables = "0.1.1"
Clustering = "0.13, 0.14, 0.15"
Expand Down
13 changes: 7 additions & 6 deletions docs/src/experimental_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ SobolSampler
truncate_batmeasure
ValueAndThreshold
BAT.MCMCIterator
BAT.MCMCTunerState
BAT.TemperingState
BAT.MCMCProposalState
BAT.MCMCTempering
BAT.MCMCState
BAT.MCMCChainState
BAT.MCMCChainStateInfo
BAT.MCMCIterator
BAT.MCMCProposal
BAT.MCMCProposalState
BAT.MCMCProposalTunerState
BAT.MCMCState
BAT.MCMCTempering
BAT.MCMCTransformTunerState
BAT.TemperingState
```
8 changes: 4 additions & 4 deletions docs/src/list_of_algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ bat_sample(target.prior, IIDSampling(nsamples=10^5))

### Metropolis-Hastings

BAT sampling algorithm type: [`MCMCSampling`](@ref), MCMC algorithm subtype: [`MetropolisHastings`](@ref)
BAT sampling algorithm type: [`TransformedMCMC`](@ref), MCMC algorithm subtype: [`RandomWalk`](@ref)

```julia
bat_sample(target, MCMCSampling(mcalg = MetropolisHastings(), nsteps = 10^5, nchains = 4))
bat_sample(target, TransformedMCMC(mcalg = RandomWalk(), nsteps = 10^5, nchains = 4))
```


### Hamiltonian MC

BAT sampling algorithm type: [`MCMCSampling`](@ref), MCMC algorithm subtype: [`HamiltonianMC`](@ref)
BAT sampling algorithm type: [`TransformedMCMC`](@ref), MCMC algorithm subtype: [`HamiltonianMC`](@ref)

```julia
import AdvancedHMC, ForwardDiff
set_batcontext(ad = ADSelector(ForwardDiff))
bat_sample(target, MCMCSampling(mcalg = HamiltonianMC()))
bat_sample(target, TransformedMCMC(mcalg = HamiltonianMC()))
```
Requires the [AdvancedHMC](https://github.com/TuringLang/AdvancedHMC.jl) Julia package to be loaded explicitly.

Expand Down
30 changes: 16 additions & 14 deletions docs/src/stable_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ lbqintegral
AbstractMCMCWeightingScheme
AbstractPosteriorMeasure
AbstractTransformTarget
AdaptiveMHTuning
AdaptiveAffineTuning
AssumeConvergence
AutocorLenAlgorithm
BATContext
BATHDF5IO
BATIOAlgorithm
BinningAlgorithm
BrooksGelmanConvergence
CuhreIntegration
DensitySample
Expand All @@ -68,6 +69,8 @@ EffSampleSizeAlgorithm
EffSampleSizeFromAC
EvaluatedMeasure
ExplicitInit
FixedNBins
FreedmanDiaconisBinning
GelmanRubinConvergence
GeyerAutocorLen
HamiltonianMC
Expand All @@ -85,34 +88,33 @@ MCMCBurninAlgorithm
MCMCChainPoolInit
MCMCInitAlgorithm
MCMCMultiCycleBurnin
MCMCNoOpTuning
MCMCSampling
MCMCTuning
MetropolisHastings
MHProposalDistTuning
MCMCProposalTuning
MCMCTransformTuning
ModeAsDefined
NoMCMCProposalTuning
NoMCMCTransformTuning
OptimAlg
OptimizationAlg
OrderedResampling
PosteriorMeasure
PriorSubstitution
PriorToGaussian
PriorToUniform
RAMTuning
RandomWalk
RandResampling
RepetitionWeighting
SampleMedianEstimator
SokalAutocorLen
SuaveIntegration
TransformAlgorithm
VEGASIntegration
BinningAlgorithm
FixedNBins
FreedmanDiaconisBinning
RiceBinning
SampleMedianEstimator
ScottBinning
SokalAutocorLen
SquareRootBinning
SturgesBinning
SuaveIntegration
ToRealVector
TransformAlgorithm
TransformedMCMC
VEGASIntegration
BAT.AbstractMedianEstimator
BAT.AbstractModeEstimator
Expand Down
12 changes: 5 additions & 7 deletions docs/src/tutorial_lit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ posterior = PosteriorMeasure(likelihood, prior)
# Now we can generate a set of MCMC samples via [`bat_sample`](@ref). We'll
# use 4 MCMC chains with 10^5 MC steps in each chain (after tuning/burn-in):

samples = bat_sample(posterior, MCMCSampling(proposal = MetropolisHastings(), nsteps = 10^5, nchains = 4)).result
samples = bat_sample(posterior, TransformedMCMC(proposal = RandomWalk(), nsteps = 10^5, nchains = 4)).result
#md nothing # hide
#nb nothing # hide

Expand Down Expand Up @@ -374,11 +374,9 @@ plot!(-4:0.01:4, x -> fit_function(true_par_values, x), color=4, label = "Truth"
# All option value used in the following are the default values, any or all
# may be omitted.

# We'll sample using the The Metropolis-Hastings MCMC algorithm:
# We'll sample using the random-walk Metropolis-Hastings MCMC algorithm:

mcmcalgo = MetropolisHastings(
weighting = RepetitionWeighting()
)
mcmcalgo = RandomWalk()

# BAT requires a counter-based random number generator (RNG), since it
# partitions the RNG space over the MCMC chains. This way, a single RNG seed
Expand All @@ -393,7 +391,7 @@ context = BATContext(rng = Philox4x())
#md nothing # hide


# By default, `MetropolisHastings()` uses the following options.
# By default, `RandomWalk()` uses the following options.
#
# For Markov chain initialization:

Expand All @@ -412,7 +410,7 @@ convergence = BrooksGelmanConvergence()

samples = bat_sample(
posterior,
MCMCSampling(
TransformedMCMC(
proposal = mcmcalgo,
nchains = 4,
nsteps = 10^5,
Expand Down
4 changes: 2 additions & 2 deletions examples/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ function setup_benchmark()
include("run_benchmark_ND.jl")
end

function do_benchmarks(;algorithm=MetropolisHastings(), nsteps=10^6, nchains=8)
function do_benchmarks(;algorithm=RandomWalk(), nsteps=10^6, nchains=8)
#run_1D_benchmark(algorithm=algorithm, nsteps=nsteps, nchains=nchains)
run_2D_benchmark(algorithm=algorithm, nsteps=nsteps, nchains=nchains)
run_ND_benchmark(n_dim=2:2:20,algorithm=MetropolisHastings(), nsteps=nsteps, nchains=4)
run_ND_benchmark(n_dim=2:2:20,algorithm=RandomWalk(), nsteps=nsteps, nchains=4)
run_ks_ahmc_vs_mh(n_dim=20:5:35)
end

Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarks/run_benchmark_1D.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function run_1D_benchmark(;algorithm=MetropolisHastings(), nsteps=10^5, nchains=8)
function run_1D_benchmark(;algorithm=RandomWalk(), nsteps=10^5, nchains=8)
for i in 1:length(testfunctions_1D)
sample_stats_all = run1D(
collect(keys(testfunctions_1D))[i], #There might be a nicer way but I need the name to save the plots
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarks/run_benchmark_2D.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function run_2D_benchmark(;algorithm = MetropolisHastings(),nchains = 8,nsteps = 10^5)
function run_2D_benchmark(;algorithm = RandomWalk(),nchains = 8,nsteps = 10^5)
for i in 1:length(testfunctions_2D)
sample_stats_all = run2D(
collect(keys(testfunctions_2D))[i], #There might be a nicer way but I need the name to save the plots
Expand Down
16 changes: 8 additions & 8 deletions examples/benchmarks/run_benchmark_ND.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ end

function run_ND_benchmark(;
n_dim = 2:2:20,
algorithm = MetropolisHastings(),
algorithm = RandomWalk(),
nchains = 4,
nsteps = 4*10^5,
time_benchmark = true,
Expand Down Expand Up @@ -244,10 +244,10 @@ function run_ND_benchmark(;

mcmc_sample = nothing
tbf = time()
if isa(algorithm,BAT.MetropolisHastings)
if isa(algorithm,BAT.RandomWalk)
mcmc_sample = bat_sample(
dis,
MCMCSampling(
TransformedMCMC(
mcalg = algorithm,
trafo = DoNotTransform(),
nchains = nchains,
Expand All @@ -262,7 +262,7 @@ function run_ND_benchmark(;
elseif isa(algorithm,BAT.HamiltonianMC)
mcmc_sample = bat_sample(
dis,
MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)
TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)
).result
end
taf = time()
Expand All @@ -271,10 +271,10 @@ function run_ND_benchmark(;
time_per_run = Array{Float64}(undef,5)
for i in 1:5
tbf = time()
if isa(algorithm,BAT.MetropolisHastings)
if isa(algorithm,BAT.RandomWalk)
bat_sample(
dis,
MCMCSampling(
TransformedMCMC(
mcalg = algorithm,
trafo = DoNotTransform(),
nchains = nchains,
Expand All @@ -289,7 +289,7 @@ function run_ND_benchmark(;
elseif isa(algorithm,BAT.HamiltonianMC)
bat_sample(
dis,
MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)
TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)
).result
end
taf = time()
Expand Down Expand Up @@ -359,6 +359,6 @@ end

function run_ks_ahmc_vs_mh(;n_dim=20:5:35,nsteps=2*10^5, nchains=4)
ks_res_ahmc = run_ND_benchmark(n_dim=n_dim,algorithm=HamiltonianMC(), nsteps=nsteps, nchains=nchains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1]
ks_res_mh = run_ND_benchmark(n_dim=n_dim,algorithm=MetropolisHastings(), nsteps=nsteps, nchains=nchains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1]
ks_res_mh = run_ND_benchmark(n_dim=n_dim,algorithm=RandomWalk(), nsteps=nsteps, nchains=nchains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1]
plot_ks_values_ahmc_vs_mh(ks_res_ahmc,ks_res_mh,n_dim)
end
8 changes: 4 additions & 4 deletions examples/benchmarks/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ function run1D(
)

sample_stats_all = []
samples, chains = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
for i in 1:n_runs
time_before = time()
samples, chains = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
time_after = time()

h = plot1D(samples,testfunctions,key,sample_stats)# posterior, key, analytical_stats,sample_stats)
Expand Down Expand Up @@ -438,10 +438,10 @@ function run2D(

sample_stats_all = []

samples, stats = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
for i in 1:n_runs
time_before = time()
samples, stats = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps))
time_after = time()

h = plot2D(samples, testfunctions, key, sample_stats)
Expand Down
Loading

0 comments on commit 4f061a7

Please sign in to comment.