Skip to content

Commit

Permalink
Rename getindex_raw to getindex_internal
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Sep 17, 2024
1 parent 45c89c4 commit 30252dc
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 32 deletions.
9 changes: 3 additions & 6 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,6 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi,
# what a bijector would result in, even if the input is a view (`SubArray`).
# TODO(torfjelde): An alternative is to implement `view` directly instead.
getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn))
# TODO(mhauru) Maybe rename getindex_raw to getindex_internal and obviate the need for this
# method.
getindex_internal(vnv::VarNamedVector, vn::VarName) = getindex_raw(vnv, vn)

function getindex_internal(vi::VarInfo, vns::Vector{<:VarName})
return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns)
Expand Down Expand Up @@ -676,7 +673,7 @@ function getall(md::Metadata)
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0)
)
end
getall(vnv::VarNamedVector) = getindex_raw(vnv, Colon())
getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon())

"""
setall!(vi::VarInfo, val)
Expand Down Expand Up @@ -1439,7 +1436,7 @@ function _link_metadata!!(
# First transform from however the variable is stored in vnv to the model
# representation.
transform_to_orig = gettransform(metadata, vn)
val_old = getindex_raw(metadata, vn)
val_old = getindex_internal(metadata, vn)
val_orig, logjac1 = with_logabsdet_jacobian(transform_to_orig, val_old)
# Then transform from the model representation to the linked representation.
transform_from_linked = from_linked_vec_transform(dists[vn])
Expand Down Expand Up @@ -1559,7 +1556,7 @@ function _invlink_metadata!!(
vns = target_vns === nothing ? keys(metadata) : target_vns
for vn in vns
transform = gettransform(metadata, vn)
old_val = getindex_raw(metadata, vn)
old_val = getindex_internal(metadata, vn)
new_val, logjac = with_logabsdet_jacobian(transform, old_val)
# TODO(mhauru) We are calling a !! function but ignoring the return value.
acclogp!!(varinfo, -logjac)
Expand Down
34 changes: 17 additions & 17 deletions src/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv))
Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn)

# `getindex` & `setindex!`
Base.getindex(vnv::VarNamedVector, i::Int) = getindex_raw(vnv, i)
Base.getindex(vnv::VarNamedVector, i::Int) = getindex_internal(vnv, i)
function Base.getindex(vnv::VarNamedVector, vn::VarName)
x = getindex_raw(vnv, vn)
x = getindex_internal(vnv, vn)
f = gettransform(vnv, vn)
return f(x)
end
Expand Down Expand Up @@ -369,15 +369,15 @@ function index_to_vals_index(vnv::VarNamedVector, i::Int)
end

"""
getindex_raw(vnv::VarNamedVector, i::Int)
getindex_raw(vnv::VarNamedVector, vn::VarName)
getindex_internal(vnv::VarNamedVector, i::Int)
getindex_internal(vnv::VarNamedVector, vn::VarName)
Like `getindex`, but returns the values as they are stored in `vnv` without transforming.
For integer indices this is the same as `getindex`, but for `VarName`s this is different.
"""
getindex_raw(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)]
getindex_raw(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)]
getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)]
getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)]

# `getindex` for `Colon`
function Base.getindex(vnv::VarNamedVector, ::Colon)
Expand All @@ -388,34 +388,34 @@ function Base.getindex(vnv::VarNamedVector, ::Colon)
end
end

getindex_raw(vnv::VarNamedVector, ::Colon) = getindex(vnv, Colon())
getindex_internal(vnv::VarNamedVector, ::Colon) = getindex(vnv, Colon())

# TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs
# sampler.
function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler)
throw(ErrorException("Cannot index a VarNamedVector with a sampler."))
end

Base.setindex!(vnv::VarNamedVector, val, i::Int) = setindex_raw!(vnv, val, i)
Base.setindex!(vnv::VarNamedVector, val, i::Int) = setindex_internal!(vnv, val, i)
function Base.setindex!(vnv::VarNamedVector, val, vn::VarName)
# Since setindex! does not change the transform, we need to apply it to `val`.
f = inverse(gettransform(vnv, vn))
return setindex_raw!(vnv, f(val), vn)
return setindex_internal!(vnv, f(val), vn)
end

"""
setindex_raw!(vnv::VarNamedVector, val, i::Int)
setindex_raw!(vnv::VarNamedVector, val, vn::VarName)
setindex_internal!(vnv::VarNamedVector, val, i::Int)
setindex_internal!(vnv::VarNamedVector, val, vn::VarName)
Like `setindex!`, but sets the values as they are stored in `vnv` without transforming.
For integer indices this is the same as `setindex!`, but for `VarName`s this is different.
"""
function setindex_raw!(vnv::VarNamedVector, val, i::Int)
function setindex_internal!(vnv::VarNamedVector, val, i::Int)
return vnv.vals[index_to_vals_index(vnv, i)] = val
end

function setindex_raw!(vnv::VarNamedVector, val::AbstractVector, vn::VarName)
function setindex_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName)
return vnv.vals[getrange(vnv, vn)] = val
end

Expand Down Expand Up @@ -565,13 +565,13 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector)
# Extract the necessary information from `left` or `right`.
if vn in vns_left && !(vn in vns_right)
# `vn` is only in `left`.
val = getindex_raw(left_vnv, vn)
val = getindex_internal(left_vnv, vn)
f = gettransform(left_vnv, vn)
is_unconstrained[idx] = istrans(left_vnv, vn)
else
# `vn` is either in both or just `right`.
# Note that in a `merge` the right value has precedence.
val = getindex_raw(right_vnv, vn)
val = getindex_internal(right_vnv, vn)
f = gettransform(right_vnv, vn)
is_unconstrained[idx] = istrans(right_vnv, vn)
end
Expand Down Expand Up @@ -621,7 +621,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:V
isempty(vnv) && return vnv_new

for vn in vns
push!(vnv_new, vn, getindex_raw(vnv, vn), gettransform(vnv, vn))
push!(vnv_new, vn, getindex_internal(vnv, vn), gettransform(vnv, vn))
settrans!(vnv_new, istrans(vnv, vn), vn)
end

Expand Down Expand Up @@ -977,7 +977,7 @@ end
set!!(vnv::VarNamedVector, vn::VarName, val) = update!!(vnv, vn, val)

function setval!(vnv::VarNamedVector, val, vn::VarName)
return setindex_raw!(vnv, tovec(val), vn)
return setindex_internal!(vnv, tovec(val), vn)
end

function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange})
Expand Down
19 changes: 10 additions & 9 deletions test/varnamedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,25 +228,26 @@ end
@test vnv[vn_right] == val_right .+ 100
end

# `getindex_raw`
@testset "getindex_raw" begin
# `getindex_internal`
@testset "getindex_internal" begin
# With `VarName` index.
@test DynamicPPL.getindex_raw(vnv_base, vn_left) == to_vec_left(val_left)
@test DynamicPPL.getindex_raw(vnv_base, vn_right) == to_vec_right(val_right)
@test DynamicPPL.getindex_internal(vnv_base, vn_left) == to_vec_left(val_left)
@test DynamicPPL.getindex_internal(vnv_base, vn_right) ==
to_vec_right(val_right)
# With `Int` index.
val_vec = vcat(to_vec_left(val_left), to_vec_right(val_right))
@test all(
DynamicPPL.getindex_raw(vnv_base, i) == val_vec[i] for
DynamicPPL.getindex_internal(vnv_base, i) == val_vec[i] for
i in 1:length(val_vec)
)
end

# `setindex_raw!`
@testset "setindex_raw!" begin
# `setindex_internal!`
@testset "setindex_internal!" begin
vnv = deepcopy(vnv_base)
DynamicPPL.setindex_raw!(vnv, to_vec_left(val_left .+ 100), vn_left)
DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left)
@test vnv[vn_left] == val_left .+ 100
DynamicPPL.setindex_raw!(vnv, to_vec_right(val_right .+ 100), vn_right)
DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right)
@test vnv[vn_right] == val_right .+ 100
end

Expand Down

0 comments on commit 30252dc

Please sign in to comment.