Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for linking #513

Merged
merged 7 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.10"
version = "0.23.11"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
4 changes: 2 additions & 2 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ The values may or may not be transformed to Euclidean space.
"""
setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val
setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val)
@generated function _setall!(metadata::NamedTuple{names}, val, start=0) where {names}
@generated function _setall!(metadata::NamedTuple{names}, val) where {names}
expr = Expr(:block)
start = :(1)
for f in names
length = :(length(metadata.$f.vals))
length = :(sum(length, metadata.$f.ranges))
finish = :($start + $length - 1)
push!(expr.args, :(metadata.$f.vals .= val[($start):($finish)]))
start = :($start + $length)
Expand Down
71 changes: 48 additions & 23 deletions test/linking.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end

Base.size(d::MyMatrixDistribution) = (d.dim, d.dim)
function Distributions._rand!(
rng::AbstractRNG, d::MyMatrixDistribution, x::AbstractMatrix{<:Real}
rng::Random.AbstractRNG, d::MyMatrixDistribution, x::AbstractMatrix{<:Real}
)
return randn!(rng, x)
end
Expand All @@ -58,29 +58,54 @@ function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Boo
end

@testset "Linking" begin
# Just making sure the transformations are okay.
x = randn(3, 3)
f = TrilToVec((3, 3))
f_inv = inverse(f)
y = f(x)
@test y isa AbstractVector
@test f_inv(f(x)) == LowerTriangular(x)
@testset "simple matrix distribution" begin
# Just making sure the transformations are okay.
x = randn(3, 3)
f = TrilToVec((3, 3))
f_inv = inverse(f)
y = f(x)
@test y isa AbstractVector
@test f_inv(f(x)) == LowerTriangular(x)

# Within a model.
dist = MyMatrixDistribution(3)
@model demo() = m ~ dist
model = demo()
# Within a model.
dist = MyMatrixDistribution(3)
@model demo() = m ~ dist
model = demo()

vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),))
@testset "$(short_varinfo_name(vi))" for vi in vis
# Evaluate once to ensure we have `logp` value.
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
# Difference should just be the log-absdet-jacobian "correction".
@test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2)
@test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist])
# Linked one should be working with a lower-dimensional representation.
@test length(vi_linked[:]) < length(vi[:])
@test length(vi_linked[:]) == 3
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),))
@testset "$(short_varinfo_name(vi))" for vi in vis
# Evaluate once to ensure we have `logp` value.
vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext()))
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
# Difference should just be the log-absdet-jacobian "correction".
@test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2)
@test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist])
# Linked one should be working with a lower-dimensional representation.
@test length(vi_linked[:]) < length(vi[:])
@test length(vi_linked[:]) == length(y)
# Invlinked.
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@test length(vi_invlinked[:]) == length(vi[:])
@test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist])
@test DynamicPPL.getlogp(vi_invlinked) ≈ DynamicPPL.getlogp(vi)
end
end

@testset "dirichlet" begin
@model demo_dirichlet() = x ~ Dirichlet(2, 1.0)
model = demo_dirichlet()
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
@test length(vi[:]) == 2
@test getlogp(vi) ≈ 0
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == 1
@test !(getlogp(vi_linked) ≈ 0) # should now include the log-absdet-jacobian correction
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
# Invlinked.
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@test length(vi_invlinked[:]) == 2
@test getlogp(vi_invlinked) ≈ 0
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ include("test_util.jl")
include("contexts.jl")
include("context_implementations.jl")
include("logdensityfunction.jl")
include("linking.jl")

include("threadsafe.jl")

Expand Down