Skip to content

Commit

Permalink
Rename PriorToGaussian to PriorToNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Oct 23, 2024
1 parent 5e9e7dc commit ab202a2
Show file tree
Hide file tree
Showing 17 changed files with 34 additions and 29 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Several algorithms have changed their names, but also their role:
changed (no deprecation for the parameter changes). Tuning and
sample weighting scheme selection have moved to `TransformedMCMC`.

* `PriorToGaussian` has become `PriorToNormal`.

Partial deprecations are available for the above, a lot of old code should
run more or less unchanged (with deprecation warnings). Also:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/stable_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ OptimizationAlg
OrderedResampling
PosteriorMeasure
PriorSubstitution
PriorToGaussian
PriorToNormal
PriorToUniform
RAMTuning
RandomWalk
Expand Down
2 changes: 1 addition & 1 deletion examples/dev-internal/test_findmode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ my_function(; nt...)

context = get_batcontext()
target = posterior
transformed_density, f_transform = BAT.transform_and_unshape(PriorToGaussian(), target, context)
transformed_density, f_transform = BAT.transform_and_unshape(PriorToNormal(), target, context)
inv_trafo = inverse(f_transform)
initalg = BAT.apply_trafo_to_init(f_transform, InitFromTarget())
x_init = collect(bat_initval(transformed_density, initalg, context).result)
Expand Down
2 changes: 1 addition & 1 deletion ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).


BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pretransform}, proposal::HamiltonianMC) = PriorToGaussian()
BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pretransform}, proposal::HamiltonianMC) = PriorToNormal()

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StanHMCTuning()

Expand Down
8 changes: 4 additions & 4 deletions src/algodefaults/default_transform_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToUniform, ::BATD
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToUniform, ::EvaluatedMeasure) = PriorSubstitution()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToUniform, ::BATDistMeasure{<:StandardUniformDist}) = IdentityTransformAlgorithm()

bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToGaussian, ::AbstractPosteriorMeasure) = PriorSubstitution()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToGaussian, ::BATDistMeasure) = PriorSubstitution()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToGaussian, ::EvaluatedMeasure) = PriorSubstitution()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToGaussian, ::BATDistMeasure{<:StandardNormalDist}) = IdentityTransformAlgorithm()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToNormal, ::AbstractPosteriorMeasure) = PriorSubstitution()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToNormal, ::BATDistMeasure) = PriorSubstitution()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToNormal, ::EvaluatedMeasure) = PriorSubstitution()

Check warning on line 14 in src/algodefaults/default_transform_algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algodefaults/default_transform_algorithm.jl#L13-L14

Added lines #L13 - L14 were not covered by tests
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::PriorToNormal, ::BATDistMeasure{<:StandardNormalDist}) = IdentityTransformAlgorithm()

bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::Function, ::DensitySampleVector) = SampleTransformation()
bat_default(::typeof(bat_transform), ::Val{:algorithm}, ::AbstractValueShape, ::DensitySampleVector) = SampleTransformation()
Expand Down
20 changes: 10 additions & 10 deletions src/algotypes/transform_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct ToRealVector <: AbstractTransformTarget end
export ToRealVector


# ToDo: Merge PriorToUniform and PriorToGaussian into PriorTo{Uniform|Normal}.
# ToDo: Merge PriorToUniform and PriorToNormal into PriorTo{Uniform|Normal}.

