Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Importance Weighted Moment Matching #23

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"

[compat]
Expand Down
46 changes: 26 additions & 20 deletions src/ImportanceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,35 @@ function psis(
dims = size(log_ratios)

data_size = dims[1]
post_sample_size = dims[2] * dims[3]
mcmc_count = dims[2] * dims[3]


# Reshape to matrix (easier to deal with)
log_ratios = reshape(log_ratios, data_size, post_sample_size)
log_ratios = reshape(log_ratios, data_size, mcmc_count)
weights = similar(log_ratios)
# Shift ratios by maximum to prevent overflow
@tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2))
# Shift ratios by maximum to avoid overflow, and log(mcmc_count) to avoid subnormals
@tturbo @. weights = exp(log_ratios - $maximum(log_ratios; dims=2) + log(mcmc_count))

r_eff = _generate_r_eff(weights, dims, r_eff, source)
_check_input_validity_psis(reshape(log_ratios, dims), r_eff)

tail_length = similar(log_ratios, Int, data_size)
ξ = similar(log_ratios, data_size)
@inbounds Threads.@threads for i in eachindex(tail_length)
tail_length[i] = @views _def_tail_length(post_sample_size, r_eff[i])
ξ[i] = @views ParetoSmooth._do_psis_i!(weights[i,:], tail_length[i])
tail_length[i] = @views _def_tail_length(mcmc_count, r_eff[i])
ξ[i] = @views ParetoSmooth._psis_smooth!(weights[i,:], tail_length[i])
end

@tullio norm_const[i] := weights[i, j]
@turbo weights .= weights ./ norm_const
ess = psis_ess(weights, r_eff)

_normalize!(weights)

ess = psis_ess(weights, r_eff)
weights = reshape(weights, dims)

if log_weights
@tturbo @. weights = log(weights)
end

return Psis(weights, ξ, ess, r_eff, tail_length, post_sample_size, data_size)
return Psis(weights, ξ, ess, r_eff, tail_length, mcmc_count, data_size)
end


Expand All @@ -95,24 +95,25 @@ end


"""
_do_psis_i!(is_ratios::AbstractVector{Real}, tail_length::Integer) -> T
_psis_smooth!(is_ratios::AbstractVector{AbstractFloat}, tail_length::Integer) -> T

Do PSIS on a single vector, smoothing its tail values.
Do PSIS on a single vector, smoothing its tail values in place before returning ξ.

# Arguments

- `is_ratios::AbstractVector{Real}`: A vector of importance sampling ratios,
scaled to have a maximum of 1.
- `is_ratios::AbstractVector{AbstractFloat}`: A vector of importance sampling ratios,
scaled to have a maximum of 1.

# Returns

- `T<:Real`: ξ, the shape parameter for the GPD; big numbers indicate thick tails.
- `T<:AbstractFloat`: ξ, the estimated shape parameter for the GPD. Bigger numbers
indicate thicker tails.

# Extended help

Additional information can be found in the LOO package from R.
"""
function _do_psis_i!(
function _psis_smooth!(
is_ratios::AbstractVector{T}, tail_length::Integer
) where {T<:Real}

Expand All @@ -135,8 +136,8 @@ function _do_psis_i!(
cutoff = is_ratios[tail_start - 1]
ξ = _psis_smooth_tail!(tail, cutoff)

# truncate at max of raw weights (1 after scaling)
clamp!(is_ratios, 0, 1)
# truncate at max of raw weights (equal to len after scaling)
clamp!(is_ratios, 0, len)
# unsort the ratios to their original position:
is_ratios .= @views is_ratios[invperm(ordering)]

Expand Down Expand Up @@ -175,6 +176,11 @@ function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real}
end


function _normalize!(weights::AbstractArray)
@tullio norm_const[i] := weights[i, j]
@turbo @. weights /= norm_const
end


##########################
#### HELPER FUNCTIONS ####
Expand Down Expand Up @@ -251,7 +257,7 @@ function _check_tail(tail::AbstractVector{T}) where {T<:Real}
throw(
ArgumentError(
"Unable to fit generalized Pareto distribution: tail length was too " *
"short. Likely causese are: \n$LIKELY_ERROR_CAUSES"
"short. Likely causes are: \n$LIKELY_ERROR_CAUSES"
),
)
end
Expand Down
8 changes: 8 additions & 0 deletions src/InternalHelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ function _assume_one_chain(matrix)
end


"""
Safely exponentiate a vector for a scale-invariant operation (exponentiate x - maximum(x))
"""
function _safe_exp(x::AbstractVector)
return exp.(x .- maximum(x))
end


"""
Convert a matrix+chain_index representation to a 3d array representation.
"""
Expand Down
12 changes: 3 additions & 9 deletions src/LeaveOneOut.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ end
[, chain_index::Vector{Integer}, kwargs...]
) -> PsisLoo

Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation
score.
Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross-validation
estimate.

# Arguments

