Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix von Mises-Fisher sampler #1930

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 97 additions & 42 deletions src/samplers/vonmisesfisher.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,39 @@
# Sampler for von Mises-Fisher
# Ref https://doi.org/10.18637/jss.v058.i10
# Ref https://hal.science/hal-04004568v3
struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous}
p::Int # the dimension
κ::Float64
b::Float64
x0::Float64
c::Float64
v::Vector{Float64}
rotate::Bool # whether to rotate the samples
end

function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64)
# Step 1: Calculate b, x₀, and c
p = length(μ)
b = _vmf_bval(p, κ)
b = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))
x0 = (1.0 - b) / (1.0 + b)
c = κ * x0 + (p - 1) * log1p(-abs2(x0))
v = _vmf_householder_vec(μ)
VonMisesFisherSampler(p, κ, b, x0, c, v)

# Compute Householder transformation, and whether it has to be applied
v, rotate = _vmf_householder_vec(μ)

return VonMisesFisherSampler(p, κ, b, x0, c, v, rotate)
end

Base.length(s::VonMisesFisherSampler) = length(s.v)

@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector)
# rotate
scale = 2.0 * (v' * x)
@. x -= (scale * v)
return x
end
function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{<:Real})
# TODO: Generalize to more general indices
Base.require_one_based_indexing(x)

# Sample angle `w`
w = _vmf_angle(rng, spl)

function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector)
w = _vmf_genw(rng, spl)
# Generate sample assuming `μ = (1, 0, 0, ..., 0)`
p = spl.p
x[1] = w
s = 0.0
Expand All @@ -43,47 +48,81 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector)
x[i] *= r
end

return _vmf_rot!(spl.v, x)
# Rotate for general `μ` (if necessary)
return _vmf_rotate!(x, spl)
end

### Core computation

_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))

function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ)
ξ = rand(rng)
w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end

function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ)
r = (p - 1) / 2.0
betad = Beta(r, r)
z = rand(rng, betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand(rng))
z = rand(rng, betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
end
return w::Float64
end
# Step 2: Sample angle W
function _vmf_angle(rng::AbstractRNG, spl::VonMisesFisherSampler)
p = spl.p
κ = spl.κ

# generate the W value -- the key step in simulating vMF
#
# following movMF's document for the p != 3 case
# and Wenzel Jakob's document for the p == 3 case
function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
if p == 3
return _vmf_genw3(rng, p, b, x0, c, κ)
_vmf_angle3(rng, κ)
else
return _vmf_genwp(rng, p, b, x0, c, κ)
# General case: Rejection sampling
# Ref https://doi.org/10.18637/jss.v058.i10
b = spl.b
c = spl.c
p = spl.p
κ = spl.κ
x0 = spl.x0
pm1 = p - 1

if p == 2
# In this case the distribution reduces to the von Mises distribution on the circle
# We exploit the fact that `Beta(1/2, 1/2) = Arcsine(0, 1)`
dist = Arcsine(zero(b), one(b))
while true
z = rand(rng, dist)
w = (1 - (1 + b) * z) / (1 - (1 - b) * z)
if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng)
return w::Float64
end
end
else
# We sample from a `Beta((p - 1)/2, (p - 1)/2)` distribution, possibly repeatedly
# Therefore we construct a sampler
# To avoid the type instability of `sampler(Beta(...))` and `sampler(Gamma(...))`
# we directly construct the Gamma sampler for Gamma((p - 1)/2, 1)
# Since (p - 1)/2 > 1, we construct a `GammaMTSampler`
r = pm1 / 2
gammasampler = GammaMTSampler(Gamma{typeof(r)}(r, one(r)))
while true
# w is supposed to be generated as
# z ~ Beta((p - 1)/ 2, (p - 1)/2)
# w = (1 - (1 + b) * z) / (1 - (1 - b) * z)
# We sample z as
# z1 ~ Gamma((p - 1) / 2, 1)
# z2 ~ Gamma((p - 1) / 2, 1)
# z = z1 / (z1 + z2)
# and rewrite the expression for w
# Cf. case p == 2 above
z1 = rand(rng, gammasampler)
z2 = rand(rng, gammasampler)
b_z1 = b * z1
w = (z2 - b_z1) / (z2 + b_z1)
if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng)
return w::Float64
end
end
end
end
end