"""
struct PriorToUniform <: AbstractTransformTarget
Expand All @@ -168,7 +168,7 @@ end


"""
struct PriorToGaussian <: AbstractTransformTarget
struct PriorToNormal <: AbstractTransformTarget
Specifies that posterior densities should be transformed in a way that makes
their pior equivalent to a standard multivariate normal distribution with an
Expand All @@ -178,12 +178,12 @@ Constructors:
* ```$(FUNCTIONNAME)()```
"""
struct PriorToGaussian <: AbstractTransformTarget end
export PriorToGaussian
struct PriorToNormal <: AbstractTransformTarget end
export PriorToNormal

_distmeasure_trafo(target::PriorToGaussian, density::BATDistMeasure) = DistributionTransform(Normal, Distribution(density))
_distmeasure_trafo(target::PriorToNormal, density::BATDistMeasure) = DistributionTransform(Normal, Distribution(density))

function bat_transform_impl(target::PriorToGaussian, density::BATDistMeasure{<:StandardNormalDist}, algorithm::IdentityTransformAlgorithm, context::BATContext)
function bat_transform_impl(target::PriorToNormal, density::BATDistMeasure{<:StandardNormalDist}, algorithm::IdentityTransformAlgorithm, context::BATContext)
(result = density, f_transform = identity)
end

Expand All @@ -208,14 +208,14 @@ _get_deep_prior_for_trafo(m::BATDistMeasure) = m
_get_deep_prior_for_trafo(m::AbstractPosteriorMeasure) = _get_deep_prior_for_trafo(getprior(m))


function bat_transform_impl(target::Union{PriorToUniform,PriorToGaussian}, m::AbstractPosteriorMeasure, algorithm::FullMeasureTransform, context::BATContext)
function bat_transform_impl(target::Union{PriorToUniform,PriorToNormal}, m::AbstractPosteriorMeasure, algorithm::FullMeasureTransform, context::BATContext)
orig_prior = _get_deep_prior_for_trafo(m)
f_transform = _distmeasure_trafo(target, orig_prior)
(result = BATPushFwdMeasure(f_transform, m, KeepRootMeasure()), f_transform = f_transform)
end


function bat_transform_impl(target::Union{PriorToUniform,PriorToGaussian}, m::BATDistMeasure, algorithm::FullMeasureTransform, context::BATContext)
function bat_transform_impl(target::Union{PriorToUniform,PriorToNormal}, m::BATDistMeasure, algorithm::FullMeasureTransform, context::BATContext)
f_transform = _distmeasure_trafo(target, m)
(result = BATPushFwdMeasure(f_transform, m, KeepRootMeasure()), f_transform = f_transform)

Check warning on line 220 in src/algotypes/transform_algorithm.jl

View check run for this annotation

Codecov / codecov/patch

src/algotypes/transform_algorithm.jl#L218-L220

Added lines #L218 - L220 were not covered by tests
end
Expand All @@ -238,14 +238,14 @@ struct PriorSubstitution <: TransformAlgorithm end
export PriorSubstitution


function bat_transform_impl(target::Union{PriorToUniform,PriorToGaussian}, density::BATDistMeasure, algorithm::PriorSubstitution, context::BATContext)
function bat_transform_impl(target::Union{PriorToUniform,PriorToNormal}, density::BATDistMeasure, algorithm::PriorSubstitution, context::BATContext)
f_transform = _distmeasure_trafo(target, density)
transformed_density = BATDistMeasure(f_transform.target_dist)
(result = transformed_density, f_transform = f_transform)
end


function bat_transform_impl(target::Union{PriorToUniform,PriorToGaussian}, density::AbstractPosteriorMeasure, algorithm::PriorSubstitution, context::BATContext)
function bat_transform_impl(target::Union{PriorToUniform,PriorToNormal}, density::AbstractPosteriorMeasure, algorithm::PriorSubstitution, context::BATContext)
orig_prior = getprior(density)
orig_likelihood = getlikelihood(density)
new_prior, f_transform = bat_transform_impl(target, orig_prior, algorithm, context)
Expand Down
3 changes: 3 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ Base.@deprecate MCMCSampling(;
callback = callback
)
export MCMCSampling


@deprecate PriorToGaussian() PriorToNormal()
2 changes: 1 addition & 1 deletion src/extdefs/mgvi_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ Fields:
TR<:AbstractTransformTarget, IA<:InitvalAlgorithm,
CFG, SD<:MGVISchedule
} <: AbstractSamplingAlgorithm
pretransform::TR = (pkgext(Val(:MGVI)); PriorToGaussian())
pretransform::TR = (pkgext(Val(:MGVI)); PriorToNormal())
init::IA = InitFromTarget()
nsamples::Int = 10^4
schedule::SD = FixedMGVISchedule(range(12, nsamples, length = 10))
Expand Down
2 changes: 1 addition & 1 deletion src/extdefs/optim_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ $(TYPEDFIELDS)
IA<:InitvalAlgorithm
} <: AbstractModeEstimator
optalg::ALG = ext_default(pkgext(Val(:Optim)), Val(:DEFAULT_OPTALG))
pretransform::TR = PriorToGaussian()
pretransform::TR = PriorToNormal()
init::IA = InitFromTarget()
maxiters::Int = 1_000
maxtime::Float64 = NaN
Expand Down
2 changes: 1 addition & 1 deletion src/extdefs/optimization_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ $(TYPEDFIELDS)
IA<:InitvalAlgorithm
} <: AbstractModeEstimator
optalg::ALG = ext_default(pkgext(Val(:Optimization)), Val(:DEFAULT_OPTALG))
pretransform::TR = PriorToGaussian()
pretransform::TR = PriorToNormal()
init::IA = InitFromTarget()
maxiters::Int64 = 1_000
maxtime::Float64 = NaN
Expand Down
2 changes: 1 addition & 1 deletion src/integration/bridge_sampling_integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Fields:
$(TYPEDFIELDS)
"""
@with_kw struct BridgeSampling{TR<:AbstractTransformTarget,ESS<:EffSampleSizeAlgorithm} <: IntegrationAlgorithm
pretransform::TR = PriorToGaussian()
pretransform::TR = PriorToNormal()
essalg::ESS = EffSampleSizeFromAC()
strict::Bool = true
# ToDo: add argument for proposal density generator
Expand Down
2 changes: 1 addition & 1 deletion src/samplers/mcmc/mcmc_sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function MCMCState(samplingalg::TransformedMCMC, target::BATMeasure, id::Integer
end


bat_default(::TransformedMCMC, ::Val{:pretransform}) = PriorToGaussian()
bat_default(::TransformedMCMC, ::Val{:pretransform}) = PriorToNormal()

Check warning on line 66 in src/samplers/mcmc/mcmc_sample.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mcmc_sample.jl#L66

Added line #L66 was not covered by tests

bat_default(::TransformedMCMC, ::Val{:nsteps}, ::AbstractTransformTarget, nchains::Integer) = 10^5

Check warning on line 68 in src/samplers/mcmc/mcmc_sample.jl

View check run for this annotation

Codecov / codecov/patch

src/samplers/mcmc/mcmc_sample.jl#L68

Added line #L68 was not covered by tests

Expand Down
2 changes: 1 addition & 1 deletion src/samplers/mcmc/mh_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct MHProposalState{Q<:ContinuousUnivariateDistribution} <: MCMCProposalState
end
export MHProposalState

bat_default(::Type{TransformedMCMC}, ::Val{:pretransform}, proposal::RandomWalk) = PriorToGaussian()
bat_default(::Type{TransformedMCMC}, ::Val{:pretransform}, proposal::RandomWalk) = PriorToNormal()

bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::RandomWalk) = NoMCMCProposalTuning()

