diff --git a/src/vonmisesfisher.jl b/src/vonmisesfisher.jl index 2b389e3..da7017b 100644 --- a/src/vonmisesfisher.jl +++ b/src/vonmisesfisher.jl @@ -96,28 +96,28 @@ function MeasureTheory.logdensity(d::VonMises{ℝ,(:μ, :κ)}, x) end function MeasureTheory.logdensity(d::VonMisesFisher{M,(:μ, :κ)}, x) where {M} - p = size(x, 1) + p = manifold_dimension(base_manifold(d)) + 1 κ = d.κ return κ * real(dot(d.μ, x)) - lognorm_vmf(p, κ) end function MeasureTheory.logdensity(d::VonMisesFisher{M,(:c,)}, x) where {M} - p = size(x, 1) + p = manifold_dimension(base_manifold(d)) + 1 c = d.c κ = norm(c) return real(dot(c, x)) - lognorm_vmf(p, κ) end function MeasureTheory.logdensity(d::VonMisesFisher{M,(:F,)}, x) where {M} - n = size(x, 1) + n, _ = representation_size(base_manifold(d)) F = d.F return real(dot(F, x)) - logpFq((), (n//2,), (F'F) / 4) end function MeasureTheory.logdensity(d::VonMisesFisher{M,(:U, :D, :V)}, x) where {M} - n = size(x, 1) + n, _ = representation_size(base_manifold(d)) D = Diagonal(d.D) return real(dot(D * d.V', d.U' * x)) - logpFq((), (n//2,), D .^ 2 ./ 4) end function MeasureTheory.logdensity(d::VonMisesFisher{M,(:H, :P)}, x) where {M} - n = size(x, 1) + n, _ = representation_size(base_manifold(d)) P = d.P return real(dot(d.H, P, x)) - logpFq((), (n//2,), (P^2) / 4) end