Skip to content

Commit

Permalink
Merge branch 'master' into yebai-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Aug 30, 2023
2 parents 7065f17 + 549d9b1 commit 5da3fae
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.13"
version = "0.23.14"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 1 addition & 1 deletion src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ values_as(vi::SimpleVarInfo) = vi.values
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T}
isempty(vi) && return T[]
return mapreduce(v -> vec([v;]), vcat, values(vi.values))
return mapreduce(vectorize, vcat, values(vi.values))
end
function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict}
return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values)))
Expand Down
18 changes: 13 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,10 @@ invlink_transform(dist) = inverse(link_transform(dist))
# Helper functions for vectorize/reconstruct values #
#####################################################

vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL))
vectorize(d, r) = vectorize(r)
vectorize(r::Real) = [r]
vectorize(r::AbstractArray{<:Real}) = copy(vec(r))
vectorize(r::Cholesky) = copy(vec(r.UL))

# NOTE:
# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real.
Expand All @@ -237,6 +236,15 @@ reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)

function reconstruct(dist::LKJCholesky, val::AbstractVector{<:Real})
return reconstruct(dist, reshape(val, size(dist)))
end
function reconstruct(dist::LKJCholesky, val::AbstractMatrix{<:Real})
return Cholesky(val, dist.uplo, 0)
end
reconstruct(::LKJCholesky, val::Cholesky) = val

function reconstruct(
::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector
)
Expand Down
7 changes: 6 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,12 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
The values may or may not be transformed to Euclidean space.
"""
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
setval!(md::Metadata, val, vn::VarName) = md.vals[getrange(md, vn)] = [val;]
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
return md.vals[getrange(md, vn)] = val
end
function setval!(md::Metadata, val, vn::VarName)
return md.vals[getrange(md, vn)] = vectorize(getdist(md, vn), val)
end

"""
getval(vi::VarInfo, vns::Vector{<:VarName})
Expand Down
39 changes: 37 additions & 2 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,51 @@ end
end
end

@testset "LKJCholesky" begin
@testset "uplo=$uplo" for uplo in ['L', 'U']
@model demo_lkj(d) = x ~ LKJCholesky(d, 1.0, uplo)
@testset "d=$d" for d in [2, 3, 5]
model = demo_lkj(d)
dist = LKJCholesky(d, 1.0, uplo)
values_original = rand(model)
vis = DynamicPPL.TestUtils.setup_varinfos(
model, values_original, (@varname(x),)
)
@testset "$(short_varinfo_name(vi))" for vi in vis
val = vi[@varname(x), dist]
# Ensure that `reconstruct` works as intended.
@test val isa Cholesky
@test val.uplo == uplo

@test length(vi[:]) == d^2
lp = logpdf(dist, val)
lp_model = logjoint(model, vi)
@test lp_model lp
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == d * (d - 1) ÷ 2
# Should now include the log-absdet-jacobian correction.
@test !(getlogp(vi_linked) lp)
# Invlinked.
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@test length(vi_invlinked[:]) == d^2
@test getlogp(vi_invlinked) lp
end
end
end
end

# Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504
@testset "dirichlet" begin
@testset "Dirichlet" begin
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
@test getlogp(vi) lp
lp_model = logjoint(model, vi)
@test lp_model lp
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == d - 1
Expand Down
13 changes: 9 additions & 4 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,18 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1)
nt = DynamicPPL.tonamedtuple(var_info)
for (k, (vals, names)) in pairs(nt)
for (n, v) in zip(names, vals)
chain_val = if Symbol(n) keys(chain)
if Symbol(n) keys(chain)
# Assume it's a group
vec(MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx])
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain[sample_idx, n, chain_idx]
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end
@test v == chain_val

@test v_true == chain_val
end
end
end
Expand Down
1 change: 1 addition & 0 deletions test/turing/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ setprogress!(false)
Random.seed!(100)

# load test utilities
include(joinpath(pathof(DynamicPPL), "..", "..", "test", "test_util.jl"))
include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_tests.jl"))

@testset "Turing" begin
Expand Down

0 comments on commit 5da3fae

Please sign in to comment.