Expand Down
2 changes: 1 addition & 1 deletion test/measures/test_bat_pushfwd_measure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ using Optim
prior = HierarchicalDistribution(f_secondary, primary_dist)
likelihood = logfuncdensity(logdensityof(varshape(prior)(MvNormal(Diagonal(fill(1.0, totalndof(varshape(prior))))))))
m = PosteriorMeasure(likelihood, prior)
hmc_samples = bat_sample(m, TransformedMCMC(mcalg = HamiltonianMC(), pretransform = PriorToGaussian(), nsteps = 10^4), context).result
hmc_samples = bat_sample(m, TransformedMCMC(mcalg = HamiltonianMC(), pretransform = PriorToNormal(), nsteps = 10^4), context).result
is_samples = bat_sample(m, PriorImportanceSampler(nsamples = 10^4), context).result
@test isapprox(mean(unshaped.(hmc_samples)), mean(unshaped.(is_samples)), rtol = 0.1)
@test isapprox(cov(unshaped.(hmc_samples)), cov(unshaped.(is_samples)), rtol = 0.2)
Expand Down
2 changes: 1 addition & 1 deletion test/samplers/mcmc/test_hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ import AdvancedHMC
inner_posterior = PosteriorMeasure(likelihood, prior)
# Test with nested posteriors:
posterior = PosteriorMeasure(likelihood, inner_posterior)
@test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = HamiltonianMC(), proposal_tuning = StanHMCTuning(), pretransform = PriorToGaussian()), prior.dist, context).verified
@test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = HamiltonianMC(), proposal_tuning = StanHMCTuning(), pretransform = PriorToNormal()), prior.dist, context).verified
end

