Skip to content

Commit

Permalink
Fix promotions with irrationals (#259)
Browse files Browse the repository at this point in the history
* Fix promotions with irrationals

* Add tests

* Fix typo
  • Loading branch information
devmotion authored Sep 25, 2023
1 parent 5847e86 commit b7c2184
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.52"
version = "0.6.53"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
24 changes: 15 additions & 9 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,40 +154,46 @@ end

function Distributions._logpdf(d::TuringScalMvNormal, x::AbstractVector)
σ2 = abs2(d.σ)
return -(length(x) * log(2π * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2
return -(length(x) * log(twoπ * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2
end
function Distributions.logpdf(d::TuringScalMvNormal, x::AbstractMatrix{<:Real})
size(x, 1) == length(d) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return -(size(x, 1) * log(2π * abs2(d.σ)) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
return -(size(x, 1) * log(twoπ * abs2(d.σ)) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
end
function Distributions.loglikelihood(d::TuringScalMvNormal, x::AbstractMatrix{<:Real})
σ2 = abs2(d.σ)
return -(length(x) * log(2π * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2
return -(length(x) * log(twoπ * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2
end

function Distributions._logpdf(d::TuringDiagMvNormal, x::AbstractVector)
return -(length(x) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))) / 2
z = 2 * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))
return -(length(x) * oftype(z, log2π) + z) / 2
end
function Distributions.logpdf(d::TuringDiagMvNormal, x::AbstractMatrix{<:Real})
size(x, 1) == length(d) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return -((size(x, 1) * log(2π) + 2 * sum(log.(d.σ))) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
s = 2 * sum(log.(d.σ))
return -((size(x, 1) * oftype(s, log2π) + s) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
end
function Distributions.loglikelihood(d::TuringDiagMvNormal, x::AbstractMatrix{<:Real})
return -(length(x) * log(2π) + 2 * size(x, 2) * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))) / 2
z = 2 * size(x, 2) * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))
return -(length(x) * oftype(z, log2π) + z) / 2
end

function Distributions._logpdf(d::TuringDenseMvNormal, x::AbstractVector)
return -(length(x) * log(2π) + logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2
z = logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))
return -(length(x) * oftype(z, log2π) + z) / 2
end
function Distributions.logpdf(d::TuringDenseMvNormal, x::AbstractMatrix{<:Real})
size(x, 1) == length(d) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
s = logdet(d.C)
return -((size(x, 1) * oftype(s, log2π) + s) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
end
function Distributions.loglikelihood(d::TuringDenseMvNormal, x::AbstractMatrix{<:Real})
return -(length(x) * log(2π) + size(x, 2) * logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2
z = size(x, 2) * logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))
return -(length(x) * oftype(z, log2π) + z) / 2
end

function Distributions.entropy(d::TuringScalMvNormal)
Expand Down
33 changes: 24 additions & 9 deletions test/others.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,33 +36,48 @@
end

@testset "TuringMvNormal" begin
@testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal]
m = rand(3)
@testset for TD in (TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal), T in (Float64, Float32)
m = rand(T, 3)
if TD <: TuringDenseMvNormal
A = rand(3, 3)
A = rand(T, 3, 3)
C = A' * A + I
d1 = TuringMvNormal(m, C)
d2 = MvNormal(m, C)
elseif TD <: TuringDiagMvNormal
C = rand(3)
C = rand(T, 3)
d1 = TuringMvNormal(m, C)
d2 = MvNormal(m, Diagonal(C .^ 2))
else
C = rand()
C = rand(T)
d1 = TuringMvNormal(m, C)
d2 = MvNormal(m, C^2 * I)
end

@testset "$F" for F in (length, size, mean)
@test F(d1) == F(d2)
end
@test cov(d1) cov(d2)
@test var(d1) var(d2)
C1 = @inferred(cov(d1))
@test C1 isa AbstractMatrix{T}
@test C1 cov(d2)
V1 = @inferred(var(d1))
@test V1 isa AbstractVector{T}
@test V1 var(d2)

x1 = rand(d1)
x2 = rand(d1, 3)
@test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6)
@test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6)
for S in (Float64, Float32)
ST = promote_type(S, T)

z = map(S, x1)
logp = @inferred(logpdf(d1, z))
@test logp isa ST
@test logp logpdf(d2, z) rtol = 1e-6

zs = map(S, x2)
logps = @inferred(logpdf(d1, zs))
@test eltype(logps) === ST
@test logps logpdf(d2, zs) rtol = 1e-6
end
end
end

Expand Down

2 comments on commit b7c2184

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/92205

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.53 -m "<description of version>" b7c2184b7791fa5f379ef3948e760961890eb40f
git push origin v0.6.53

Please sign in to comment.