Expand All @@ -52,7 +52,7 @@ score.
See also: [`psis`](@ref), [`loo`](@ref), [`PsisLoo`](@ref).
"""
function psis_loo(
log_likelihood::T, args...;
log_likelihood::AbstractArray, args...;
kwargs...
) where {F<:Real, T<:AbstractArray{F, 3}}

Expand Down Expand Up @@ -114,12 +114,6 @@ function psis_loo(
new_log_ratios = _convert_to_array(log_likelihood, chain_index)
return psis_loo(new_log_ratios, args...; kwargs...)
end

# function psis_loo(log_likelihood, args...;
# subsamples::Integer, rng::AbstractRNG=MersenneTwister(1776), kwargs...
# )
# return log_likelihood = rand()
# end


function _generate_loo_table(
Expand Down
6 changes: 3 additions & 3 deletions src/LooStructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ const CV_DESC = """
- `:ess` is the effective sample size, which measures the simulation error caused by
using Monte Carlo estimates. It is *not* related to the actual sample size, and it
does not measure how accurate your predictions are.
- `:pareto_k` is the estimated value for the parameter `ξ` of the generalized Pareto
distribution. Values above .7 indicate that PSIS has failed to approximate the true
distribution.
- `:pareto_k` is the estimated value for the parameter `ξ` of the generalized Pareto
distribution. Values above .7 indicate that PSIS has failed to approximate the true
distribution.
- `psis_object::Psis`: A `Psis` object containing the results of Pareto-smoothed
importance sampling.

Expand Down
153 changes: 153 additions & 0 deletions src/MomentMatch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
using AxisKeys
using LoopVectorization
using StatsBase
using Tables
using Tullio

"""
adapt_cv(
log_target::Function,
psis_object::Psis,
samples::AbstractArray,
data;
hard_thresh::Real = 2/3,
soft_thresh::Real = 1/2,
soft_cap::Integer = 10
)

Perform importance-weighted moment matching, adapting a sample from a proposal distribution
to more closely match the target distribution.

# Arguments
- `log_target`: The log-pdf of the target distribution. This should be a function having
θ, a vector of parameters, as its input; and `x`, the data set, as its second input.

"""
function adapt_cv(
log_target::Function,
log_p::AbstractArray,
cv_object::AbstractCV,
samples::AbstractArray,
data;
hard_thresh::Real = 2/3,
soft_thresh::Real = 1/2,
soft_cap::Integer = 10
)

dims = size(samples)
n_steps, n_params, n_chains = dims
mcmc_count = n_steps * n_chains
log_count = mcmc_count

weights = psis_object.weights
original_weights = deepcopy(psis_object.weights)
ξ = psis_object.pareto_k
resample_count = size(weights, 1)

Threads.@threads @inbounds for resample in 1:resample_count
@views _match!(
log_target,
log_p[resample, :, :],
weights[resample, :, :],
samples,
ξ[resample],
log_count,
hard_thresh,
soft_thresh,
soft_cap
)
end

end


function _match!(
log_target::Function,
log_proposal::AbstractArray,
weights::AbstractArray,
θ_hats::AbstractArray,
ξ::Real,
log_count::Real,
hard_thresh::Real,
soft_thresh::Real,
soft_cap::Integer,
)

# initialize variables
num_iter = 0 # iterations of IWMM
transform = 1
θ_proposed = similar(θ_hats)
ξ_proposed = soft_thresh
μ = mean(θ_hats; dims=:parameter)
μ_proposed = similar(μ)
σ = std(θ_hats; dims=:parameter)
σ_proposed = similar(σ)
weights_proposed = similar(weights)


while _keep_going(ξ, hard_thresh, soft_thresh, soft_cap, num_iter)

if transform == 1
μ_proposed = mean(θ_hats, weights; dims=2)
σ_proposed .= σ
elseif transform == 2
μ_proposed = mean(θ_hats, weights; dims=2)
σ_proposed = std(θ_hats, weights; mean=μ_proposed, dims=2)
elseif transform == 3
μ_proposed, Σ_proposed = mean_and_cov(θ_hats, weights; dims=2)
σ_proposed = sqrt(Σ_proposed)
elseif transform == 4
break
end

θ_proposed = (θ_hats + μ_proposed - μ) * (σ_proposed * inv(σ))
@. weights_proposed = _safe_exp(log_target(θ_proposed) - log_proposal, log_count)
ξ_proposed = _psis_smooth!(weights)

if ξ_proposed < ξ
num_iter += 1
_normalize!(weights_proposed)

ξ = ξ_proposed
@. begin
μ = μ_proposed
σ = σ_proposed
θ_hats = θ_proposed
weights = weights_proposed
end
else
transform += 1
end

end

return weights

end


function _keep_going(
ξ::Real,
hard_thresh::Real,
soft_thresh::Real,
soft_cap::Int,
num_iter::Integer
)
if ξ > hard_thresh
return true
elseif (ξ > soft_thresh) && (num_iter ≤ soft_cap)
return true
else
return false
end
end



"""
Safely exponentiate x, preventing underflow/overflow by rescaling all elements
by a common factor
"""
function _safe_exp(x, log_count)
return @. exp(x - $maximum(x; dims=2) + log_count)
end
6 changes: 3 additions & 3 deletions src/PublicHelpers.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
export pointwise_log_likelihoods

const ARRAY_DIMS_WARNING = "The supplied array of mcmc samples indicates you have more
parameters than mcmc samples.This is possible, but highly unusual. Please check that your
array of mcmc samples has the following dimensions: [n_samples,n_parms,n_chains]."
const ARRAY_DIMS_WARNING = "The supplied array of MCMC samples indicates you have more " *
"parameters than samples. This is possible, but unusual. Please check your array of MCMC " *
"samples has dimensions `[iter, var, chain]`."

"""
pointwise_log_likelihoods(
Expand Down