@testset "HMC autodiff" begin
Expand Down
2 changes: 1 addition & 1 deletion test/samplers/mcmc/test_mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,6 @@ using StatsBase, Distributions, StatsBase, ValueShapes, ArraysOfArrays, DensityI
inner_posterior = PosteriorMeasure(likelihood, prior)
# Test with nested posteriors:
posterior = PosteriorMeasure(likelihood, inner_posterior)
@test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = RandomWalk(), pretransform = PriorToGaussian()), prior.dist).verified
@test BAT.sample_and_verify(posterior, TransformedMCMC(proposal = RandomWalk(), pretransform = PriorToNormal()), prior.dist).verified
end
end
6 changes: 3 additions & 3 deletions test/transforms/test_distribution_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ end
posterior_uniform_prior = @inferred(PosteriorMeasure(logfuncdensity(logdensityof(mvn)), uniform_prior))
posterior_gaussian_prior = @inferred(PosteriorMeasure(logfuncdensity(logdensityof(mvn)), mvn))

@test @inferred(bat_transform(PriorToGaussian(), posterior_uniform_prior, context)).result.prior.dist == @inferred(BAT.StandardMvNormal(3))
@test @inferred(bat_transform(PriorToNormal(), posterior_uniform_prior, context)).result.prior.dist == @inferred(BAT.StandardMvNormal(3))
@test @inferred(bat_transform(PriorToUniform(), posterior_gaussian_prior, context)).result.prior.dist == @inferred(BAT.StandardMvUniform(3))
@test @inferred(bat_transform(DoNotTransform(), posterior_uniform_prior, context)).result.prior.dist == uniform_prior
pd = @inferred(product_distribution([Uniform() for i in 1:3]))
Expand All @@ -243,6 +243,6 @@ end
# ToDo: Improve comparison for bounds so `.dist` is not required here:
@inferred(bat_transform(PriorToUniform(), batmeasure(BAT.StandardUvUniform()), context)).result.dist == batmeasure(BAT.StandardUvUniform()).dist
@inferred(bat_transform(PriorToUniform(), batmeasure(BAT.StandardMvUniform(4)), context)).result.dist == batmeasure(BAT.StandardMvUniform(4)).dist
@inferred(bat_transform(PriorToGaussian(), batmeasure(BAT.StandardUvNormal()), context)).result.dist == batmeasure(BAT.StandardUvNormal()).dist
@inferred(bat_transform(PriorToGaussian(), batmeasure(BAT.StandardMvNormal(4)), context)).result.dist == batmeasure(BAT.StandardMvNormal(4)).dist
@inferred(bat_transform(PriorToNormal(), batmeasure(BAT.StandardUvNormal()), context)).result.dist == batmeasure(BAT.StandardUvNormal()).dist
@inferred(bat_transform(PriorToNormal(), batmeasure(BAT.StandardMvNormal(4)), context)).result.dist == batmeasure(BAT.StandardMvNormal(4)).dist
end

0 comments on commit ab202a2

Please sign in to comment.