Skip to content

Commit

Permalink
Reintroduce Int indexing to VNV
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Oct 4, 2024
1 parent eb5577b commit b3f92c2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ To ensure that `VarInfo` is simple and intuitive to work with, we want `VarInfo`

- `getindex_internal(::VarInfo, ::VarName)`: get the flattened value of a single variable.
- `getindex_internal(::VarInfo, ::Colon)`: get the flattened values of all variables.
- `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values
- `setindex_internal!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable.
- `setindex_internal!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values
- `length_internal(::VarInfo)`: return the length of the flat representation of `metadata`.

The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised.
Expand Down
19 changes: 15 additions & 4 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ Like `getindex`, but returns the values as they are stored in `vnv`, without tra
"""
getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)]

"""
getindex_internal(vnv::VarNamedVector, i::Int)
Gets the `i`th element of the internal storage vector, ignoring inactive entries.
"""
getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)]

function getindex_internal(vnv::VarNamedVector, ::Colon)
return if has_inactive(vnv)
mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges)
Expand Down Expand Up @@ -601,6 +608,14 @@ end

"""
setindex_internal!(vnv::VarNamedVector, val, i::Int)
Sets the `i`th element of the internal storage vector, ignoring inactive entries.
"""
function setindex_internal!(vnv::VarNamedVector, val, i::Int)
return vnv.vals[index_to_vals_index(vnv, i)] = val
end

"""
setindex_internal!(vnv::VarNamedVector, val, vn::VarName[, transform])
Like `setindex!`, but sets the values as they are stored internally in `vnv`.
Expand All @@ -609,10 +624,6 @@ Optionally can set the transformation, such that `transform(val)` is the origina
the variable. By default, the transform is the identity if creating a new entry in `vnv`, or
the existing transform if updating an existing entry.
"""
function setindex_internal!(vnv::VarNamedVector, val, i::Int)
return vnv.vals[index_to_vals_index(vnv, i)] = val
end

function setindex_internal!(
vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing
)
Expand Down
23 changes: 23 additions & 0 deletions test/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,16 @@ end
to_vec_right(val_right)
end

@testset "getindex_internal with Ints" begin
for (i, val) in enumerate(to_vec_left(val_left))
@test DynamicPPL.getindex_internal(vnv_base, i) == val
end
offset = length(to_vec_left(val_left))
for (i, val) in enumerate(to_vec_right(val_right))
@test DynamicPPL.getindex_internal(vnv_base, offset + i) == val
end
end

@testset "update!" begin
vnv = deepcopy(vnv_base)
update!(vnv, val_left .+ 100, vn_left)
Expand Down Expand Up @@ -371,6 +381,16 @@ end
end
end

@testset "setindex_internal! with Ints" begin
vnv = deepcopy(vnv_base)
for i in 1:length_internal(vnv_base)
setindex_internal!(vnv, i, i)
end
for i in 1:length_internal(vnv_base)
@test getindex_internal(vnv, i) == i
end
end

@testset "setindex_internal!!" begin
# Not setting the transformation.
vnv = deepcopy(vnv_base)
Expand Down Expand Up @@ -418,6 +438,7 @@ end
@test vnv[vn] == val .+ 1
@test length_internal(vnv) == expected_length
@test length(x) == length_internal(vnv)
@test all(getindex_internal(vnv, i) == x[i] for i in eachindex(x))

# There should be no redundant values in the underlying vector.
@test !DynamicPPL.has_inactive(vnv)
Expand All @@ -441,6 +462,7 @@ end
@test vnv[vn] == val .+ 1
@test length_internal(vnv) == expected_length
@test length(x) == length_internal(vnv)
@test all(getindex_internal(vnv, i) == x[i] for i in eachindex(x))
end

vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals)
Expand All @@ -461,6 +483,7 @@ end
@test vnv[vn] == val .+ 1
@test length_internal(vnv) == expected_length
@test length(x) == length_internal(vnv)
@test all(getindex_internal(vnv, i) == x[i] for i in eachindex(x))
end
end
end
Expand Down

0 comments on commit b3f92c2

Please sign in to comment.