From 30252dc72461f3f048860462f7a7c0ba8a7950fa Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 17 Sep 2024 13:22:19 +0100 Subject: [PATCH] Rename getindex_raw to getindex_internal --- src/varinfo.jl | 9 +++------ src/varnamedvector.jl | 34 +++++++++++++++++----------------- test/varnamedvector.jl | 19 ++++++++++--------- 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index c2a535d3f..8b548cc14 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -637,9 +637,6 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, # what a bijector would result in, even if the input is a view (`SubArray`). # TODO(torfjelde): An alternative is to implement `view` directly instead. getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) -# TODO(mhauru) Maybe rename getindex_raw to getindex_internal and obviate the need for this -# method. -getindex_internal(vnv::VarNamedVector, vn::VarName) = getindex_raw(vnv, vn) function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) @@ -676,7 +673,7 @@ function getall(md::Metadata) Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) ) end -getall(vnv::VarNamedVector) = getindex_raw(vnv, Colon()) +getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon()) """ setall!(vi::VarInfo, val) @@ -1439,7 +1436,7 @@ function _link_metadata!!( # First transform from however the variable is stored in vnv to the model # representation. transform_to_orig = gettransform(metadata, vn) - val_old = getindex_raw(metadata, vn) + val_old = getindex_internal(metadata, vn) val_orig, logjac1 = with_logabsdet_jacobian(transform_to_orig, val_old) # Then transform from the model representation to the linked representation. transform_from_linked = from_linked_vec_transform(dists[vn]) @@ -1559,7 +1556,7 @@ function _invlink_metadata!!( vns = target_vns === nothing ? keys(metadata) : target_vns for vn in vns transform = gettransform(metadata, vn) - old_val = getindex_raw(metadata, vn) + old_val = getindex_internal(metadata, vn) new_val, logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. acclogp!!(varinfo, -logjac) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 9d9c9812d..389819ecf 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -298,9 +298,9 @@ Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv)) Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn) # `getindex` & `setindex!` -Base.getindex(vnv::VarNamedVector, i::Int) = getindex_raw(vnv, i) +Base.getindex(vnv::VarNamedVector, i::Int) = getindex_internal(vnv, i) function Base.getindex(vnv::VarNamedVector, vn::VarName) - x = getindex_raw(vnv, vn) + x = getindex_internal(vnv, vn) f = gettransform(vnv, vn) return f(x) end @@ -369,15 +369,15 @@ function index_to_vals_index(vnv::VarNamedVector, i::Int) end """ - getindex_raw(vnv::VarNamedVector, i::Int) - getindex_raw(vnv::VarNamedVector, vn::VarName) + getindex_internal(vnv::VarNamedVector, i::Int) + getindex_internal(vnv::VarNamedVector, vn::VarName) Like `getindex`, but returns the values as they are stored in `vnv` without transforming. For integer indices this is the same as `getindex`, but for `VarName`s this is different. """ -getindex_raw(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] -getindex_raw(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] +getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] +getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] # `getindex` for `Colon` function Base.getindex(vnv::VarNamedVector, ::Colon) @@ -388,7 +388,7 @@ function Base.getindex(vnv::VarNamedVector, ::Colon) end end -getindex_raw(vnv::VarNamedVector, ::Colon) = getindex(vnv, Colon()) +getindex_internal(vnv::VarNamedVector, ::Colon) = getindex(vnv, Colon()) # TODO(mhauru): Remove this as soon as possible. Only needed because of the old Gibbs # sampler. @@ -396,26 +396,26 @@ function Base.getindex(vnv::VarNamedVector, spl::AbstractSampler) throw(ErrorException("Cannot index a VarNamedVector with a sampler.")) end -Base.setindex!(vnv::VarNamedVector, val, i::Int) = setindex_raw!(vnv, val, i) +Base.setindex!(vnv::VarNamedVector, val, i::Int) = setindex_internal!(vnv, val, i) function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) # Since setindex! does not change the transform, we need to apply it to `val`. f = inverse(gettransform(vnv, vn)) - return setindex_raw!(vnv, f(val), vn) + return setindex_internal!(vnv, f(val), vn) end """ - setindex_raw!(vnv::VarNamedVector, val, i::Int) - setindex_raw!(vnv::VarNamedVector, val, vn::VarName) + setindex_internal!(vnv::VarNamedVector, val, i::Int) + setindex_internal!(vnv::VarNamedVector, val, vn::VarName) Like `setindex!`, but sets the values as they are stored in `vnv` without transforming. For integer indices this is the same as `setindex!`, but for `VarName`s this is different. """ -function setindex_raw!(vnv::VarNamedVector, val, i::Int) +function setindex_internal!(vnv::VarNamedVector, val, i::Int) return vnv.vals[index_to_vals_index(vnv, i)] = val end -function setindex_raw!(vnv::VarNamedVector, val::AbstractVector, vn::VarName) +function setindex_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName) return vnv.vals[getrange(vnv, vn)] = val end @@ -565,13 +565,13 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # Extract the necessary information from `left` or `right`. if vn in vns_left && !(vn in vns_right) # `vn` is only in `left`. - val = getindex_raw(left_vnv, vn) + val = getindex_internal(left_vnv, vn) f = gettransform(left_vnv, vn) is_unconstrained[idx] = istrans(left_vnv, vn) else # `vn` is either in both or just `right`. # Note that in a `merge` the right value has precedence. - val = getindex_raw(right_vnv, vn) + val = getindex_internal(right_vnv, vn) f = gettransform(right_vnv, vn) is_unconstrained[idx] = istrans(right_vnv, vn) end @@ -621,7 +621,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{VN}) where {VN<:V isempty(vnv) && return vnv_new for vn in vns - push!(vnv_new, vn, getindex_raw(vnv, vn), gettransform(vnv, vn)) + push!(vnv_new, vn, getindex_internal(vnv, vn), gettransform(vnv, vn)) settrans!(vnv_new, istrans(vnv, vn), vn) end @@ -977,7 +977,7 @@ end set!!(vnv::VarNamedVector, vn::VarName, val) = update!!(vnv, vn, val) function setval!(vnv::VarNamedVector, val, vn::VarName) - return setindex_raw!(vnv, tovec(val), vn) + return setindex_internal!(vnv, tovec(val), vn) end function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 425d809dd..ba365d24a 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -228,25 +228,26 @@ end @test vnv[vn_right] == val_right .+ 100 end - # `getindex_raw` - @testset "getindex_raw" begin + # `getindex_internal` + @testset "getindex_internal" begin # With `VarName` index. - @test DynamicPPL.getindex_raw(vnv_base, vn_left) == to_vec_left(val_left) - @test DynamicPPL.getindex_raw(vnv_base, vn_right) == to_vec_right(val_right) + @test DynamicPPL.getindex_internal(vnv_base, vn_left) == to_vec_left(val_left) + @test DynamicPPL.getindex_internal(vnv_base, vn_right) == + to_vec_right(val_right) # With `Int` index. val_vec = vcat(to_vec_left(val_left), to_vec_right(val_right)) @test all( - DynamicPPL.getindex_raw(vnv_base, i) == val_vec[i] for + DynamicPPL.getindex_internal(vnv_base, i) == val_vec[i] for i in 1:length(val_vec) ) end - # `setindex_raw!` - @testset "setindex_raw!" begin + # `setindex_internal!` + @testset "setindex_internal!" begin vnv = deepcopy(vnv_base) - DynamicPPL.setindex_raw!(vnv, to_vec_left(val_left .+ 100), vn_left) + DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.setindex_raw!(vnv, to_vec_right(val_right .+ 100), vn_right) + DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) @test vnv[vn_right] == val_right .+ 100 end