diff --git a/Project.toml b/Project.toml index 009bc15d5..c03084058 100644 --- a/Project.toml +++ b/Project.toml @@ -46,7 +46,7 @@ ADTypes = "0.2" AbstractMCMC = "5" AbstractPPL = "0.8.4" Accessors = "0.1" -BangBang = "0.4" +BangBang = "0.4.1" Bijectors = "0.13.9" ChainRulesCore = "1" Compat = "4" diff --git a/src/utils.jl b/src/utils.jl index 4b45988df..15c7078b8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -548,43 +548,6 @@ function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) wh return child end -# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233 -# and https://github.com/JuliaFolds/BangBang.jl/pull/238, https://github.com/JuliaFolds2/BangBang.jl/pull/16. -# This avoids type-instability in `dot_assume` for `SimpleVarInfo`. -# The following code a copy from https://github.com/JuliaFolds2/BangBang.jl/pull/16 authored by torfjelde -# Default implementation for `_setindex!` with `AbstractArray`. -# But this will return `false` even in cases such as -# -# setindex!!([1, 2, 3], [4, 5, 6], :) -# -# because `promote_type(eltype(C), T) <: eltype(C)` is `false`. -# To address this, we specialize on the case where `T<:AbstractArray`. -# In addition, we need to support a wide range of indexing behaviors: -# -# We also need to ensure that the dimensionality of the index is -# valid, i.e. that we're not returning `true` in cases such as -# -# setindex!!([1, 2, 3], [4, 5], 1) -# -# which should return `false`. -_index_dimension(::Any) = 0 -_index_dimension(::Colon) = 1 -_index_dimension(::AbstractVector) = 1 -_index_dimension(indices::Tuple) = sum(map(_index_dimension, indices)) - -function BangBang.possible( - ::typeof(BangBang._setindex!), ::C, ::T, indices::Vararg -) where {M,C<:AbstractArray{<:Real},T<:AbstractArray{<:Real,M}} - return BangBang.implements(setindex!, C) && - promote_type(eltype(C), eltype(T)) <: eltype(C) && - # This will still return `false` for scenarios such as - # - # setindex!!([1, 2, 3], [4, 5, 6], :, 1) - # - # which are in fact valid. However, this cases are rare. - (_index_dimension(indices) == M || _index_dimension(indices) == 1) -end - # HACK(torfjelde): This makes it so it works on iterators, etc. by default. # TODO(torfjelde): Do better. """ diff --git a/test/utils.jl b/test/utils.jl index a2d6f46fb..1fcf09ef1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -48,119 +48,4 @@ x = rand(dist) @test vectorize(dist, x) == vec(x.UL) end - - @testset "BangBang.possible" begin - using DynamicPPL.BangBang: setindex!! - - # Some utility methods for testing `setindex!`. - test_linear_index_only(::Tuple, ::AbstractArray) = false - test_linear_index_only(inds::NTuple{1}, ::AbstractArray) = true - test_linear_index_only(inds::NTuple{1}, ::AbstractVector) = false - - function replace_colon_with_axis(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? axes(x, i) : inds[i] - end - end - function replace_colon_with_vector(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? collect(axes(x, i)) : inds[i] - end - end - function replace_colon_with_range(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? (1:size(x, i)) : inds[i] - end - end - function replace_colon_with_booleans(inds::Tuple, x) - ntuple(length(inds)) do i - inds[i] isa Colon ? trues(size(x, i)) : inds[i] - end - end - - function replace_colon_with_range_linear(inds::NTuple{1}, x::AbstractArray) - return inds[1] isa Colon ? (1:length(x),) : inds - end - - @testset begin - @test setindex!!((1, 2, 3), :two, 2) === (1, :two, 3) - @test setindex!!((a=1, b=2, c=3), :two, :b) === (a=1, b=:two, c=3) - @test setindex!!([1, 2, 3], :two, 2) == [1, :two, 3] - @test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 10, :a) == - Dict(:a => 10, :b => 2) - @test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 3, "c") == - Dict(:a => 1, :b => 2, "c" => 3) - end - - @testset "mutation" begin - @testset "without type expansion" begin - for args in [([1, 2, 3], 20, 2), (Dict(:a => 1, :b => 2), 10, :a)] - @test setindex!!(args...) === args[1] - end - end - - @testset "with type expansion" begin - @test setindex!!([1, 2, 3], [4, 5], 1) == [[4, 5], 2, 3] - @test setindex!!([1, 2, 3], [4, 5, 6], :, 1) == [4, 5, 6] - end - end - - @testset "slices" begin - @testset "$(typeof(x)) with $(src_idx)" for (x, src_idx) in [ - # Vector. - (randn(2), (:,)), - (randn(2), (1:2,)), - # Matrix. - (randn(2, 3), (:,)), - (randn(2, 3), (:, 1)), - (randn(2, 3), (:, 1:3)), - # 3D array. - (randn(2, 3, 4), (:, 1, :)), - (randn(2, 3, 4), (:, 1:3, :)), - (randn(2, 3, 4), (1, 1:3, :)), - ] - # Base case. - @test @inferred(setindex!!(x, x[src_idx...], src_idx...)) === x - - # If we have `Colon` in the index, we replace this with other equivalent indices. - if any(Base.Fix2(isa, Colon), src_idx) - if test_linear_index_only(src_idx, x) - # With range instead of `Colon`. - @test @inferred( - setindex!!( - x, - x[src_idx...], - replace_colon_with_range_linear(src_idx, x)..., - ) - ) === x - else - # With axis instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_axis(src_idx, x)... - ) - ) === x - # With range instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_range(src_idx, x)... - ) - ) === x - # With vectors instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_vector(src_idx, x)... - ) - ) === x - # With boolean index instead of `Colon`. - @test @inferred( - setindex!!( - x, x[src_idx...], replace_colon_with_booleans(src_idx, x)... - ) - ) === x - end - end - end - end - end end