diff --git a/NEWS.md b/NEWS.md index cacb6af14..d1f778617 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,40 +1,54 @@ -BAT.jl v3.0.0 Release Notes -=========================== +BAT.jl Release Notes +==================== -New features ------------- +BAT.jl v4.0.0 +------------- -* Support for [DensityInterface](https://github.com/JuliaMath/DensityInterface.jl): BAT will now accept any object that implements the DensityInterface API (specifically `DensityInterface.densitykind` and `DensityInterface.logdensityof`) as likelihoods. In return, all BAT priors and posteriors support the DensityInterface API as well. +### Breaking changes -* Support for [InverseFunctions](https://github.com/JuliaMath/InverseFunctions.jl) and [ChangesOfVariables](https://github.com/JuliaMath/ChangesOfVariables.jl): Parameter transformations in BAT now implement the DensityInterface API. Any function that supports +Several algorithms have changed their names, but also their role: - * `InverseFunctions.inverse` - * `ChangesOfVariables.with_logabsdet_jacobian` - * `output_shape = f(input_shape::ValueShapes.AbstractValueShape)::AbstractValueShape` +* `MCMCSampling` has become `TransformedMCMC`. -can now be used as a parameter transformation in BAT. +* `MetropolisHastings` has become `RandomWalk`. It's parameters have + changed (no deprecation for the parameter changes). Tuning and + sample weighting scheme selection have moved to `TransformedMCMC`. -* `BATContext`, `get_batcontext` and `set_batcontext` +Partial deprecations are available for the above, a lot of old code should +run more or less unchanged (with deprecation warnings). Also: -* `bat_transform` with enhanced capabilities +* `AdaptiveMHTuning` has become `AdaptiveAffineTuning`, but is now + used as a parameter for `TransformedMCMC` (formerly `MCMCSampling`) + instead of `RandomWalk` (formerly `MetropolisHastings`). -* `distprod`, `distbind` and `lbqintegral` are the new ways to express priors and posteriors in BAT. +* `MCMCNoOpTuning` has become `NoMCMCTransformTuning`. -* `bat_report` +* The parameters of `HamiltonianMC` have changed. -* `BAT.enable_error_log` (experimental) +* `MCMCTuningAlgorithm` has been replaced by `MCMCTransformTuning`. -* `BAT.error_log` (experimental) -* `BridgeSampling` (experimental) +### New features -* `EllipsoidalNestedSampling` (experimental) +* The new `RAMTuning` is now the default (transform) tuning algorithm for + `RandomWalk` (formerly `MetropolisHastings`). It typically results in a much + faster burn-in process than `AdaptiveAffineTuning` (formerly + `AdaptiveMHTuning`, the previous default). + +* MCMC Sampling handles parameter scale and correlation adaptivity via + via tunable space transformations instead of tuning covariance matrices + in proposal distributions. + + MCMC tuning has been split into proposal tuning (algorithms of type + `MCMCProposalTuning`) and transform turning (algorithms of type + `MCMCTransformTuning`). Proposal tuning has now a much more limited role + and often may be `NoMCMCProposalTuning()` (e.g. for `RandomWalk`). -* `ReactiveNestedSampling` (experimental) +BAT.jl v3.0.0 +------------- -Breaking changes ----------------- +### Breaking changes * `AbstractVariateTransform` and the function `ladjof` have been removed, BAT parameter transformations do not need to have a specific supertype any longer (see above). @@ -90,3 +104,35 @@ Breaking changes * Pending: BAT will rely less on ValueShapes in the future. Do not use ValueShapes functionality directly where avoidable. Use `distprod` instead of using `ValueShapes.NamedTupleDist` directly, and favor using `bat_transform` instead of shaping and unshaping data using values shapes directly, if possible. * Use the new function `bat_report` to generate a sampling output report instead of `show(BAT.SampledDensity(samples))`. + + +### New features +------------ + +* Support for [DensityInterface](https://github.com/JuliaMath/DensityInterface.jl): BAT will now accept any object that implements the DensityInterface API (specifically `DensityInterface.densitykind` and `DensityInterface.logdensityof`) as likelihoods. In return, all BAT priors and posteriors support the DensityInterface API as well. + +* Support for [InverseFunctions](https://github.com/JuliaMath/InverseFunctions.jl) and [ChangesOfVariables](https://github.com/JuliaMath/ChangesOfVariables.jl): Parameter transformations in BAT now implement the DensityInterface API. Any function that supports + + * `InverseFunctions.inverse` + * `ChangesOfVariables.with_logabsdet_jacobian` + * `output_shape = f(input_shape::ValueShapes.AbstractValueShape)::AbstractValueShape` + +can now be used as a parameter transformation in BAT. + +* `BATContext`, `get_batcontext` and `set_batcontext` + +* `bat_transform` with enhanced capabilities + +* `distprod`, `distbind` and `lbqintegral` are the new ways to express priors and posteriors in BAT. + +* `bat_report` + +* `BAT.enable_error_log` (experimental) + +* `BAT.error_log` (experimental) + +* `BridgeSampling` (experimental) + +* `EllipsoidalNestedSampling` (experimental) + +* `ReactiveNestedSampling` (experimental) diff --git a/Project.toml b/Project.toml index 2d29e5010..27b4209f3 100644 --- a/Project.toml +++ b/Project.toml @@ -95,7 +95,7 @@ AdvancedHMC = "0.5, 0.6" AffineMaps = "0.2.3, 0.3" ArgCheck = "1, 2.0" ArraysOfArrays = "0.4, 0.5, 0.6" -AutoDiffOperators = "0.2" +AutoDiffOperators = "0.1.8, 0.2" ChainRulesCore = "0.9.44, 0.10, 1" ChangesOfVariables = "0.1.1" Clustering = "0.13, 0.14, 0.15" diff --git a/docs/src/experimental_api.md b/docs/src/experimental_api.md index e8feb7af9..174dcee57 100644 --- a/docs/src/experimental_api.md +++ b/docs/src/experimental_api.md @@ -36,13 +36,14 @@ SobolSampler truncate_batmeasure ValueAndThreshold -BAT.MCMCIterator -BAT.MCMCTunerState -BAT.TemperingState -BAT.MCMCProposalState -BAT.MCMCTempering -BAT.MCMCState BAT.MCMCChainState BAT.MCMCChainStateInfo +BAT.MCMCIterator BAT.MCMCProposal +BAT.MCMCProposalState +BAT.MCMCProposalTunerState +BAT.MCMCState +BAT.MCMCTempering +BAT.MCMCTransformTunerState +BAT.TemperingState ``` diff --git a/docs/src/list_of_algorithms.md b/docs/src/list_of_algorithms.md index 1ca94bb4e..614244e26 100644 --- a/docs/src/list_of_algorithms.md +++ b/docs/src/list_of_algorithms.md @@ -19,21 +19,21 @@ bat_sample(target.prior, IIDSampling(nsamples=10^5)) ### Metropolis-Hastings -BAT sampling algorithm type: [`MCMCSampling`](@ref), MCMC algorithm subtype: [`MetropolisHastings`](@ref) +BAT sampling algorithm type: [`TransformedMCMC`](@ref), MCMC algorithm subtype: [`RandomWalk`](@ref) ```julia -bat_sample(target, MCMCSampling(mcalg = MetropolisHastings(), nsteps = 10^5, nchains = 4)) +bat_sample(target, TransformedMCMC(mcalg = RandomWalk(), nsteps = 10^5, nchains = 4)) ``` ### Hamiltonian MC -BAT sampling algorithm type: [`MCMCSampling`](@ref), MCMC algorithm subtype: [`HamiltonianMC`](@ref) +BAT sampling algorithm type: [`TransformedMCMC`](@ref), MCMC algorithm subtype: [`HamiltonianMC`](@ref) ```julia import AdvancedHMC, ForwardDiff set_batcontext(ad = ADSelector(ForwardDiff)) -bat_sample(target, MCMCSampling(mcalg = HamiltonianMC())) +bat_sample(target, TransformedMCMC(mcalg = HamiltonianMC())) ``` Requires the [AdvancedHMC](https://github.com/TuringLang/AdvancedHMC.jl) Julia package to be loaded explicitly. diff --git a/docs/src/stable_api.md b/docs/src/stable_api.md index 9a728e8d9..a126b8403 100644 --- a/docs/src/stable_api.md +++ b/docs/src/stable_api.md @@ -52,12 +52,13 @@ lbqintegral AbstractMCMCWeightingScheme AbstractPosteriorMeasure AbstractTransformTarget -AdaptiveMHTuning +AdaptiveAffineTuning AssumeConvergence AutocorLenAlgorithm BATContext BATHDF5IO BATIOAlgorithm +BinningAlgorithm BrooksGelmanConvergence CuhreIntegration DensitySample @@ -68,6 +69,8 @@ EffSampleSizeAlgorithm EffSampleSizeFromAC EvaluatedMeasure ExplicitInit +FixedNBins +FreedmanDiaconisBinning GelmanRubinConvergence GeyerAutocorLen HamiltonianMC @@ -85,12 +88,11 @@ MCMCBurninAlgorithm MCMCChainPoolInit MCMCInitAlgorithm MCMCMultiCycleBurnin -MCMCNoOpTuning -MCMCSampling -MCMCTuning -MetropolisHastings -MHProposalDistTuning +MCMCProposalTuning +MCMCTransformTuning ModeAsDefined +NoMCMCProposalTuning +NoMCMCTransformTuning OptimAlg OptimizationAlg OrderedResampling @@ -98,21 +100,21 @@ PosteriorMeasure PriorSubstitution PriorToGaussian PriorToUniform +RAMTuning +RandomWalk RandResampling RepetitionWeighting -SampleMedianEstimator -SokalAutocorLen -SuaveIntegration -TransformAlgorithm -VEGASIntegration -BinningAlgorithm -FixedNBins -FreedmanDiaconisBinning RiceBinning +SampleMedianEstimator ScottBinning +SokalAutocorLen SquareRootBinning SturgesBinning +SuaveIntegration ToRealVector +TransformAlgorithm +TransformedMCMC +VEGASIntegration BAT.AbstractMedianEstimator BAT.AbstractModeEstimator diff --git a/docs/src/tutorial_lit.jl b/docs/src/tutorial_lit.jl index fe290c9b9..0965bb912 100644 --- a/docs/src/tutorial_lit.jl +++ b/docs/src/tutorial_lit.jl @@ -230,7 +230,7 @@ posterior = PosteriorMeasure(likelihood, prior) # Now we can generate a set of MCMC samples via [`bat_sample`](@ref). We'll # use 4 MCMC chains with 10^5 MC steps in each chain (after tuning/burn-in): -samples = bat_sample(posterior, MCMCSampling(proposal = MetropolisHastings(), nsteps = 10^5, nchains = 4)).result +samples = bat_sample(posterior, TransformedMCMC(proposal = RandomWalk(), nsteps = 10^5, nchains = 4)).result #md nothing # hide #nb nothing # hide @@ -374,11 +374,9 @@ plot!(-4:0.01:4, x -> fit_function(true_par_values, x), color=4, label = "Truth" # All option value used in the following are the default values, any or all # may be omitted. -# We'll sample using the The Metropolis-Hastings MCMC algorithm: +# We'll sample using the random-walk Metropolis-Hastings MCMC algorithm: -mcmcalgo = MetropolisHastings( - weighting = RepetitionWeighting() -) +mcmcalgo = RandomWalk() # BAT requires a counter-based random number generator (RNG), since it # partitions the RNG space over the MCMC chains. This way, a single RNG seed @@ -393,7 +391,7 @@ context = BATContext(rng = Philox4x()) #md nothing # hide -# By default, `MetropolisHastings()` uses the following options. +# By default, `RandomWalk()` uses the following options. # # For Markov chain initialization: @@ -412,7 +410,7 @@ convergence = BrooksGelmanConvergence() samples = bat_sample( posterior, - MCMCSampling( + TransformedMCMC( proposal = mcmcalgo, nchains = 4, nsteps = 10^5, diff --git a/examples/benchmarks/benchmarks.jl b/examples/benchmarks/benchmarks.jl index a9893a924..662bef128 100644 --- a/examples/benchmarks/benchmarks.jl +++ b/examples/benchmarks/benchmarks.jl @@ -21,10 +21,10 @@ function setup_benchmark() include("run_benchmark_ND.jl") end -function do_benchmarks(;algorithm=MetropolisHastings(), nsteps=10^6, nchains=8) +function do_benchmarks(;algorithm=RandomWalk(), nsteps=10^6, nchains=8) #run_1D_benchmark(algorithm=algorithm, nsteps=nsteps, nchains=nchains) run_2D_benchmark(algorithm=algorithm, nsteps=nsteps, nchains=nchains) - run_ND_benchmark(n_dim=2:2:20,algorithm=MetropolisHastings(), nsteps=nsteps, nchains=4) + run_ND_benchmark(n_dim=2:2:20,algorithm=RandomWalk(), nsteps=nsteps, nchains=4) run_ks_ahmc_vs_mh(n_dim=20:5:35) end diff --git a/examples/benchmarks/run_benchmark_1D.jl b/examples/benchmarks/run_benchmark_1D.jl index 584691984..961e2ea52 100644 --- a/examples/benchmarks/run_benchmark_1D.jl +++ b/examples/benchmarks/run_benchmark_1D.jl @@ -1,4 +1,4 @@ -function run_1D_benchmark(;algorithm=MetropolisHastings(), nsteps=10^5, nchains=8) +function run_1D_benchmark(;algorithm=RandomWalk(), nsteps=10^5, nchains=8) for i in 1:length(testfunctions_1D) sample_stats_all = run1D( collect(keys(testfunctions_1D))[i], #There might be a nicer way but I need the name to save the plots diff --git a/examples/benchmarks/run_benchmark_2D.jl b/examples/benchmarks/run_benchmark_2D.jl index 7155339ad..0c18f3af5 100644 --- a/examples/benchmarks/run_benchmark_2D.jl +++ b/examples/benchmarks/run_benchmark_2D.jl @@ -1,4 +1,4 @@ -function run_2D_benchmark(;algorithm = MetropolisHastings(),nchains = 8,nsteps = 10^5) +function run_2D_benchmark(;algorithm = RandomWalk(),nchains = 8,nsteps = 10^5) for i in 1:length(testfunctions_2D) sample_stats_all = run2D( collect(keys(testfunctions_2D))[i], #There might be a nicer way but I need the name to save the plots diff --git a/examples/benchmarks/run_benchmark_ND.jl b/examples/benchmarks/run_benchmark_ND.jl index aa2db59e7..6c909efc2 100644 --- a/examples/benchmarks/run_benchmark_ND.jl +++ b/examples/benchmarks/run_benchmark_ND.jl @@ -203,7 +203,7 @@ end function run_ND_benchmark(; n_dim = 2:2:20, - algorithm = MetropolisHastings(), + algorithm = RandomWalk(), nchains = 4, nsteps = 4*10^5, time_benchmark = true, @@ -244,10 +244,10 @@ function run_ND_benchmark(; mcmc_sample = nothing tbf = time() - if isa(algorithm,BAT.MetropolisHastings) + if isa(algorithm,BAT.RandomWalk) mcmc_sample = bat_sample( dis, - MCMCSampling( + TransformedMCMC( mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, @@ -262,7 +262,7 @@ function run_ND_benchmark(; elseif isa(algorithm,BAT.HamiltonianMC) mcmc_sample = bat_sample( dis, - MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps) + TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps) ).result end taf = time() @@ -271,10 +271,10 @@ function run_ND_benchmark(; time_per_run = Array{Float64}(undef,5) for i in 1:5 tbf = time() - if isa(algorithm,BAT.MetropolisHastings) + if isa(algorithm,BAT.RandomWalk) bat_sample( dis, - MCMCSampling( + TransformedMCMC( mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, @@ -289,7 +289,7 @@ function run_ND_benchmark(; elseif isa(algorithm,BAT.HamiltonianMC) bat_sample( dis, - MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps) + TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps) ).result end taf = time() @@ -359,6 +359,6 @@ end function run_ks_ahmc_vs_mh(;n_dim=20:5:35,nsteps=2*10^5, nchains=4) ks_res_ahmc = run_ND_benchmark(n_dim=n_dim,algorithm=HamiltonianMC(), nsteps=nsteps, nchains=nchains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1] - ks_res_mh = run_ND_benchmark(n_dim=n_dim,algorithm=MetropolisHastings(), nsteps=nsteps, nchains=nchains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1] + ks_res_mh = run_ND_benchmark(n_dim=n_dim,algorithm=RandomWalk(), nsteps=nsteps, nchains=nchains, time_benchmark=false,ahmi_benchmark=false,hmc_benchmark=true)[1] plot_ks_values_ahmc_vs_mh(ks_res_ahmc,ks_res_mh,n_dim) end diff --git a/examples/benchmarks/utils.jl b/examples/benchmarks/utils.jl index 53bc5156d..db90af749 100644 --- a/examples/benchmarks/utils.jl +++ b/examples/benchmarks/utils.jl @@ -207,10 +207,10 @@ function run1D( ) sample_stats_all = [] - samples, chains = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) + samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) for i in 1:n_runs time_before = time() - samples, chains = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) + samples, chains = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) time_after = time() h = plot1D(samples,testfunctions,key,sample_stats)# posterior, key, analytical_stats,sample_stats) @@ -438,10 +438,10 @@ function run2D( sample_stats_all = [] - samples, stats = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) + samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) for i in 1:n_runs time_before = time() - samples, stats = bat_sample(testfunctions[key].posterior, MCMCSampling(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) + samples, stats = bat_sample(testfunctions[key].posterior, TransformedMCMC(mcalg = algorithm, trafo = DoNotTransform(), nchains = nchains, nsteps = nsteps)) time_after = time() h = plot2D(samples, testfunctions, key, sample_stats) diff --git a/examples/dev-internal/output_examples.jl b/examples/dev-internal/output_examples.jl index 743fbb536..460125c90 100644 --- a/examples/dev-internal/output_examples.jl +++ b/examples/dev-internal/output_examples.jl @@ -23,7 +23,7 @@ prior = BAT.NamedTupleDist( posterior = PosteriorMeasure(likelihood, prior); -samples, chains = bat_sample(posterior, MCMCSampling(mcalg = MetropolisHastings(), nsteps = 10^5)); +samples, chains = bat_sample(posterior, TransformedMCMC(mcalg = RandomWalk(), nsteps = 10^5)); #samples = bat_sample(posterior, SobolSampler(nsamples = 10^5)).result; sd = EvaluatedMeasure(posterior, samples) diff --git a/examples/dev-internal/plotting_examples.jl b/examples/dev-internal/plotting_examples.jl index ba62cbe74..a2df0be33 100644 --- a/examples/dev-internal/plotting_examples.jl +++ b/examples/dev-internal/plotting_examples.jl @@ -31,7 +31,7 @@ prior = BAT.NamedTupleDist( posterior = PosteriorMeasure(likelihood, prior); -samples, chains = bat_sample(posterior, MCMCSampling(mcalg = MetropolisHastings(), nsteps = 10^5)); +samples, chains = bat_sample(posterior, TransformedMCMC(mcalg = RandomWalk(), nsteps = 10^5)); # ## Set up plotting # Set up plotting using the [Plots.jl](https://github.com/JuliaPlots/Plots.jl) package: diff --git a/examples/paper-example/paper_example.jl b/examples/paper-example/paper_example.jl index bfcddb2da..078a9b249 100644 --- a/examples/paper-example/paper_example.jl +++ b/examples/paper-example/paper_example.jl @@ -143,7 +143,7 @@ posterior_bkg_signal = PosteriorMeasure(SignalBkgLikelihood(summary_dataset_tabl nchains = 4 nsteps = 10^5 -algorithm = MCMCSampling(proposal = HamiltonianMC(), nchains = nchains, nsteps = nsteps) +algorithm = TransformedMCMC(proposal = HamiltonianMC(), nchains = nchains, nsteps = nsteps) samples_bkg = bat_sample(posterior_bkg, algorithm).result eval_bkg = EvaluatedMeasure(posterior_bkg, samples = samples_bkg) diff --git a/ext/BATAdvancedHMCExt.jl b/ext/BATAdvancedHMCExt.jl index 8fd9f89a8..323d90c6f 100644 --- a/ext/BATAdvancedHMCExt.jl +++ b/ext/BATAdvancedHMCExt.jl @@ -15,13 +15,13 @@ using BAT: MeasureLike, BATMeasure using BAT: get_context, get_adselector, _NoADSelected using BAT: getproposal, mcmc_target -using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCTunerState, NoMCMCTempering +using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples -using BAT: AbstractTransformTarget, TriangularAffineTransform +using BAT: AbstractTransformTarget, NoAdaptiveTransform using BAT: RNGPartition, get_rng, set_rng! using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio using BAT: get_samples!, reset_rng_counters! -using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!!, tuning_callback +using BAT: create_trafo_tuner_state, create_proposal_tuner_state, mcmc_tuning_init!!, mcmc_tuning_postinit!!, mcmc_tuning_reinit!!, mcmc_tune_transform_post_cycle!!, transform_mcmc_tuning_finalize!! using BAT: totalndof, measure_support, checked_logdensityof using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJECTED_SAMPLE diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index cf5330854..a452551ab 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -1,20 +1,20 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -BAT.bat_default(::Type{MCMCSampling}, ::Val{:pre_transform}, proposal::HamiltonianMC) = PriorToGaussian() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pre_transform}, proposal::HamiltonianMC) = PriorToGaussian() -BAT.bat_default(::Type{MCMCSampling}, ::Val{:trafo_tuning}, proposal::HamiltonianMC) = StanHMCTuning() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StanHMCTuning() -BAT.bat_default(::Type{MCMCSampling}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = TriangularAffineTransform() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = NoAdaptiveTransform() -BAT.bat_default(::Type{MCMCSampling}, ::Val{:tempering}, proposal::HamiltonianMC) = NoMCMCTempering() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:tempering}, proposal::HamiltonianMC) = NoMCMCTempering() -BAT.bat_default(::Type{MCMCSampling}, ::Val{:nsteps}, proposal::HamiltonianMC, pre_transform::AbstractTransformTarget, nchains::Integer) = 10^4 +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:nsteps}, proposal::HamiltonianMC, pre_transform::AbstractTransformTarget, nchains::Integer) = 10^4 -BAT.bat_default(::Type{MCMCSampling}, ::Val{:init}, proposal::HamiltonianMC, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:init}, proposal::HamiltonianMC, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = MCMCChainPoolInit(nsteps_init = 25) # clamp(div(nsteps, 100), 25, 250) -BAT.bat_default(::Type{MCMCSampling}, ::Val{:burnin}, proposal::HamiltonianMC, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:burnin}, proposal::HamiltonianMC, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 250), max_ncycles = 4) @@ -50,8 +50,7 @@ function BAT._create_proposal_state( termination, hamiltonian, kernel, - transition, - proposal.weighting + transition ) end @@ -114,7 +113,11 @@ function BAT.mcmc_propose!!(mc_state::HMCState) accepted = x_current != x_proposed - return mc_state, accepted, Float64(accepted) + # TODO: Setting p_accept to 1 or 0 for now. + # Use AdvancedHMC.stat(transition).acceptance_rate in the future? + p_accept = Float64(accepted) + + return mc_state, accepted, p_accept end function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer) @@ -137,11 +140,7 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float samples.info.sampletype[proposed] = REJECTED_SAMPLE end - delta_w_current, w_proposed = if accepted - (0, 1) - else - (1, 0) - end + delta_w_current, w_proposed = BAT.mcmc_weight_values(mc_state.weighting, p_accept, accepted) samples.weight[current] += delta_w_current samples.weight[proposed] = w_proposed diff --git a/ext/ahmc_impl/ahmc_tuner_impl.jl b/ext/ahmc_impl/ahmc_tuner_impl.jl index 754847b0a..b7b58e504 100644 --- a/ext/ahmc_impl/ahmc_tuner_impl.jl +++ b/ext/ahmc_impl/ahmc_tuner_impl.jl @@ -1,36 +1,26 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -struct HMCTrafoTunerState <: MCMCTunerState end - -mutable struct HMCProposalTunerState{A<:AdvancedHMC.AbstractAdaptor} <: MCMCTunerState +mutable struct HMCProposalTunerState{A<:AdvancedHMC.AbstractAdaptor} <: MCMCProposalTunerState tuning::HMCTuning target_acceptance::Float64 adaptor::A end -(tuning::HMCTuning)(chain_state::HMCState) = HMCProposalTunerState(tuning, chain_state), HMCTrafoTunerState() - -HMCTrafoTunerState(tuning::HMCTuning) = HMCTrafoTunerState() - function HMCProposalTunerState(tuning::HMCTuning, chain_state::MCMCChainState) θ = first(chain_state.samples).v adaptor = ahmc_adaptor(tuning, chain_state.proposal.hamiltonian.metric, chain_state.proposal.kernel.τ.integrator, θ) HMCProposalTunerState(tuning, tuning.target_acceptance, adaptor) end -BAT.create_trafo_tuner_state(tuning::HMCTuning, chain_state::MCMCChainState, iteration::Integer) = HMCTrafoTunerState(tuning) - BAT.create_proposal_tuner_state(tuning::HMCTuning, chain_state::MCMCChainState, iteration::Integer) = HMCProposalTunerState(tuning, chain_state) -BAT.mcmc_tuning_init!!(tuner::HMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) = nothing function BAT.mcmc_tuning_init!!(tuner::HMCProposalTunerState, chain_state::HMCState, max_nsteps::Integer) AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1)) nothing end -BAT.mcmc_tuning_reinit!!(tuner::HMCTrafoTunerState, chain_state::HMCState, max_nsteps::Integer) = nothing function BAT.mcmc_tuning_reinit!!(tuner::HMCProposalTunerState, chain_state::HMCState, max_nsteps::Integer) AdvancedHMC.Adaptation.initialize!(tuner.adaptor, Int(max_nsteps - 1)) @@ -38,13 +28,9 @@ function BAT.mcmc_tuning_reinit!!(tuner::HMCProposalTunerState, chain_state::HMC end -BAT.mcmc_tuning_postinit!!(tuner::HMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing - BAT.mcmc_tuning_postinit!!(tuner::HMCProposalTunerState, chain_state::HMCState, samples::DensitySampleVector) = nothing -BAT.mcmc_tune_post_cycle!!(tuner::HMCTrafoTunerState, chain_state::HMCState, samples::DensitySampleVector) = chain_state, tuner, false - function BAT.mcmc_tune_post_cycle!!(tuner::HMCProposalTunerState, chain_state::HMCState, samples::DensitySampleVector) max_log_posterior = maximum(samples.logd) accept_ratio = eff_acceptance_ratio(chain_state) @@ -59,8 +45,6 @@ function BAT.mcmc_tune_post_cycle!!(tuner::HMCProposalTunerState, chain_state::H end -BAT.mcmc_tuning_finalize!!(tuner::HMCTrafoTunerState, chain_state::HMCState) = nothing - function BAT.mcmc_tuning_finalize!!(tuner::HMCProposalTunerState, chain_state::HMCState) adaptor = tuner.adaptor proposal = chain_state.proposal @@ -71,19 +55,6 @@ function BAT.mcmc_tuning_finalize!!(tuner::HMCProposalTunerState, chain_state::H end -BAT.tuning_callback(::HMCTrafoTunerState) = nop_func - -BAT.tuning_callback(::HMCProposalTunerState) = nop_func - - -function BAT.mcmc_tune_post_step!!( - tuner_state::HMCTrafoTunerState, - chain_state::MCMCChainState, - p_accept::Real -) - return chain_state, tuner_state, false -end - # TODO: MD, make actually !! function function BAT.mcmc_tune_post_step!!( tuner_state::HMCProposalTunerState, diff --git a/src/algodefaults/default_sampling_algorithm.jl b/src/algodefaults/default_sampling_algorithm.jl index c25794173..563bf182f 100644 --- a/src/algodefaults/default_sampling_algorithm.jl +++ b/src/algodefaults/default_sampling_algorithm.jl @@ -12,7 +12,7 @@ end bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::DensitySampleVector) = OrderedResampling() bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::DensitySampleMeasure) = OrderedResampling() -bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::PosteriorMeasure) = MCMCSampling() +bat_default(::typeof(bat_sample), ::Val{:algorithm}, ::PosteriorMeasure) = TransformedMCMC() function bat_default(::typeof(bat_sample), ::Val{:algorithm}, m::EvaluatedMeasure) bat_default(bat_sample, Val(:algorithm), m.measure) diff --git a/src/algotypes/bat_default.jl b/src/algotypes/bat_default.jl index 9d3efd9ff..150313bbd 100644 --- a/src/algotypes/bat_default.jl +++ b/src/algotypes/bat_default.jl @@ -15,7 +15,7 @@ Which arguments are considered to be objectives is function-specific. For example: ```julia -bat_default(bat_sample, :algorithm, density::PosteriorMeasure) == MetropolisHastings() +bat_default(bat_sample, :algorithm, density::PosteriorMeasure) == RandomWalk() bat_default(bat_sample, Val(:algorithm), samples::DensitySampleVector) == OrderedResampling() ``` """ diff --git a/src/deprecations.jl b/src/deprecations.jl index d5867b06f..213ea6a30 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -50,3 +50,33 @@ export PosteriorDensity @deprecate bat_initval(rng::AbstractRNG, target::MeasureLike, n::Integer, algorithm::InitvalAlgorithm) = bat_initval(target, n, algorithm, BAT.set_rng(BAT.get_batcontext(), rng)) @deprecate bat_initval(rng::AbstractRNG, target::MeasureLike, n::Integer) = bat_initval(target, n, BAT.set_rng(BAT.get_batcontext(), rng)) =# + + +Base.@deprecate MetropolisHastings() RandomWalk() + +Base.@deprecate MCMCSampling(; + mcalg::MCMCProposal = RandomWalk(), + trafo::AbstractTransformTarget = bat_default(TransformedMCMC, Val(:pre_transform), mcalg), + nchains::Int = 4, + nsteps::Int = bat_default(TransformedMCMC, Val(:nsteps), mcalg, trafo, nchains), + init::MCMCInitAlgorithm = bat_default(TransformedMCMC, Val(:init), mcalg, trafo, nchains, nsteps), + burnin::MCMCBurninAlgorithm = bat_default(TransformedMCMC, Val(:burnin), mcalg, trafo, nchains, nsteps), + convergence::ConvergenceTest = BrooksGelmanConvergence(), + strict::Bool = true, + store_burnin::Bool = false, + nonzero_weights::Bool = true, + callback::Function = nop_func +) TransformedMCMC( + proposal = mcalg, + pre_transform = trafo, + nchains = nchains, + nsteps = nsteps, + init = init, + burnin = burnin, + convergence = convergence, + strict = strict, + store_burnin = store_burnin, + nonzero_weights = nonzero_weights, + callback = callback +) +export MCMCSampling diff --git a/src/extdefs/ahmc_defs/ahmc_alg.jl b/src/extdefs/ahmc_defs/ahmc_alg.jl index 97d202489..3772987b4 100644 --- a/src/extdefs/ahmc_defs/ahmc_alg.jl +++ b/src/extdefs/ahmc_defs/ahmc_alg.jl @@ -33,13 +33,11 @@ $(TYPEDFIELDS) @with_kw struct HamiltonianMC{ MT<:HMCMetric, IT, - TC, - WS<:AbstractMCMCWeightingScheme + TC } <: MCMCProposal metric::MT = DiagEuclideanMetric() integrator::IT = ext_default(pkgext(Val(:AdvancedHMC)), Val(:DEFAULT_INTEGRATOR)) termination::TC = ext_default(pkgext(Val(:AdvancedHMC)), Val(:DEFAULT_TERMINATION_CRITERION)) - weighting::WS = RepetitionWeighting() end export HamiltonianMC @@ -50,24 +48,15 @@ mutable struct HMCProposalState{ TC, HA,#<:AdvancedHMC.Hamiltonian, KRNL,#<:AdvancedHMC.HMCKernel - TR,# <:AdvancedHMC.Transition - WS<:AbstractMCMCWeightingScheme + TR# <:AdvancedHMC.Transition } <: MCMCProposalState integrator::IT termination::TC hamiltonian::HA kernel::KRNL transition::TR - weighting::WS end export HMCProposalState -const HMCState = MCMCChainState{<:BATMeasure, - <:RNGPartition, - <:Function, - <:HMCProposalState, - <:DensitySampleVector, - <:DensitySampleVector, - <:BATContext -} +const HMCState = MCMCChainState{<:BATMeasure, <:RNGPartition, <:Function, <:HMCProposalState} diff --git a/src/extdefs/ahmc_defs/ahmc_config.jl b/src/extdefs/ahmc_defs/ahmc_config.jl index 04f68ac64..653ca0b9d 100644 --- a/src/extdefs/ahmc_defs/ahmc_config.jl +++ b/src/extdefs/ahmc_defs/ahmc_config.jl @@ -14,7 +14,7 @@ struct DenseEuclideanMetric <: HMCMetric end # Tuning ============================================== -abstract type HMCTuning <: MCMCTuning end +abstract type HMCTuning <: MCMCProposalTuning end @with_kw struct MassMatrixAdaptor <: HMCTuning target_acceptance::Float64 = 0.8 diff --git a/src/precompile.jl b/src/precompile.jl index 2b0707db2..4837c2d0b 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -9,7 +9,7 @@ let precompile(EvaluatedMeasure, map(typeof, (posterior, dummy_samples))) - for mcalg in (MetropolisHastings(), HamiltonianMC()) - precompile(bat_sample, map(typeof, (posterior, MCMCSampling(mcalg = mcalg)))) + for mcalg in (RandomWalk(), HamiltonianMC()) + precompile(bat_sample, map(typeof, (posterior, TransformedMCMC(mcalg = mcalg)))) end end diff --git a/src/transforms/adaptive_transform.jl b/src/samplers/adaptive_transform.jl similarity index 52% rename from src/transforms/adaptive_transform.jl rename to src/samplers/adaptive_transform.jl index 6bf4f5b40..c3423dcce 100644 --- a/src/transforms/adaptive_transform.jl +++ b/src/samplers/adaptive_transform.jl @@ -9,22 +9,30 @@ end CustomTransform() = CustomTransform(identity) +init_adaptive_transform(at::CustomTransform, ::AbstractMeasure, ::BATContext) = at.f + + + +struct NoAdaptiveTransform <: AbstractAdaptiveTransform end + +init_adaptive_transform(::NoAdaptiveTransform, ::AbstractMeasure, ::BATContext) = identity + + + struct TriangularAffineTransform <: AbstractAdaptiveTransform end # TODO: MD, make typestable -function init_adaptive_transform( - adaptive_transform::BAT.TriangularAffineTransform, - target, - context -) +function init_adaptive_transform(adaptive_transform::TriangularAffineTransform, target::AbstractMeasure, ::BATContext) n = totalndof(varshape(target)) M = _approx_cov(target, n) + b = _approx_mean(target, n) s = cholesky(M).L - g = Mul(s) + g = MulAdd(s, b) return g end +# TODO: Implement DiagonalAffineTransform struct DiagonalAffineTransform <: AbstractAdaptiveTransform end diff --git a/src/samplers/mcmc/chain_pool_init.jl b/src/samplers/mcmc/chain_pool_init.jl index c2ee72c5d..74b7a6d79 100644 --- a/src/samplers/mcmc/chain_pool_init.jl +++ b/src/samplers/mcmc/chain_pool_init.jl @@ -32,7 +32,7 @@ end function _construct_mcmc_state( - sampling::MCMCSampling, + samplingalg::TransformedMCMC, target::BATMeasure, rngpart::RNGPartition, id::Integer, @@ -41,27 +41,27 @@ function _construct_mcmc_state( ) new_context = set_rng(parent_context, AbstractRNG(rngpart, id)) v_init = bat_initval(target, initval_alg, new_context).result - return MCMCState(sampling, target, Int32(id), v_init, new_context) + return MCMCState(samplingalg, target, Int32(id), v_init, new_context) end _gen_mcmc_states( - sampling::MCMCSampling, + samplingalg::TransformedMCMC, target::BATMeasure, rngpart::RNGPartition, ids::AbstractRange{<:Integer}, initval_alg::InitvalAlgorithm, context::BATContext -) = [_construct_mcmc_state(sampling, target, rngpart, id, initval_alg, context) for id in ids] +) = [_construct_mcmc_state(samplingalg, target, rngpart, id, initval_alg, context) for id in ids] function mcmc_init!( - sampling::MCMCSampling, + samplingalg::TransformedMCMC, target::BATMeasure, init_alg::MCMCChainPoolInit, callback::Function, context::BATContext )::NamedTuple{(:mcmc_states, :outputs), Tuple{Vector{MCMCState}, Vector{DensitySampleVector}}} - @unpack tempering, nchains, trafo_tuning, proposal_tuning, nonzero_weights = sampling + @unpack tempering, nchains, transform_tuning, proposal_tuning, nonzero_weights = samplingalg @info "MCMCChainPoolInit: trying to generate $nchains viable MCMC chain state(s)." @@ -79,7 +79,7 @@ function mcmc_init!( dummy_context = deepcopy(context) dummy_initval = unshaped(bat_initval(target, InitFromTarget(), dummy_context).result, varshape(target)) - dummy_mcmc_state = MCMCState(sampling, target, one(Int32), dummy_initval, dummy_context) + dummy_mcmc_state = MCMCState(samplingalg, target, one(Int32), dummy_initval, dummy_context) mcmc_states = similar([dummy_mcmc_state], 0) outputs = similar([DensitySampleVector(dummy_mcmc_state)], 0) @@ -90,7 +90,7 @@ function mcmc_init!( n = min(min_nviable, max_ncandidates - ncandidates) @debug "Generating $n $(cycle > 1 ? "additional " : "")candidate MCMC chain state(s)." - new_mcmc_states = _gen_mcmc_states(sampling, target, rngpart, ncandidates .+ (one(Int64):n), initval_alg, context) + new_mcmc_states = _gen_mcmc_states(samplingalg, target, rngpart, ncandidates .+ (one(Int64):n), initval_alg, context) filter!(isvalidstate, new_mcmc_states) diff --git a/src/samplers/mcmc/mcmc_algorithm.jl b/src/samplers/mcmc/mcmc_algorithm.jl index d9394c35e..6e21d638a 100644 --- a/src/samplers/mcmc/mcmc_algorithm.jl +++ b/src/samplers/mcmc/mcmc_algorithm.jl @@ -34,20 +34,35 @@ apply_trafo_to_init(trafo::Function, initalg::MCMCInitAlgorithm) = initalg """ - abstract type MCMCTuning + abstract type MCMCProposalTuning Abstract type for MCMC tuning algorithms. """ -abstract type MCMCTuning end -export MCMCTuning +abstract type MCMCProposalTuning end +export MCMCProposalTuning """ - abstract type MCMCTuning + abstract type MCMCProposalTunerState Abstract type for MCMC tuning algorithm states. """ -abstract type MCMCTunerState end -export MCMCTunerState +abstract type MCMCProposalTunerState end + + +""" + abstract type MCMCTransformTuning + +Abstract type for MCMC tuning algorithms. +""" +abstract type MCMCTransformTuning end +export MCMCTransformTuning + +""" + abstract type MCMCTransformTunerState + +Abstract type for MCMC tuning algorithm states. +""" +abstract type MCMCTransformTunerState end """ @@ -164,13 +179,13 @@ of the tuning and tempering algorithms used for sampling. """ struct MCMCState{ C<:MCMCIterator, - TT<:MCMCTunerState, - PT<:MCMCTunerState, + PT<:MCMCProposalTunerState, + TT<:MCMCTransformTunerState, T<:TemperingState } chain_state::C - trafo_tuner_state::TT proposal_tuner_state::PT + trafo_tuner_state::TT temperer_state::T end export MCMCState @@ -233,8 +248,6 @@ function mcmc_tune_post_step!! end function transform_mcmc_tuning_finalize!! end -function tuning_callback end - function mcmc_init! end diff --git a/src/samplers/mcmc/mcmc_sample.jl b/src/samplers/mcmc/mcmc_sample.jl index 3ada8f276..71eea7123 100644 --- a/src/samplers/mcmc/mcmc_sample.jl +++ b/src/samplers/mcmc/mcmc_sample.jl @@ -2,7 +2,7 @@ """ - struct MCMCSampling <: AbstractSamplingAlgorithm + struct TransformedMCMC <: AbstractSamplingAlgorithm Samples a probability density using Markov chain Monte Carlo. @@ -14,58 +14,66 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct MCMCSampling{ +@with_kw struct TransformedMCMC{ PR<:MCMCProposal, - TU<:MCMCTuning, + PRT<:MCMCProposalTuning, TR<:AbstractTransformTarget, - ATR<:AbstractAdaptiveTransform, + AT<:AbstractAdaptiveTransform, + ATT<:MCMCTransformTuning, TE<:MCMCTempering, IN<:MCMCInitAlgorithm, BI<:MCMCBurninAlgorithm, CT<:ConvergenceTest, + WS<:AbstractMCMCWeightingScheme, CB<:Function } <: AbstractSamplingAlgorithm - proposal::PR = MetropolisHastings(proposaldist = TDist(1.0)) - pre_transform::TR = bat_default(MCMCSampling, Val(:pre_transform), proposal) - trafo_tuning::TU = bat_default(MCMCSampling, Val(:trafo_tuning), proposal) - proposal_tuning::TU = trafo_tuning - adaptive_transform::ATR = bat_default(MCMCSampling, Val(:adaptive_transform), proposal) - tempering::TE = bat_default(MCMCSampling, Val(:tempering), proposal) + proposal::PR = RandomWalk(proposaldist = TDist(1.0)) + proposal_tuning::PRT = bat_default(TransformedMCMC, Val(:proposal_tuning), proposal) + pre_transform::TR = bat_default(TransformedMCMC, Val(:pre_transform), proposal) + adaptive_transform::AT = bat_default(TransformedMCMC, Val(:adaptive_transform), proposal) + transform_tuning::ATT = bat_default(TransformedMCMC, Val(:transform_tuning), adaptive_transform) + tempering::TE = bat_default(TransformedMCMC, Val(:tempering), proposal) nchains::Int = 4 - nsteps::Int = bat_default(MCMCSampling, Val(:nsteps), proposal, pre_transform, nchains) + nsteps::Int = bat_default(TransformedMCMC, Val(:nsteps), proposal, pre_transform, nchains) #TODO: max_time ? - init::IN = bat_default(MCMCSampling, Val(:init), proposal, pre_transform, nchains, nsteps) - burnin::BI = bat_default(MCMCSampling, Val(:burnin), proposal, pre_transform, nchains, nsteps) + init::IN = bat_default(TransformedMCMC, Val(:init), proposal, pre_transform, nchains, nsteps) + burnin::BI = bat_default(TransformedMCMC, Val(:burnin), proposal, pre_transform, nchains, nsteps) convergence::CT = BrooksGelmanConvergence() strict::Bool = true store_burnin::Bool = false nonzero_weights::Bool = true + sample_weighting::WS = RepetitionWeighting() callback::CB = nop_func end -export MCMCSampling +export TransformedMCMC -function MCMCState(samplingalg::MCMCSampling, target::BATMeasure, id::Integer, v_init::AbstractVector, context::BATContext) +bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, ::CustomTransform) = NoMCMCTransformTuning() +bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, ::NoAdaptiveTransform) = NoMCMCTransformTuning() +bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, ::TriangularAffineTransform) = RAMTuning() + + +function MCMCState(samplingalg::TransformedMCMC, target::BATMeasure, id::Integer, v_init::AbstractVector, context::BATContext) chain_state = MCMCChainState(samplingalg, target, Int32(id), v_init, context) - trafo_tuner_state = create_trafo_tuner_state(samplingalg.trafo_tuning, chain_state, 0) + trafo_tuner_state = create_trafo_tuner_state(samplingalg.transform_tuning, chain_state, 0) proposal_tuner_state = create_proposal_tuner_state(samplingalg.proposal_tuning, chain_state, 0) temperer_state = create_temperering_state(samplingalg.tempering, target) - MCMCState(chain_state, trafo_tuner_state, proposal_tuner_state, temperer_state) + MCMCState(chain_state, proposal_tuner_state, trafo_tuner_state, temperer_state) end -bat_default(::MCMCSampling, ::Val{:pre_transform}) = PriorToGaussian() +bat_default(::TransformedMCMC, ::Val{:pre_transform}) = PriorToGaussian() -bat_default(::MCMCSampling, ::Val{:nsteps}, trafo::AbstractTransformTarget, nchains::Integer) = 10^5 +bat_default(::TransformedMCMC, ::Val{:nsteps}, trafo::AbstractTransformTarget, nchains::Integer) = 10^5 -bat_default(::MCMCSampling, ::Val{:init}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = +bat_default(::TransformedMCMC, ::Val{:init}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = MCMCChainPoolInit(nsteps_init = max(div(nsteps, 100), 250)) -bat_default(::MCMCSampling, ::Val{:burnin}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = +bat_default(::TransformedMCMC, ::Val{:burnin}, trafo::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 2500)) -function bat_sample_impl(target::BATMeasure, samplingalg::MCMCSampling, context::BATContext) +function bat_sample_impl(target::BATMeasure, samplingalg::TransformedMCMC, context::BATContext) target_transformed, pre_transform = transform_and_unshape(samplingalg.pre_transform, target, context) diff --git a/src/samplers/mcmc/mcmc_state.jl b/src/samplers/mcmc/mcmc_state.jl index e20add7c7..de4034190 100644 --- a/src/samplers/mcmc/mcmc_state.jl +++ b/src/samplers/mcmc/mcmc_state.jl @@ -13,6 +13,7 @@ mutable struct MCMCChainState{ PR<:RNGPartition, FT<:Function, P<:MCMCProposalState, + WS<:AbstractMCMCWeightingScheme, SVX<:DensitySampleVector, SVZ<:DensitySampleVector, CTX<:BATContext @@ -20,6 +21,7 @@ mutable struct MCMCChainState{ target::M proposal::P f_transform::FT + weighting::WS samples::SVX sample_z::SVZ info::MCMCChainStateInfo @@ -31,7 +33,7 @@ end export MCMCChainState function MCMCChainState( - samplingalg::MCMCSampling, + samplingalg::TransformedMCMC, target::BATMeasure, id::Integer, v_init::AbstractVector{P}, @@ -56,7 +58,7 @@ function MCMCChainState( z = inverse_g(v_init) logd_z = logdensityof(MeasureBase.pullback(g, target), z) - W = _weight_type(proposal.weighting) + W = mcmc_weight_type(samplingalg.sample_weighting) T = typeof(logd_x) info, sample_id_type = _get_sample_id(proposal, Int32(id), cycle, 1, CURRENT_SAMPLE) @@ -74,6 +76,7 @@ function MCMCChainState( target, proposal, g, + samplingalg.sample_weighting, samples, sample_z, MCMCChainStateInfo(id, cycle, false, false), @@ -140,9 +143,8 @@ function DensitySampleVector(chain_state::MCMCChainState) DensitySampleVector(sample_type(chain_state), totalndof(varshape(mcmc_target(chain_state)))) end -# TODO: MD, make into !! -function mcmc_step!!(mcmc_state::MCMCState) +function mcmc_step!!(mcmc_state::MCMCState) # TODO: MD, include sample_z in _cleanup_samples() _cleanup_samples(mcmc_state) diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl index 1792e64a6..09badd13e 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl @@ -1,13 +1,13 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -# ToDo: Add literature references to AdaptiveMHTuning docstring. +# ToDo: Add literature references to AdaptiveAffineTuning docstring. """ - struct AdaptiveMHTuning <: MHProposalDistTuning + struct AdaptiveAffineTuning <: MCMCTransformTuning -Adaptive MCMC tuning strategy for Metropolis-Hastings samplers. +Adaptive cycle-based MCMC tuning strategy. -Adapts the proposal function based on the acceptance ratio and covariance -of the previous samples. +Adapts an affine space transformation based on the acceptance ratio and +covariance of the previous samples. Constructors: @@ -17,20 +17,20 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct AdaptiveMHTuning <: MHProposalDistTuning +@with_kw struct AdaptiveAffineTuning <: MCMCTransformTuning "Controls the weight given to new covariance information in adapting the - proposal distribution." + affine transform." λ::Float64 = 0.5 "Metropolis-Hastings acceptance ratio target, tuning will try to adapt - the proposal distribution to bring the acceptance ratio inside this interval." + the affine transform to bring the acceptance ratio inside this interval." α::IntervalSets.ClosedInterval{Float64} = ClosedInterval(0.15, 0.35) - "Controls how much the spread of the proposal distribution is + "Controls how much the scale of the affine transform is widened/narrowed depending on the current MH acceptance ratio." β::Float64 = 1.5 - "Interval for allowed scale/spread of the proposal distribution." + "Interval for allowed scale of the affine transform distribution." c::IntervalSets.ClosedInterval{Float64} = ClosedInterval(1e-4, 1e2) "Reweighting factor. Take accumulated sample statistics of previous @@ -39,42 +39,32 @@ $(TYPEDFIELDS) r::Real = 0.5 end -export AdaptiveMHTuning +export AdaptiveAffineTuning # TODO: MD, make immutable and use Accessors.jl -mutable struct AdaptiveMHTrafoTunerState{ +mutable struct AdaptiveAffineTuningState{ S<:MCMCBasicStats -} <: MCMCTunerState - tuning::AdaptiveMHTuning +} <: MCMCTransformTunerState + tuning::AdaptiveAffineTuning stats::S iteration::Int scale::Float64 end -struct AdaptiveMHProposalTunerState <: MCMCTunerState end -(tuning::AdaptiveMHTuning)(chain_state::MCMCChainState) = AdaptiveMHTrafoTunerState(tuning, chain_state), AdaptiveMHProposalTunerState() - -# TODO: MD, what should the default be? -default_adaptive_transform(tuning::AdaptiveMHTuning) = TriangularAffineTransform() - -function AdaptiveMHTrafoTunerState(tuning::AdaptiveMHTuning, chain_state::MCMCChainState) +function AdaptiveAffineTuningState(tuning::AdaptiveAffineTuning, chain_state::MCMCChainState) m = totalndof(varshape(mcmc_target(chain_state))) scale = 2.38^2 / m - AdaptiveMHTrafoTunerState(tuning, MCMCBasicStats(chain_state), 1, scale) + AdaptiveAffineTuningState(tuning, MCMCBasicStats(chain_state), 1, scale) end -AdaptiveMHProposalTunerState(tuning::AdaptiveMHTuning, chain_state::MCMCChainState) = AdaptiveMHProposalTunerState() - - -create_trafo_tuner_state(tuning::AdaptiveMHTuning, chain_state::MCMCChainState, iteration::Integer) = AdaptiveMHTrafoTunerState(tuning, chain_state) - -create_proposal_tuner_state(tuning::AdaptiveMHTuning, chain_state::MCMCChainState, iteration::Integer) = AdaptiveMHProposalTunerState() +create_trafo_tuner_state(tuning::AdaptiveAffineTuning, chain_state::MCMCChainState, iteration::Integer) = AdaptiveAffineTuningState(tuning, chain_state) -function mcmc_tuning_init!!(tuner_state::AdaptiveMHTrafoTunerState, chain_state::MCMCChainState, max_nsteps::Integer) +function mcmc_tuning_init!!(tuner_state::AdaptiveAffineTuningState, chain_state::MCMCChainState, max_nsteps::Integer) n = totalndof(varshape(mcmc_target(chain_state))) + b = chain_state.f_transform.b proposaldist = chain_state.proposal.proposaldist Σ_unscaled = _approx_cov(proposaldist, n) @@ -82,35 +72,31 @@ function mcmc_tuning_init!!(tuner_state::AdaptiveMHTrafoTunerState, chain_state: S = cholesky(Σ) - chain_state.f_transform = Mul(S.L) + chain_state.f_transform = MulAdd(S.L, b) nothing end -mcmc_tuning_init!!(tuner_state::AdaptiveMHProposalTunerState, chain_state::MCMCChainState, max_nsteps::Integer) = nothing +mcmc_tuning_reinit!!(tuner_state::AdaptiveAffineTuningState, chain_state::MCMCChainState, max_nsteps::Integer) = nothing -mcmc_tuning_reinit!!(tuner_state::AdaptiveMHTrafoTunerState, chain_state::MCMCChainState, max_nsteps::Integer) = nothing -mcmc_tuning_reinit!!(tuner_state::AdaptiveMHProposalTunerState, chain_state::MCMCChainState, max_nsteps::Integer) = nothing - - -function mcmc_tuning_postinit!!(tuner::AdaptiveMHTrafoTunerState, chain_state::MCMCChainState, samples::DensitySampleVector) +function mcmc_tuning_postinit!!(tuner::AdaptiveAffineTuningState, chain_state::MCMCChainState, samples::DensitySampleVector) # The very first samples of a chain can be very valuable to init tuner # stats, especially if the chain gets stuck early after: stats = tuner.stats append!(stats, samples) end -mcmc_tuning_postinit!!(tuner_state::AdaptiveMHProposalTunerState, chain_state::MCMCChainState, samples::DensitySampleVector) = nothing # TODO: MD, make properly !! -function mcmc_tune_post_cycle!!(tuner::AdaptiveMHTrafoTunerState, chain_state::MCMCChainState, samples::DensitySampleVector) +function mcmc_tune_post_cycle!!(tuner::AdaptiveAffineTuningState, chain_state::MCMCChainState, samples::DensitySampleVector) tuning = tuner.tuning stats = tuner.stats stats_reweight_factor = tuning.r reweight_relative!(stats, stats_reweight_factor) append!(stats, samples) + b = chain_state.f_transform.b α_min = minimum(tuning.α) α_max = maximum(tuning.α) @@ -153,7 +139,7 @@ function mcmc_tune_post_cycle!!(tuner::AdaptiveMHTrafoTunerState, chain_state::M Σ_new = new_Σ_unscal * tuner.scale S_new = cholesky(Positive, Σ_new) - chain_state.f_transform = Mul(S_new.L) + chain_state.f_transform = MulAdd(S_new.L, b) tuner.iteration += 1 @@ -161,29 +147,12 @@ function mcmc_tune_post_cycle!!(tuner::AdaptiveMHTrafoTunerState, chain_state::M chain_state, tuner, true end -mcmc_tune_post_cycle!!(tuner::AdaptiveMHProposalTunerState, chain_state::MCMCChainState, samples::DensitySampleVector) = chain_state, tuner, false - - -mcmc_tuning_finalize!!(tuner::AdaptiveMHTrafoTunerState, chain_state::MCMCChainState) = nothing - -mcmc_tuning_finalize!!(tuner::AdaptiveMHProposalTunerState, chain_state::MCMCChainState) = nothing - -tuning_callback(::AdaptiveMHTrafoTunerState) = nop_func - -tuning_callback(::AdaptiveMHProposalTunerState) = nop_func +mcmc_tuning_finalize!!(tuner::AdaptiveAffineTuningState, chain_state::MCMCChainState) = nothing # add a boold to return if the transfom changes function mcmc_tune_post_step!!( - tuner::AdaptiveMHTrafoTunerState, - chain_state::MCMCChainState, - p_accept::Real -) - return chain_state, tuner, false -end - -function mcmc_tune_post_step!!( - tuner::AdaptiveMHProposalTunerState, + tuner::AdaptiveAffineTuningState, chain_state::MCMCChainState, p_accept::Real ) diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl index 8fbc41495..2d7d90de0 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_noop_tuner.jl @@ -1,48 +1,54 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). + """ - MCMCNoOpTuning <: MCMCTuning + NoMCMCTransformTuning <: MCMCTransformTuning -No-op tuning, marks MCMC chain states as tuned without performing any other changes -on them. Useful if chain states are pre-tuned or tuning is an internal part of the -MCMC sampler implementation. +Do not perform any MCMC transform turing. """ -struct MCMCNoOpTuning <: MCMCTuning end -export MCMCNoOpTuning +struct NoMCMCTransformTuning <: MCMCTransformTuning end +export NoMCMCTransformTuning + +struct NoMCMCTransformTuningState <: MCMCTransformTunerState end + + +create_trafo_tuner_state(::NoMCMCTransformTuning, ::MCMCChainState, ::Integer) = NoMCMCTransformTuningState() + +mcmc_tuning_init!!(::NoMCMCTransformTuningState, ::MCMCChainState, ::Integer) = nothing -struct MCMCNoOpTunerState <: MCMCTunerState end +mcmc_tuning_reinit!!(::NoMCMCTransformTuningState, ::MCMCChainState, ::Integer) = nothing -(tuning::MCMCNoOpTuning)(mc_state::MCMCChainState) = MCMCNoOpTunerState(), MCMCNoOpTunerState() +mcmc_tuning_postinit!!(::NoMCMCTransformTuningState, ::MCMCChainState, ::DensitySampleVector) = nothing -default_adaptive_transform(tuning::MCMCNoOpTuning) = nop_func +mcmc_tune_post_cycle!!(tuner::NoMCMCTransformTuningState, chain_state::MCMCChainState, ::DensitySampleVector) = chain_state, tuner, false -function NoOpTunerState(tuning::MCMCNoOpTuning, mc_state::MCMCChainState, iteration::Integer) - MCMCNoOpTunerState() -end +mcmc_tuning_finalize!!(::NoMCMCTransformTuningState, ::MCMCChainState) = nothing -create_trafo_tuner_state(tuning::MCMCNoOpTuning, mc_state::MCMCChainState, iteration::Integer) = MCMCNoOpTunerState() +mcmc_tune_post_step!!(tuner::NoMCMCTransformTuningState, chain_state::MCMCChainState, ::Real) = chain_state, tuner, false -create_proposal_tuner_state(tuning::MCMCNoOpTuning, mc_state::MCMCChainState, iteration::Integer) = MCMCNoOpTunerState() -mcmc_tuning_init!!(tuner_state::MCMCNoOpTunerState, mc_state::MCMCChainState, max_nsteps::Integer) = nothing -mcmc_tuning_reinit!!(tuner::MCMCNoOpTunerState, mc_state::MCMCChainState, max_nsteps::Integer) = nothing +""" + NoMCMCProposalTuning <: MCMCProposalTuning + +Do not perform any MCMC proposal tuning. +""" +struct NoMCMCProposalTuning <: MCMCProposalTuning end +export NoMCMCProposalTuning + +struct NoMCMCProposalTunerState <: MCMCProposalTunerState end -mcmc_tuning_postinit!!(tuner::MCMCNoOpTunerState, mc_state::MCMCChainState, samples::DensitySampleVector) = nothing -mcmc_tune_post_cycle!!(tuner::MCMCNoOpTunerState, mc_state::MCMCChainState, samples::DensitySampleVector) = mc_state, tuner, false +create_proposal_tuner_state(::NoMCMCProposalTuning, ::MCMCChainState, ::Integer) = NoMCMCProposalTunerState() -mcmc_tuning_finalize!!(tuner::MCMCNoOpTunerState, mc_state::MCMCChainState) = nothing +mcmc_tuning_init!!(::NoMCMCProposalTunerState, ::MCMCChainState, ::Integer) = nothing -tuning_callback(::MCMCNoOpTuning) = nop_func +mcmc_tuning_reinit!!(::NoMCMCProposalTunerState, ::MCMCChainState, ::Integer) = nothing -tuning_callback(::Nothing) = nop_func +mcmc_tuning_postinit!!(::NoMCMCProposalTunerState, ::MCMCChainState, ::DensitySampleVector) = nothing +mcmc_tune_post_cycle!!(tuner::NoMCMCProposalTunerState, chain_state::MCMCChainState, ::DensitySampleVector) = chain_state, tuner, false -function mcmc_tune_post_step!!(chain_state::MCMCChainState, tuner::MCMCNoOpTunerState, ::Real) - return chain_state, tuner, false -end +mcmc_tuning_finalize!!(::NoMCMCProposalTunerState, ::MCMCChainState) = nothing -function mcmc_tune_post_step!!(chain_state::MCMCChainState, tuner::Nothing, ::Real) - return chain_state, nothing, false -end +mcmc_tune_post_step!!(tuner::NoMCMCProposalTunerState, chain_state::MCMCChainState, ::Real) = chain_state, tuner, false diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index fa8567fd8..28bbe57de 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -1,30 +1,45 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -@with_kw struct RAMTuning <: MCMCTuning - target_acceptance::Float64 = 0.234 #TODO AC: how to pass custom intitial value for cov matrix? +""" + struct RAMTuning <: MCMCTransformTuning + +Tunes MCMC spaces transformations based on the +[Robust adaptive Metropolis algorithm](https://doi.org/10.1007/s11222-011-9269-5). + +In constrast to the original RAM algorithm, `RAMTuning` does not use the +covariance estimate to change a proposal distribution, but instead +uses it as the bases for an affine transformation. The sampling process is +mathematically equivalent, though. + +Constructors: + +* ```$(FUNCTIONNAME)(; fields...)``` + +Fields: + +$(TYPEDFIELDS) +""" +@with_kw struct RAMTuning <: MCMCTransformTuning + "MCMC target acceptance ratio." + target_acceptance::Float64 = 0.234 + + "Width around `target_acceptance`." σ_target_acceptance::Float64 = 0.05 + + "Negative adaption rate exponent." gamma::Float64 = 2/3 end export RAMTuning -mutable struct RAMTrafoTunerState <: MCMCTunerState +mutable struct RAMTrafoTunerState <: MCMCTransformTunerState tuning::RAMTuning nsteps::Int end -mutable struct RAMProposalTunerState <: MCMCTunerState end - -(tuning::RAMTuning)(mc_state::MCMCChainState) = RAMTrafoTunerState(tuning, 0), RAMProposalTunerState() - -default_adaptive_transform(tuning::RAMTuning) = TriangularAffineTransform() +mutable struct RAMProposalTunerState <: MCMCTransformTunerState end -RAMTrafoTunerState(tuning::RAMTuning) = RAMTrafoTunerState(tuning, 0) -RAMProposalTunerState(tuning::RAMTuning) = RAMProposalTunerState() - -create_trafo_tuner_state(tuning::RAMTuning, chain::MCMCChainState, n_steps_hint::Integer) = RAMTrafoTunerState(tuning, n_steps_hint) - -create_proposal_tuner_state(tuning::RAMTuning, chain::MCMCChainState, n_steps_hint::Integer) = RAMProposalTunerState() +create_trafo_tuner_state(tuning::RAMTuning, chain::MCMCChainState, n_steps_hint::Integer) = RAMTrafoTunerState(tuning, 0) function mcmc_tuning_init!!(tuner_state::RAMTrafoTunerState, chain_state::MCMCChainState, max_nsteps::Integer) chain_state.info = MCMCChainStateInfo(chain_state.info, tuned = false) # TODO ? @@ -32,19 +47,14 @@ function mcmc_tuning_init!!(tuner_state::RAMTrafoTunerState, chain_state::MCMCCh return nothing end -mcmc_tuning_init!!(tuner_state::RAMProposalTunerState, chain_state::MCMCChainState, max_nsteps::Integer) = nothing - mcmc_tuning_reinit!!(tuner_state::RAMTrafoTunerState, chain_state::MCMCChainState, max_nsteps::Integer) = nothing -mcmc_tuning_reinit!!(tuner_state::RAMProposalTunerState, chain_state::MCMCChainState, max_nsteps::Integer) = nothing - mcmc_tuning_postinit!!(tuner::RAMTrafoTunerState, chain::MCMCChainState, samples::DensitySampleVector) = nothing -mcmc_tuning_postinit!!(tuner::RAMProposalTunerState, chain::MCMCChainState, samples::DensitySampleVector) = nothing - function mcmc_tune_post_cycle!!(tuner::RAMTrafoTunerState, chain_state::MCMCChainState, samples::DensitySampleVector) - α_min, α_max = map(op -> op(1, tuner.tuning.σ_target_acceptance), [-,+]) .* tuner.tuning.target_acceptance + α_min = (1 - tuner.tuning.σ_target_acceptance) * tuner.tuning.target_acceptance + α_max = (1 + tuner.tuning.σ_target_acceptance) * tuner.tuning.target_acceptance α = eff_acceptance_ratio(chain_state) max_log_posterior = maximum(samples.logd) @@ -59,16 +69,8 @@ function mcmc_tune_post_cycle!!(tuner::RAMTrafoTunerState, chain_state::MCMCChai return chain_state, tuner, false end -mcmc_tune_post_cycle!!(tuner::RAMProposalTunerState, chain::MCMCChainState, samples::DensitySampleVector) = chain, tuner, false - mcmc_tuning_finalize!!(tuner::RAMTrafoTunerState, chain::MCMCChainState) = nothing -mcmc_tuning_finalize!!(tuner::RAMProposalTunerState, chain::MCMCChainState) = nothing - -tuning_callback(::RAMTrafoTunerState) = nop_func - -tuning_callback(::RAMProposalTunerState) = nop_func - # Return mc_state instead of f_transform function mcmc_tune_post_step!!( tuner_state::RAMTrafoTunerState, @@ -77,6 +79,7 @@ function mcmc_tune_post_step!!( ) @unpack target_acceptance, gamma = tuner_state.tuning @unpack f_transform, sample_z = mc_state + b = f_transform.b n_dims = size(sample_z.v[1], 1) η = min(1, n_dims * tuner_state.nsteps^(-gamma)) @@ -87,7 +90,7 @@ function mcmc_tune_post_step!!( M = s_L * (I + η * (p_accept - target_acceptance) * (u * u') / norm(u)^2 ) * s_L' S = cholesky(Positive, M) - f_transform_new = Mul(S.L) + f_transform_new = MulAdd(S.L, b) tuner_state_new = @set tuner_state.nsteps = tuner_state.nsteps + 1 @@ -95,11 +98,3 @@ function mcmc_tune_post_step!!( return mc_state_new, tuner_state_new, true end - -function mcmc_tune_post_step!!( - tuner_state::RAMProposalTunerState, - mc_state::MCMCChainState, - p_accept::Real, -) - return mc_state, tuner_state, false -end diff --git a/src/samplers/mcmc/mcmc_utils.jl b/src/samplers/mcmc/mcmc_utils.jl index 5ded6c374..682a7f635 100644 --- a/src/samplers/mcmc/mcmc_utils.jl +++ b/src/samplers/mcmc/mcmc_utils.jl @@ -40,3 +40,46 @@ end _approx_cov(target::Distribution, n) = _cov_with_fallback(target, n) _approx_cov(target::BATDistMeasure, n) = _cov_with_fallback(Distribution(target), n) _approx_cov(target::AbstractPosteriorMeasure, n) = _approx_cov(getprior(target), n) + + + +function _mean_with_fallback(d::UnivariateDistribution, n::Integer) + rng = _bat_determ_rng() + T = float(eltype(rand(rng, d))) + m = fill(T(NaN), n) + try + m[:] = fill(mean(d),n) + catch err + if err isa MethodError + m[:] = fill(mean(nestedview(rand(rng, d, 10^5))), n) + else + throw(err) + end + end + return m +end + +function _mean_with_fallback(d::TDist, n::Integer) # include arg for desired type of output? + return ones(Float64, n) # technially only for degrees of freedom > 1 +end + + +function _mean_with_fallback(d::MultivariateDistribution, n::Integer) + rng = _bat_determ_rng() + T = float(eltype(rand(rng, d))) + m = fill(T(NaN), n) + try + m[:] = mean(d) + catch err + if err isa MethodError + m[:] = mean(nestedview(rand(rng, d, 10^5))) + else + throw(err) + end + end + return m +end + +_approx_mean(target::Distribution, n) = _mean_with_fallback(target, n) +_approx_mean(target::BATDistMeasure, n) = _mean_with_fallback(Distribution(target), n) +_approx_mean(target::AbstractPosteriorMeasure, n) = _approx_mean(getprior(target), n) diff --git a/src/samplers/mcmc/mcmc_weighting.jl b/src/samplers/mcmc/mcmc_weighting.jl index 57ddcb959..ee40a853c 100644 --- a/src/samplers/mcmc/mcmc_weighting.jl +++ b/src/samplers/mcmc/mcmc_weighting.jl @@ -21,7 +21,7 @@ sample_weight_type(::Type{<:AbstractMCMCWeightingScheme{T}}) where {T} = T Sample weighting scheme suitable for sampling algorithms which may repeated samples multiple times in direct succession (e.g. -[`MetropolisHastings`](@ref)). The repeated sample is stored only once, +[`RandomWalk`](@ref)). The repeated sample is stored only once, with a weight equal to the number of times it has been repeated (e.g. because a Markov chain has not moved during a sampling step). @@ -34,7 +34,20 @@ export RepetitionWeighting RepetitionWeighting() = RepetitionWeighting{Int}() -_weight_type(::RepetitionWeighting) = Int +mcmc_weight_type(::RepetitionWeighting) = Int + +function mcmc_weight_values( + ::RepetitionWeighting, + p_accept::Real, + accepted::Bool +) + if accepted + (0, 1) + else + (1, 0) + end +end + """ ARPWeighting{T<:AbstractFloat} <: AbstractMCMCWeightingScheme{T} @@ -42,7 +55,7 @@ _weight_type(::RepetitionWeighting) = Int *Experimental feature, not part of stable public API.* Sample weighting scheme suitable for accept/reject-based sampling algorithms -(e.g. [`MetropolisHastings`](@ref)). Both accepted and rejected samples +(e.g. [`RandomWalk`](@ref)). Both accepted and rejected samples become part of the output, with a weight proportional to their original acceptance probability. @@ -55,4 +68,19 @@ export ARPWeighting ARPWeighting() = ARPWeighting{Float64}() -_weight_type(::ARPWeighting) = Float64 +mcmc_weight_type(::ARPWeighting) = Float64 + +function mcmc_weight_values( + scheme::ARPWeighting, + p_accept::Real, + accepted::Bool +) + T = typeof(p_accept) + if p_accept ≈ 1 + (zero(T), one(T)) + elseif p_accept ≈ 0 + (one(T), zero(T)) + else + (T(1 - p_accept), p_accept) + end +end diff --git a/src/samplers/mcmc/mh_sampler.jl b/src/samplers/mcmc/mh_sampler.jl index 530d1cb00..b6acb0ef4 100644 --- a/src/samplers/mcmc/mh_sampler.jl +++ b/src/samplers/mcmc/mh_sampler.jl @@ -1,17 +1,8 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). -""" - abstract type MHProposalDistTuning - -Abstract type for Metropolis-Hastings tuning strategies for -proposal distributions. -""" -abstract type MHProposalDistTuning <: MCMCTuning end -export MHProposalDistTuning - """ - struct MetropolisHastings <: MCMCAlgorithm + struct RandomWalk <: MCMCAlgorithm Metropolis-Hastings MCMC sampling algorithm. @@ -23,51 +14,44 @@ Fields: $(TYPEDFIELDS) """ -@with_kw struct MetropolisHastings{ - Q<:ContinuousDistribution, - WS<:AbstractMCMCWeightingScheme, -} <: MCMCProposal +@with_kw struct RandomWalk{Q<:ContinuousUnivariateDistribution} <: MCMCProposal proposaldist::Q = TDist(1.0) - weighting::WS = RepetitionWeighting() end -export MetropolisHastings +export RandomWalk -mutable struct MHProposalState{ - Q<:ContinuousDistribution, - WS<:AbstractMCMCWeightingScheme, -} <: MCMCProposalState +struct MHProposalState{Q<:ContinuousUnivariateDistribution} <: MCMCProposalState proposaldist::Q - weighting::WS end export MHProposalState +bat_default(::Type{TransformedMCMC}, ::Val{:pre_transform}, proposal::RandomWalk) = PriorToGaussian() -bat_default(::Type{MCMCSampling}, ::Val{:pre_transform}, proposal::MetropolisHastings) = PriorToGaussian() +bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::RandomWalk) = NoMCMCProposalTuning() -bat_default(::Type{MCMCSampling}, ::Val{:trafo_tuning}, proposal::MetropolisHastings) = RAMTuning() +bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, proposal::RandomWalk) = RAMTuning() -bat_default(::Type{MCMCSampling}, ::Val{:adaptive_transform}, proposal::MetropolisHastings) = TriangularAffineTransform() +bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::RandomWalk) = TriangularAffineTransform() -bat_default(::Type{MCMCSampling}, ::Val{:tempering}, proposal::MetropolisHastings) = NoMCMCTempering() +bat_default(::Type{TransformedMCMC}, ::Val{:tempering}, proposal::RandomWalk) = NoMCMCTempering() -bat_default(::Type{MCMCSampling}, ::Val{:nsteps}, proposal::MetropolisHastings, pre_transform::AbstractTransformTarget, nchains::Integer) = 10^5 +bat_default(::Type{TransformedMCMC}, ::Val{:nsteps}, proposal::RandomWalk, pre_transform::AbstractTransformTarget, nchains::Integer) = 10^5 -bat_default(::Type{MCMCSampling}, ::Val{:init}, proposal::MetropolisHastings, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = +bat_default(::Type{TransformedMCMC}, ::Val{:init}, proposal::RandomWalk, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = MCMCChainPoolInit(nsteps_init = max(div(nsteps, 100), 250)) -bat_default(::Type{MCMCSampling}, ::Val{:burnin}, proposal::MetropolisHastings, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = +bat_default(::Type{TransformedMCMC}, ::Val{:burnin}, proposal::RandomWalk, pre_transform::AbstractTransformTarget, nchains::Integer, nsteps::Integer) = MCMCMultiCycleBurnin(nsteps_per_cycle = max(div(nsteps, 10), 2500)) function _create_proposal_state( - proposal::MetropolisHastings, + proposal::RandomWalk, target::BATMeasure, context::BATContext, v_init::AbstractVector{<:Real}, rng::AbstractRNG ) - return MHProposalState(proposal.proposaldist, proposal.weighting) + return MHProposalState(proposal.proposaldist) end @@ -76,14 +60,7 @@ function _get_sample_id(proposal::MHProposalState, id::Int32, cycle::Int32, step end -const MHChainState = MCMCChainState{<:BATMeasure, - <:RNGPartition, - <:Function, - <:MHProposalState, - <:DensitySampleVector, - <:DensitySampleVector, - <:BATContext -} +const MHChainState = MCMCChainState{<:BATMeasure, <:RNGPartition, <:Function, <:MHProposalState} function mcmc_propose!!(mc_state::MHChainState) @unpack target, proposal, f_transform, context = mc_state @@ -117,6 +94,7 @@ function mcmc_propose!!(mc_state::MHChainState) return mc_state, accepted, p_accept end + function _accept_reject!(mc_state::MHChainState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer) @unpack samples, proposal = mc_state @@ -131,36 +109,10 @@ function _accept_reject!(mc_state::MHChainState, accepted::Bool, p_accept::Float samples.info.sampletype[proposed] = REJECTED_SAMPLE end - delta_w_current, w_proposed = _weights(proposal, p_accept, accepted) + delta_w_current, w_proposed = mcmc_weight_values(mc_state.weighting, p_accept, accepted) samples.weight[current] += delta_w_current samples.weight[proposed] = w_proposed end -function _weights( - proposal::MHProposalState{Q,<:RepetitionWeighting}, - p_accept::Real, - accepted::Bool -) where Q - if accepted - (0, 1) - else - (1, 0) - end -end - -function _weights( - proposal::MHProposalState{Q,<:ARPWeighting}, - p_accept::Real, - accepted::Bool -) where Q - T = typeof(p_accept) - if p_accept ≈ 1 - (zero(T), one(T)) - elseif p_accept ≈ 0 - (one(T), zero(T)) - else - (T(1 - p_accept), p_accept) - end -end eff_acceptance_ratio(mc_state::MHChainState) = nsamples(mc_state) / nsteps(mc_state) diff --git a/src/samplers/mcmc/multi_cycle_burnin.jl b/src/samplers/mcmc/multi_cycle_burnin.jl index 978a2dcbe..fffe87a6f 100644 --- a/src/samplers/mcmc/multi_cycle_burnin.jl +++ b/src/samplers/mcmc/multi_cycle_burnin.jl @@ -26,7 +26,7 @@ export MCMCMultiCycleBurnin function mcmc_burnin!( outputs::Union{AbstractVector{<:DensitySampleVector},Nothing}, mcmc_states::AbstractVector{<:MCMCState}, - samplingalg::MCMCSampling, + samplingalg::TransformedMCMC, callback::Function ) nchains = length(mcmc_states) diff --git a/src/samplers/samplers.jl b/src/samplers/samplers.jl index de7264c18..eda5f4ab7 100644 --- a/src/samplers/samplers.jl +++ b/src/samplers/samplers.jl @@ -1,5 +1,6 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). +include("adaptive_transform.jl") include("bat_sample.jl") include("mcmc/mcmc.jl") include("evaluated_measure.jl") diff --git a/src/transforms/transforms.jl b/src/transforms/transforms.jl index fdd814e9a..6c19878a2 100644 --- a/src/transforms/transforms.jl +++ b/src/transforms/transforms.jl @@ -2,4 +2,3 @@ include("trafo_utils.jl") include("distribution_transform.jl") -include("adaptive_transform.jl") diff --git a/test/distributions/test_hierarchical_distribution.jl b/test/distributions/test_hierarchical_distribution.jl index a317a32ce..103562014 100644 --- a/test/distributions/test_hierarchical_distribution.jl +++ b/test/distributions/test_hierarchical_distribution.jl @@ -45,7 +45,7 @@ import AdvancedHMC @test @inferred(logpdf(hd, varshape(hd)(ux))) == logpdf(ud, ux) @test @inferred(logpdf(hd, varshape(hd)(ux))) == logpdf(ud, ux) - samples = bat_sample(hd, MCMCSampling(mcalg = HamiltonianMC(), trafo = DoNotTransform(), nsteps = 10^4), context).result + samples = bat_sample(hd, TransformedMCMC(mcalg = HamiltonianMC(), trafo = DoNotTransform(), nsteps = 10^4), context).result @test isapprox(cov(unshaped.(samples)), cov(ud), rtol = 0.25) end diff --git a/test/integration/test_brigde_sampling_integration.jl b/test/integration/test_brigde_sampling_integration.jl index 80a003827..f3be56c1b 100644 --- a/test/integration/test_brigde_sampling_integration.jl +++ b/test/integration/test_brigde_sampling_integration.jl @@ -15,7 +15,7 @@ using LinearAlgebra: Diagonal, ones dist::Distribution; val_expected::Real=1.0, val_rtol::Real=3.5, err_max::Real=0.2) @testset "$title" begin - samplingalg = MCMCSampling( + samplingalg = TransformedMCMC( pre_transform = DoNotTransform(), nsteps = 2*10^5, burnin = MCMCMultiCycleBurnin(nsteps_per_cycle = 10^5, max_ncycles = 60) diff --git a/test/io/test_hdf5.jl b/test/io/test_hdf5.jl index 686a39350..5c3262a22 100644 --- a/test/io/test_hdf5.jl +++ b/test/io/test_hdf5.jl @@ -11,7 +11,7 @@ if Int == Int64 @testset "hdf5" begin mktempdir() do tmp_datadir results_filename = joinpath(tmp_datadir, "results.hdf5") - samples = bat_sample(BAT.example_posterior(), MCMCSampling(nsteps = 1000, strict = false)).result + samples = bat_sample(BAT.example_posterior(), TransformedMCMC(nsteps = 1000, strict = false)).result bat_write(results_filename, samples) samples2 = bat_read(results_filename).result @test samples == samples2 diff --git a/test/measures/test_bat_pushfwd_measure.jl b/test/measures/test_bat_pushfwd_measure.jl index d27d623a2..0a065becb 100644 --- a/test/measures/test_bat_pushfwd_measure.jl +++ b/test/measures/test_bat_pushfwd_measure.jl @@ -87,7 +87,7 @@ using Optim @test isfinite(@inferred logdensityof(m)(@inferred(bat_initval(m, context)).result)) @test isapprox(cov(@inferred(bat_initval(m, 10^4, context)).result), I(totalndof(varshape(m))), rtol = 0.1) - samples_is = bat_sample(m, MCMCSampling(mcalg = HamiltonianMC(), trafo = DoNotTransform(), nsteps = 10^4), context).result + samples_is = bat_sample(m, TransformedMCMC(mcalg = HamiltonianMC(), trafo = DoNotTransform(), nsteps = 10^4), context).result @test isapprox(cov(samples_is), I(totalndof(varshape(m))), rtol = 0.1) samples_os = inverse(trafo).(samples_is) @test all(isfinite, logpdf.(Ref(src_d), samples_os.v)) @@ -99,7 +99,7 @@ using Optim prior = HierarchicalDistribution(f_secondary, primary_dist) likelihood = logfuncdensity(logdensityof(varshape(prior)(MvNormal(Diagonal(fill(1.0, totalndof(varshape(prior)))))))) m = PosteriorMeasure(likelihood, prior) - hmc_samples = bat_sample(m, MCMCSampling(mcalg = HamiltonianMC(), trafo = PriorToGaussian(), nsteps = 10^4), context).result + hmc_samples = bat_sample(m, TransformedMCMC(mcalg = HamiltonianMC(), trafo = PriorToGaussian(), nsteps = 10^4), context).result is_samples = bat_sample(m, PriorImportanceSampler(nsamples = 10^4), context).result @test isapprox(mean(unshaped.(hmc_samples)), mean(unshaped.(is_samples)), rtol = 0.1) @test isapprox(cov(unshaped.(hmc_samples)), cov(unshaped.(is_samples)), rtol = 0.2) diff --git a/test/measures/test_truncate_batmeasure.jl b/test/measures/test_truncate_batmeasure.jl index 8071ea544..685392aca 100644 --- a/test/measures/test_truncate_batmeasure.jl +++ b/test/measures/test_truncate_batmeasure.jl @@ -66,7 +66,7 @@ using ArraysOfArrays, Distributions, StatsBase, IntervalSets let trunc_prior_dist = parent(BAT.getprior(trunc_pstr)).dist - s = bat_sample(trunc_pstr, MCMCSampling(mcalg = MetropolisHastings(), trafo = DoNotTransform(), nsteps = 10^5)).result + s = bat_sample(trunc_pstr, TransformedMCMC(mcalg = RandomWalk(), trafo = DoNotTransform(), nsteps = 10^5)).result s_flat = flatview(unshaped.(s)) @test all(minimum.(bounds) .< minimum(s_flat)) @test all(maximum.(bounds) .> maximum(s_flat)) diff --git a/test/plotting/test_BATHistogram.jl b/test/plotting/test_BATHistogram.jl index 1ff73a3e7..a9bdf0bd6 100644 --- a/test/plotting/test_BATHistogram.jl +++ b/test/plotting/test_BATHistogram.jl @@ -31,7 +31,7 @@ prior = BAT.NamedTupleDist( ) posterior = PosteriorMeasure(likelihood, prior); -algorithm = MCMCSampling(mcalg = MetropolisHastings(), nchains = 4, nsteps = 10^4) +algorithm = TransformedMCMC(mcalg = RandomWalk(), nchains = 4, nsteps = 10^4) shaped_samples = bat_sample(posterior, algorithm).result unshaped_samples = BAT.unshaped.(shaped_samples) diff --git a/test/samplers/mcmc/test_hmc.jl b/test/samplers/mcmc/test_hmc.jl index 3d5baf622..403567c90 100644 --- a/test/samplers/mcmc/test_hmc.jl +++ b/test/samplers/mcmc/test_hmc.jl @@ -20,9 +20,9 @@ import AdvancedHMC @test target isa BAT.BATDistMeasure proposal = HamiltonianMC() - tuning = StanHMCTuning() + proposal_tuning = StanHMCTuning() nchains = 4 - samplingalg = MCMCSampling(proposal = proposal, trafo_tuning = tuning) + samplingalg = TransformedMCMC(proposal = proposal, proposal_tuning = proposal_tuning, nchains = nchains) @testset "MCMC iteration" begin v_init = bat_initval(target, InitFromTarget(), context).result @@ -35,9 +35,6 @@ import AdvancedHMC BAT.mcmc_tuning_init!!(mcmc_state, 0) BAT.mcmc_tuning_reinit!!(mcmc_state, div(nsteps, 10)) - samplingalg = BAT.MCMCSampling(proposal = proposal, trafo_tuning = tuning, nchains = nchains) - - samples = DensitySampleVector(mcmc_state) mcmc_state = BAT.mcmc_iterate!!(samples, mcmc_state; max_nsteps = nsteps, nonzero_weights = false) @test mcmc_state.chain_state.stepno == nsteps @@ -53,23 +50,23 @@ import AdvancedHMC @testset "MCMC tuning and burn-in" begin max_nsteps = 10^5 - tuning_alg = BAT.StanHMCTuning() + proposal_tuning = BAT.StanHMCTuning() trafo = DoNotTransform() - init_alg = bat_default(MCMCSampling, Val(:init), proposal, trafo, nchains, max_nsteps) - burnin_alg = bat_default(MCMCSampling, Val(:burnin), proposal, trafo, nchains, max_nsteps) + init_alg = bat_default(TransformedMCMC, Val(:init), proposal, trafo, nchains, max_nsteps) + burnin_alg = bat_default(TransformedMCMC, Val(:burnin), proposal, trafo, nchains, max_nsteps) convergence_test = BrooksGelmanConvergence() strict = true nonzero_weights = false callback = (x...) -> nothing - samplingalg = MCMCSampling(proposal = proposal, - trafo_tuning = tuning_alg, - pre_transform = trafo, - init = init_alg, - burnin = burnin_alg, - convergence = convergence_test, - strict = strict, - nonzero_weights = nonzero_weights + samplingalg = TransformedMCMC(proposal = proposal, + proposal_tuning = proposal_tuning, + pre_transform = trafo, + init = init_alg, + burnin = burnin_alg, + convergence = convergence_test, + strict = strict, + nonzero_weights = nonzero_weights ) # Note: No @inferred, not type stable (yet) with HamiltonianMC @@ -110,9 +107,9 @@ import AdvancedHMC @testset "bat_sample" begin samples = bat_sample( shaped_target, - MCMCSampling( + TransformedMCMC( proposal = proposal, - trafo_tuning = StanHMCTuning(), + proposal_tuning = StanHMCTuning(), pre_transform = DoNotTransform(), nsteps = 10^4, store_burnin = true @@ -126,9 +123,9 @@ import AdvancedHMC smplres = BAT.sample_and_verify( shaped_target, - MCMCSampling( + TransformedMCMC( proposal = proposal, - trafo_tuning = StanHMCTuning(), + proposal_tuning = StanHMCTuning(), pre_transform = DoNotTransform(), nsteps = 10^4, store_burnin = false @@ -148,7 +145,7 @@ import AdvancedHMC inner_posterior = PosteriorMeasure(likelihood, prior) # Test with nested posteriors: posterior = PosteriorMeasure(likelihood, inner_posterior) - @test BAT.sample_and_verify(posterior, MCMCSampling(proposal = HamiltonianMC(), trafo_tuning = StanHMCTuning(), pre_transform = PriorToGaussian()), prior.dist, context).verified + @test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = HamiltonianMC(), proposal_tuning = StanHMCTuning(), pre_transform = PriorToGaussian()), prior.dist, context).verified end @testset "HMC autodiff" begin @@ -158,9 +155,9 @@ import AdvancedHMC @testset "$adsel" begin context = BATContext(ad = adsel) - hmc_samplingalg = MCMCSampling( + hmc_samplingalg = TransformedMCMC( proposal = HamiltonianMC(), - trafo_tuning = StanHMCTuning(), + proposal_tuning = StanHMCTuning(), nchains = 2, nsteps = 100, init = MCMCChainPoolInit(init_tries_per_chain = 2..2, nsteps_init = 5), diff --git a/test/samplers/mcmc/test_mcmc_sample.jl b/test/samplers/mcmc/test_mcmc_sample.jl index f3e69b245..9a3b12d8c 100644 --- a/test/samplers/mcmc/test_mcmc_sample.jl +++ b/test/samplers/mcmc/test_mcmc_sample.jl @@ -18,15 +18,14 @@ using DensityInterface nchains = 4 nsteps = 10^4 - samplingalg_MW = @inferred(MCMCSampling(pre_transform = DoNotTransform(), nchains = nchains, nsteps = nsteps)) + samplingalg_MW = @inferred(TransformedMCMC(pre_transform = DoNotTransform(), nchains = nchains, nsteps = nsteps)) smplres = BAT.sample_and_verify(PosteriorMeasure(likelihood, prior), samplingalg_MW, mv_dist) samples = smplres.result @test smplres.verified @test (nchains * nsteps - sum(samples.weight)) < 100 - - samplingalg_PW = @inferred MCMCSampling(proposal = MetropolisHastings(weighting = ARPWeighting()), pre_transform = DoNotTransform(), nsteps = 10^5) + samplingalg_PW = @inferred TransformedMCMC(proposal = RandomWalk(), pre_transform = DoNotTransform(), nsteps = 10^5, sample_weighting = ARPWeighting()) @test BAT.sample_and_verify(mv_dist, samplingalg_PW).verified @@ -36,5 +35,5 @@ using DensityInterface @test gensamples(context) != gensamples(context) @test gensamples(deepcopy(context)) == gensamples(deepcopy(context)) - @test BAT.sample_and_verify(Normal(), MCMCSampling(pre_transform = DoNotTransform(), nsteps = 10^4)).verified + @test BAT.sample_and_verify(Normal(), TransformedMCMC(pre_transform = DoNotTransform(), nsteps = 10^4)).verified end diff --git a/test/samplers/mcmc/test_mh.jl b/test/samplers/mcmc/test_mh.jl index e15eca024..a23017722 100644 --- a/test/samplers/mcmc/test_mh.jl +++ b/test/samplers/mcmc/test_mh.jl @@ -5,7 +5,7 @@ using Test using LinearAlgebra using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityInterface -@testset "MetropolisHastings" begin +@testset "RandomWalk" begin context = BATContext() objective = NamedTupleDist(a = Normal(1, 1.5), b = MvNormal([-1.0, 2.0], [2.0 1.5; 1.5 3.0])) @@ -14,10 +14,10 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI target = unshaped(shaped_target) @test target isa BAT.BATDistMeasure - proposal = MetropolisHastings() + proposal = RandomWalk() nchains = 4 - samplingalg = MCMCSampling() + samplingalg = TransformedMCMC() @testset "MCMC iteration" begin v_init = bat_initval(target, InitFromTarget(), context).result @@ -41,7 +41,7 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI @testset "MCMC tuning and burn-in" begin init_alg = MCMCChainPoolInit() - tuning_alg = AdaptiveMHTuning() + tuning_alg = AdaptiveAffineTuning() burnin_alg = MCMCMultiCycleBurnin() convergence_test = BrooksGelmanConvergence() strict = true @@ -49,9 +49,9 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI callback = (x...) -> nothing max_nsteps = 10^5 - samplingalg = MCMCSampling( + samplingalg = TransformedMCMC( proposal = proposal, - trafo_tuning = tuning_alg, + transform_tuning = tuning_alg, burnin = burnin_alg, nchains = nchains, convergence = convergence_test, @@ -71,7 +71,7 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI # TODO: MD, Reactivate, for some reason fail # @test mcmc_states isa AbstractVector{<:BAT.MHChainState} - # @test tuners isa AbstractVector{<:BAT.AdaptiveMHTrafoTunerState} + # @test tuners isa AbstractVector{<:BAT.AdaptiveAffineTuningState} @test outputs isa AbstractVector{<:DensitySampleVector} BAT.mcmc_burnin!( @@ -98,7 +98,7 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI @testset "bat_sample" begin samples = bat_sample( shaped_target, - MCMCSampling( + TransformedMCMC( proposal = proposal, pre_transform = DoNotTransform(), store_burnin = true @@ -110,7 +110,7 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI smplres = BAT.sample_and_verify( shaped_target, - MCMCSampling( + TransformedMCMC( proposal = proposal, pre_transform = DoNotTransform() ), @@ -128,6 +128,6 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI inner_posterior = PosteriorMeasure(likelihood, prior) # Test with nested posteriors: posterior = PosteriorMeasure(likelihood, inner_posterior) - @test BAT.sample_and_verify(posterior, MCMCSampling(proposal = MetropolisHastings(), pre_transform = PriorToGaussian()), prior.dist).verified + @test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = RandomWalk(), pre_transform = PriorToGaussian()), prior.dist).verified end end