# Special case: 2-sphere
@inline function _vmf_angle3(rng::AbstractRNG, κ::Real)
# In this case, we can directly sample the angle
# Ref https://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf
ξ = rand(rng)
w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end

_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) =
_vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ)

# Create Householder transformation to rotate samples for `μ = (1, 0, ..., 0)`
# to samples for general `μ`
function _vmf_householder_vec(μ::Vector{Float64})
# assuming μ is a unit-vector (which it should be)
# can compute v in a single pass over μ
Expand All @@ -92,11 +131,27 @@ function _vmf_householder_vec(μ::Vector{Float64})
v = similar(μ)
v[1] = μ[1] - 1.0
s = sqrt(-2*v[1])
if iszero(s)
# In this case, μ is (approx.) (1, 0, ..., 0)
# Hence no rotation has to be performed and `v` is not used
return v, false
end

v[1] /= s

@inbounds for i in 2:p
v[i] = μ[i] / s
end

return v
return v, true
end

# Rotate samples for general `μ` (if needed)
@inline function _vmf_rotate!(x::AbstractVector{<:Real}, spl::VonMisesFisherSampler)
if spl.rotate
v = spl.v
scale = 2.0 * (v' * x)
@. x -= (scale * v)
end
return x
end
1 change: 0 additions & 1 deletion src/univariate/continuous/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ function rand(rng::AbstractRNG, d::Gamma)
# TODO: shape(d) = 0.5 : use scaled chisq
return rand(rng, GammaIPSampler(d))
elseif shape(d) == 1.0
θ =
return rand(rng, Exponential{partype(d)}(scale(d)))
else
return rand(rng, GammaMTSampler(d))
Expand Down
42 changes: 29 additions & 13 deletions test/multivariate/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ function gen_vmf_tdata(n::Int, p::Int,
end

function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing)
# Random μ
if ismissing(rng)
μ = randn(p)
x = randn(p)
Expand All @@ -34,16 +35,24 @@ function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing)
μ = μ ./ κ

s = Distributions.VonMisesFisherSampler(μ, κ)
@test s.rotate
v = μ - vcat(1, zeros(p-1))
H = I - 2*v*v'/(v'*v)

@test Distributions._vmf_rot!(s.v, copy(x)) ≈ (H*x)

end
@test Distributions._vmf_rotate!(copy(x), s) ≈ (H*x)

# Special case: μ = (1, 0, ..., 0)
# In this case no rotation is performed
μ = zeros(p)
μ[1] = 1
s = Distributions.VonMisesFisherSampler(μ, κ)
@test !s.rotate
@test Distributions._vmf_rotate!(copy(x), s) == x

return nothing
end

function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
function test_angle3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
p = 3

if ismissing(rng)
Expand All @@ -53,21 +62,20 @@ function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missin
end
μ = μ ./ norm(μ)

s = Distributions.VonMisesFisherSampler(μ, float(κ))
spl = Distributions.VonMisesFisherSampler(μ, float(κ))
angle3_res = [Distributions._vmf_angle3(rng, spl.κ) for _ in 1:ns]
angle_res = [Distributions._vmf_angle(rng, spl) for _ in 1:ns]

genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]
genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]

@test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01)
@test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ)
@test mean(angle3_res) ≈ mean(angle_res) rtol=5e-2
@test std(angle3_res) ≈ std(angle_res) rtol=1e-2

# test mean and stdev against analytical formulas
coth_κ = coth(κ)
mean_w = coth_κ - 1/κ
var_w = 1 - coth_κ^2 + 1/κ^2

@test isapprox(mean(genw3_res), mean_w, atol=0.01)
@test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ)
@test mean(angle3_res) ≈ mean_w rtol=5e-2
@test std(angle3_res) ≈ sqrt(var_w) rtol=1e-2
end


Expand Down Expand Up @@ -178,7 +186,15 @@ ns = 10^6

if !ismissing(rng)
@testset "Testing genw with $key at (3, $κ)" for κ in [0.1, 0.5, 1.0, 2.0, 5.0]
test_genw3(κ, ns, rng)
test_angle3(κ, ns, rng)
end
end
end

# issue #1423
@testset "Special case: No rotation" begin
for n in 2:10
d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0)
@test sum(abs2, rand(d)) ≈ 1
end
end
Loading