diff --git a/Project.toml b/Project.toml index b4d6a57..8f2d5f1 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,8 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" [compat] diff --git a/src/ImportanceSampling.jl b/src/ImportanceSampling.jl index a34e581..a9bbb9e 100644 --- a/src/ImportanceSampling.jl +++ b/src/ImportanceSampling.jl @@ -51,13 +51,14 @@ function psis( dims = size(log_ratios) data_size = dims[1] - post_sample_size = dims[2] * dims[3] + mcmc_count = dims[2] * dims[3] + # Reshape to matrix (easier to deal with) - log_ratios = reshape(log_ratios, data_size, post_sample_size) + log_ratios = reshape(log_ratios, data_size, mcmc_count) weights = similar(log_ratios) - # Shift ratios by maximum to prevent overflow - @tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2)) + # Shift ratios by maximum to avoid overflow, and log(mcmc_count) to avoid subnormals + @tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2) + log(mcmc_count)) r_eff = _generate_r_eff(weights, dims, r_eff, source) _check_input_validity_psis(reshape(log_ratios, dims), r_eff) @@ -65,21 +66,20 @@ function psis( tail_length = similar(log_ratios, Int, data_size) ξ = similar(log_ratios, data_size) @inbounds Threads.@threads for i in eachindex(tail_length) - tail_length[i] = @views _def_tail_length(post_sample_size, r_eff[i]) - ξ[i] = @views ParetoSmooth._do_psis_i!(weights[i,:], tail_length[i]) + tail_length[i] = @views _def_tail_length(mcmc_count, r_eff[i]) + ξ[i] = @views ParetoSmooth._psis_smooth!(weights[i,:], tail_length[i]) end - - @tullio norm_const[i] := weights[i, j] - @turbo weights .= weights ./ norm_const - ess = psis_ess(weights, r_eff) + _normalize!(weights) + + ess = psis_ess(weights, r_eff) weights = reshape(weights, dims) if log_weights @tturbo @. weights = log(weights) end - return Psis(weights, ξ, ess, r_eff, tail_length, post_sample_size, data_size) + return Psis(weights, ξ, ess, r_eff, tail_length, mcmc_count, data_size) end @@ -95,24 +95,25 @@ end """ - _do_psis_i!(is_ratios::AbstractVector{Real}, tail_length::Integer) -> T + _psis_smooth!(is_ratios::AbstractVector{AbstractFloat}, tail_length::Integer) -> T -Do PSIS on a single vector, smoothing its tail values. +Do PSIS on a single vector, smoothing its tail values in place before returning ξ. # Arguments -- `is_ratios::AbstractVector{Real}`: A vector of importance sampling ratios, -scaled to have a maximum of 1. + - `is_ratios::AbstractVector{AbstractFloat}`: A vector of importance sampling ratios, + scaled to have a maximum of 1. # Returns -- `T<:Real`: ξ, the shape parameter for the GPD; big numbers indicate thick tails. + - `T<:AbstractFloat`: ξ, the estimated shape parameter for the GPD. Bigger numbers + indicate thicker tails. # Extended help Additional information can be found in the LOO package from R. """ -function _do_psis_i!( +function _psis_smooth!( is_ratios::AbstractVector{T}, tail_length::Integer ) where {T<:Real} @@ -135,8 +136,8 @@ function _do_psis_i!( cutoff = is_ratios[tail_start - 1] ξ = _psis_smooth_tail!(tail, cutoff) - # truncate at max of raw weights (1 after scaling) - clamp!(is_ratios, 0, 1) + # truncate at max of raw weights (equal to len after scaling) + clamp!(is_ratios, 0, len) # unsort the ratios to their original position: is_ratios .= @views is_ratios[invperm(ordering)] @@ -175,6 +176,11 @@ function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} end +function _normalize!(weights::AbstractArray) + @tullio norm_const[i] := weights[i, j] + @turbo @. weights /= norm_const +end + ########################## #### HELPER FUNCTIONS #### @@ -251,7 +257,7 @@ function _check_tail(tail::AbstractVector{T}) where {T<:Real} throw( ArgumentError( "Unable to fit generalized Pareto distribution: tail length was too " * - "short. Likely causese are: \n$LIKELY_ERROR_CAUSES" + "short. Likely causes are: \n$LIKELY_ERROR_CAUSES" ), ) end diff --git a/src/InternalHelpers.jl b/src/InternalHelpers.jl index f25335c..f683169 100644 --- a/src/InternalHelpers.jl +++ b/src/InternalHelpers.jl @@ -39,6 +39,14 @@ function _assume_one_chain(matrix) end +""" +Safely exponentiate a vector for a scale-invariant operation (exponentiate x - maximum(x)) +""" +function _safe_exp(x::AbstractVector) + return exp.(x .- maximum(x)) +end + + """ Convert a matrix+chain_index representation to a 3d array representation. """ diff --git a/src/LeaveOneOut.jl b/src/LeaveOneOut.jl index 8fb2b62..3e5a90b 100644 --- a/src/LeaveOneOut.jl +++ b/src/LeaveOneOut.jl @@ -35,8 +35,8 @@ end [, chain_index::Vector{Integer}, kwargs...] ) -> PsisLoo -Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation -score. +Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross-validation +estimate. # Arguments @@ -52,7 +52,7 @@ score. See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref). """ function psis_loo( - log_likelihood::T, args...; + log_likelihood::AbstractArray, args...; kwargs... ) where {F<:Real, T<:AbstractArray{F, 3}} @@ -114,12 +114,6 @@ function psis_loo( new_log_ratios = _convert_to_array(log_likelihood, chain_index) return psis_loo(new_log_ratios, args...; kwargs...) end - -# function psis_loo(log_likelihood, args...; -# subsamples::Integer, rng::AbstractRNG=MersenneTwister(1776), kwargs... -# ) -# return log_likelihood = rand() -# end function _generate_loo_table( diff --git a/src/LooStructs.jl b/src/LooStructs.jl index dce2430..c53d1ab 100644 --- a/src/LooStructs.jl +++ b/src/LooStructs.jl @@ -24,9 +24,9 @@ const CV_DESC = """ - `:ess` is the effective sample size, which measures the simulation error caused by using Monte Carlo estimates. It is *not* related to the actual sample size, and it does not measure how accurate your predictions are. - - `:pareto_k` is the estimated value for the parameter `ξ` of the generalized Pareto - distribution. Values above .7 indicate that PSIS has failed to approximate the true - distribution. + - `:pareto_k` is the estimated value for the parameter `ξ` of the generalized Pareto + distribution. Values above .7 indicate that PSIS has failed to approximate the true + distribution. - `psis_object::Psis`: A `Psis` object containing the results of Pareto-smoothed importance sampling. diff --git a/src/MomentMatch.jl b/src/MomentMatch.jl new file mode 100644 index 0000000..6b983b7 --- /dev/null +++ b/src/MomentMatch.jl @@ -0,0 +1,153 @@ +using AxisKeys +using LoopVectorization +using StatsBase +using Tables +using Tullio + +""" + adapt_cv( + log_target::Function, + psis_object::Psis, + samples::AbstractArray, + data; + hard_thresh::Real = 2/3, + soft_thresh::Real = 1/2, + soft_cap::Integer = 10 + ) + +Perform importance-weighted moment matching, adapting a sample from a proposal distribution +to more closely match the target distribution. + +# Arguments + - `log_target`: The log-pdf of the target distribution. This should be a function having + θ, a vector of parameters, as its input; and `x`, the data set, as its second input. + +""" +function adapt_cv( + log_target::Function, + log_p::AbstractArray, + cv_object::AbstractCV, + samples::AbstractArray, + data; + hard_thresh::Real = 2/3, + soft_thresh::Real = 1/2, + soft_cap::Integer = 10 +) + + dims = size(samples) + n_steps, n_params, n_chains = dims + mcmc_count = n_steps * n_chains + log_count = mcmc_count + + weights = psis_object.weights + original_weights = deepcopy(psis_object.weights) + ξ = psis_object.pareto_k + resample_count = size(weights, 1) + + Threads.@threads @inbounds for resample in 1:resample_count + @views _match!( + log_target, + log_p[resample, :, :], + weights[resample, :, :], + samples, + ξ[resample], + log_count, + hard_thresh, + soft_thresh, + soft_cap + ) + end + +end + + +function _match!( + log_target::Function, + log_proposal::AbstractArray, + weights::AbstractArray, + θ_hats::AbstractArray, + ξ::Real, + log_count::Real, + hard_thresh::Real, + soft_thresh::Real, + soft_cap::Integer, +) + + # initialize variables + num_iter = 0 # iterations of IWMM + transform = 1 + θ_proposed = similar(θ_hats) + ξ_proposed = soft_thresh + μ = mean(θ_hats; dims=:parameter) + μ_proposed = similar(μ) + σ = std(θ_hats; dims=:parameter) + σ_proposed = similar(σ) + weights_proposed = similar(weights) + + + while _keep_going(ξ, hard_thresh, soft_thresh, soft_cap, num_iter) + + if transform == 1 + μ_proposed = mean(θ_hats, weights; dims=2) + σ_proposed .= σ + elseif transform == 2 + μ_proposed = mean(θ_hats, weights; dims=2) + σ_proposed = std(θ_hats, weights; mean=μ_proposed, dims=2) + elseif transform == 3 + μ_proposed, Σ_proposed = mean_and_cov(θ_hats, weights; dims=2) + σ_proposed = sqrt(Σ_proposed) + elseif transform == 4 + break + end + + θ_proposed = (θ_hats + μ_proposed - μ) * (σ_proposed * inv(σ)) + @. weights_proposed = _safe_exp(log_target(θ_proposed) - log_proposal, log_count) + ξ_proposed = _psis_smooth!(weights) + + if ξ_proposed < ξ + num_iter += 1 + _normalize!(weights_proposed) + + ξ = ξ_proposed + @. begin + μ = μ_proposed + σ = σ_proposed + θ_hats = θ_proposed + weights = weights_proposed + end + else + transform += 1 + end + + end + + return weights + +end + + +function _keep_going( + ξ::Real, + hard_thresh::Real, + soft_thresh::Real, + soft_cap::Int, + num_iter::Integer +) + if ξ > hard_thresh + return true + elseif (ξ > soft_thresh) && (num_iter ≤ soft_cap) + return true + else + return false + end +end + + + +""" +Safely exponentiate x, preventing underflow/overflow by rescaling all elements +by a common factor +""" +function _safe_exp(x, log_count) + return @. exp(x - $maximum(x; dims=2) + log_count) +end \ No newline at end of file diff --git a/src/PublicHelpers.jl b/src/PublicHelpers.jl index a88d411..34afd4e 100644 --- a/src/PublicHelpers.jl +++ b/src/PublicHelpers.jl @@ -1,8 +1,8 @@ export pointwise_log_likelihoods -const ARRAY_DIMS_WARNING = "The supplied array of mcmc samples indicates you have more -parameters than mcmc samples.This is possible, but highly unusual. Please check that your -array of mcmc samples has the following dimensions: [n_samples,n_parms,n_chains]." +const ARRAY_DIMS_WARNING = "The supplied array of MCMC samples indicates you have more " * +"parameters than samples. This is possible, but unusual. Please check your array of MCMC " * +"samples has dimensions `[iter, var, chain]`." """ pointwise_log_likelihoods(