From b7c2184b7791fa5f379ef3948e760961890eb40f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 26 Sep 2023 00:26:16 +0200 Subject: [PATCH] Fix promotions with irrationals (#259) * Fix promotions with irrationals * Add tests * Fix typo --- Project.toml | 2 +- src/multivariate.jl | 24 +++++++++++++++--------- test/others.jl | 33 ++++++++++++++++++++++++--------- 3 files changed, 40 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 135695d..7b61f7f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.52" +version = "0.6.53" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/multivariate.jl b/src/multivariate.jl index 31adf95..4c5c30d 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -154,40 +154,46 @@ end function Distributions._logpdf(d::TuringScalMvNormal, x::AbstractVector) σ2 = abs2(d.σ) - return -(length(x) * log(2π * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2 + return -(length(x) * log(twoπ * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2 end function Distributions.logpdf(d::TuringScalMvNormal, x::AbstractMatrix{<:Real}) size(x, 1) == length(d) || throw(DimensionMismatch("Inconsistent array dimensions.")) - return -(size(x, 1) * log(2π * abs2(d.σ)) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2 + return -(size(x, 1) * log(twoπ * abs2(d.σ)) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2 end function Distributions.loglikelihood(d::TuringScalMvNormal, x::AbstractMatrix{<:Real}) σ2 = abs2(d.σ) - return -(length(x) * log(2π * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2 + return -(length(x) * log(twoπ * σ2) + sum(abs2.(x .- d.m)) / σ2) / 2 end function Distributions._logpdf(d::TuringDiagMvNormal, x::AbstractVector) - return -(length(x) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))) / 2 + z = 2 * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ)) + return -(length(x) * oftype(z, log2π) + z) / 2 end function Distributions.logpdf(d::TuringDiagMvNormal, x::AbstractMatrix{<:Real}) size(x, 1) == length(d) || throw(DimensionMismatch("Inconsistent array dimensions.")) - return -((size(x, 1) * log(2π) + 2 * sum(log.(d.σ))) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2 + s = 2 * sum(log.(d.σ)) + return -((size(x, 1) * oftype(s, log2π) + s) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2 end function Distributions.loglikelihood(d::TuringDiagMvNormal, x::AbstractMatrix{<:Real}) - return -(length(x) * log(2π) + 2 * size(x, 2) * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ))) / 2 + z = 2 * size(x, 2) * sum(log.(d.σ)) + sum(abs2.((x .- d.m) ./ d.σ)) + return -(length(x) * oftype(z, log2π) + z) / 2 end function Distributions._logpdf(d::TuringDenseMvNormal, x::AbstractVector) - return -(length(x) * log(2π) + logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2 + z = logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m))) + return -(length(x) * oftype(z, log2π) + z) / 2 end function Distributions.logpdf(d::TuringDenseMvNormal, x::AbstractMatrix{<:Real}) size(x, 1) == length(d) || throw(DimensionMismatch("Inconsistent array dimensions.")) - return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2 + s = logdet(d.C) + return -((size(x, 1) * oftype(s, log2π) + s) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2 end function Distributions.loglikelihood(d::TuringDenseMvNormal, x::AbstractMatrix{<:Real}) - return -(length(x) * log(2π) + size(x, 2) * logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2 + z = size(x, 2) * logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m))) + return -(length(x) * oftype(z, log2π) + z) / 2 end function Distributions.entropy(d::TuringScalMvNormal) diff --git a/test/others.jl b/test/others.jl index 23549ae..66274c4 100644 --- a/test/others.jl +++ b/test/others.jl @@ -36,19 +36,19 @@ end @testset "TuringMvNormal" begin - @testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal] - m = rand(3) + @testset for TD in (TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal), T in (Float64, Float32) + m = rand(T, 3) if TD <: TuringDenseMvNormal - A = rand(3, 3) + A = rand(T, 3, 3) C = A' * A + I d1 = TuringMvNormal(m, C) d2 = MvNormal(m, C) elseif TD <: TuringDiagMvNormal - C = rand(3) + C = rand(T, 3) d1 = TuringMvNormal(m, C) d2 = MvNormal(m, Diagonal(C .^ 2)) else - C = rand() + C = rand(T) d1 = TuringMvNormal(m, C) d2 = MvNormal(m, C^2 * I) end @@ -56,13 +56,28 @@ @testset "$F" for F in (length, size, mean) @test F(d1) == F(d2) end - @test cov(d1) ≈ cov(d2) - @test var(d1) ≈ var(d2) + C1 = @inferred(cov(d1)) + @test C1 isa AbstractMatrix{T} + @test C1 ≈ cov(d2) + V1 = @inferred(var(d1)) + @test V1 isa AbstractVector{T} + @test V1 ≈ var(d2) x1 = rand(d1) x2 = rand(d1, 3) - @test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6) - @test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6) + for S in (Float64, Float32) + ST = promote_type(S, T) + + z = map(S, x1) + logp = @inferred(logpdf(d1, z)) + @test logp isa ST + @test logp ≈ logpdf(d2, z) rtol = 1e-6 + + zs = map(S, x2) + logps = @inferred(logpdf(d1, zs)) + @test eltype(logps) === ST + @test logps ≈ logpdf(d2, zs) rtol = 1e-6 + end end end