Skip to content

Commit

Permalink
Stop exporting any VNV stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Oct 4, 2024
1 parent b3f92c2 commit 65c94ca
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 114 deletions.
18 changes: 0 additions & 18 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,9 @@ export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
TypedVarInfo,
VectorVarInfo,
SimpleVarInfo,
VarNamedVector,
length_internal,
getindex_internal,
update!,
reset!,
setindex_internal!,
update_internal!,
insert_internal!,
update!!,
insert!!,
reset!!,
setindex!!,
setindex_internal!!,
update_internal!!,
insert_internal!!,
push!!,
empty!!,
loosen_types!!,
tighten_types,
subset,
getlogp,
setlogp!!,
Expand Down
4 changes: 2 additions & 2 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ function setup_varinfos(
# SimpleVarInfo
svi_typed = SimpleVarInfo(example_values)
svi_untyped = SimpleVarInfo(OrderedDict())
svi_vnv = SimpleVarInfo(VarNamedVector())
svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector())

# SimpleVarInfo{<:Any,<:Ref}
svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed)))
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))
svi_vnv_ref = SimpleVarInfo(VarNamedVector(), Ref(getlogp(svi_vnv)))
svi_vnv_ref = SimpleVarInfo(DynamicPPL.VarNamedVector(), Ref(getlogp(svi_vnv)))

lp = getlogp(vi_typed_metadata)
varinfos = map((
Expand Down
2 changes: 1 addition & 1 deletion src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ $(FIELDS)
The values for different variables are internally all stored in a single vector. For
instance,
```jldoctest varnamedvector-struct
julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!
julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!, update!, getindex_internal
julia> vnv = VarNamedVector();
Expand Down
16 changes: 9 additions & 7 deletions test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,29 @@
end

@testset "VarNamedVector" begin
svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m) => 1.0))
svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0))
@test getlogp(svi) == 0.0
@test haskey(svi, @varname(m))
@test !haskey(svi, @varname(m[1]))

svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m) => [1.0]))
svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0]))
@test getlogp(svi) == 0.0
@test haskey(svi, @varname(m))
@test haskey(svi, @varname(m[1]))
@test !haskey(svi, @varname(m[2]))
@test svi[@varname(m)][1] == svi[@varname(m[1])]

svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m.a) => [1.0]))
svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0]))
@test haskey(svi, @varname(m))
@test haskey(svi, @varname(m.a))
@test haskey(svi, @varname(m.a[1]))
@test !haskey(svi, @varname(m.a[2]))
@test !haskey(svi, @varname(m.a.b))
# The implementation of haskey and getvalue fo VarNamedVector is incomplete, the
# next test is here to remind of us that.
svi = SimpleVarInfo(push!!(VarNamedVector(), @varname(m.a.b) => [1.0]))
svi = SimpleVarInfo(
push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0])
)
@test_broken !haskey(svi, @varname(m.a.b.c.d))
end
end
Expand All @@ -89,7 +91,7 @@
@testset "$(typeof(vi))" for vi in (
SimpleVarInfo(Dict()),
SimpleVarInfo(values_constrained),
SimpleVarInfo(VarNamedVector()),
SimpleVarInfo(DynamicPPL.VarNamedVector()),
VarInfo(model),
)
for vn in DynamicPPL.TestUtils.varnames(model)
Expand Down Expand Up @@ -143,7 +145,7 @@
# to see whether this is the case.
svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model))
svi_dict = SimpleVarInfo(VarInfo(model), Dict)
vnv = VarNamedVector()
vnv = DynamicPPL.VarNamedVector()
for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model))
vnv = push!!(vnv, VarName{k}() => v)
end
Expand Down Expand Up @@ -232,7 +234,7 @@
# Initialize.
svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true)
svi_nt = last(DynamicPPL.evaluate!!(model, svi_nt, SamplingContext()))
svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(VarNamedVector()), true)
svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true)
svi_vnv = last(DynamicPPL.evaluate!!(model, svi_vnv, SamplingContext()))

for svi in (svi_nt, svi_vnv)
Expand Down
6 changes: 4 additions & 2 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ function short_varinfo_name(vi::TypedVarInfo)
return "TypedVarInfo"
end
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
short_varinfo_name(::VectorVarInfo) = "VectorVarInfo"
short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo"
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
short_varinfo_name(::SimpleVarInfo{<:VarNamedVector}) = "SimpleVarInfo{<:VarNamedVector}"
function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector})
return "SimpleVarInfo{<:VarNamedVector}"
end

# convenient functions for testing model.jl
# function to modify the representation of values based on their length
Expand Down
8 changes: 4 additions & 4 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test vi[SampleFromPrior()][1] == 3 * r

# TODO(mhauru) Implement these functions for other VarInfo types too.
if vi isa VectorVarInfo
if vi isa DynamicPPL.VectorVarInfo
delete!(vi, vn)
@test isempty(vi)
vi = push!!(vi, vn, r, dist, gid)
Expand All @@ -128,7 +128,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
test_base!!(TypedVarInfo(vi))
test_base!!(SimpleVarInfo())
test_base!!(SimpleVarInfo(Dict()))
test_base!!(SimpleVarInfo(VarNamedVector()))
test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector()))
end
@testset "flags" begin
# Test flag setting:
Expand Down Expand Up @@ -209,7 +209,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
model, SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata
)
vi_untyped = VarInfo(DynamicPPL.Metadata())
vi_vnv = VarInfo(VarNamedVector())
vi_vnv = VarInfo(DynamicPPL.VarNamedVector())
vi_vnv_typed = VarInfo(
model, SampleFromPrior(), DefaultContext(), DynamicPPL.VarNamedVector
)
Expand Down Expand Up @@ -374,7 +374,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test getlogp(vi) Bijectors.logpdf_with_trans(dist, x, true)

## `SimpleVarInfo{<:VarNamedVector}`
vi = DynamicPPL.settrans!!(SimpleVarInfo(VarNamedVector()), true)
vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true)
# Sample in unconstrained space.
vi = last(DynamicPPL.evaluate!!(model, vi, SamplingContext()))
f = DynamicPPL.from_linked_internal_transform(vi, vn, dist)
Expand Down
Loading

0 comments on commit 65c94ca

Please sign in to comment.