Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix argmax, findmax, findXwithfirst, and expand testing #99

Merged
merged 23 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/chainedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,10 @@ function Base.findmax(f::F, x::ChainedVector) where {F}
cleanup!(x) # get rid of any empty arrays
i = 1
y = f(x.arrays[1][1])
return findXwithfirst(!isless, f, x, y, i)
# x > y iff y < x for a well ordered set
# nb. isgreater = !isless is not correct. That is `>=`
isgreater(x, y) = isless(y, x)
return findXwithfirst(isgreater, f, x, y, i)
end

function Base.findmin(f::F, x::ChainedVector) where {F}
Expand All @@ -825,8 +828,9 @@ function findXwithfirst(comp, f, x, y, i)
for A in x.arrays
for y′ in A
y′′ = f(y′)
y = ifelse(comp(y′′, y), y′′, y)
i = ifelse(comp(y′′, y), i′, i)
c = comp(y′′, y) # store this before y changes
y = ifelse(c, y′′, y)
i = ifelse(c, i′, i)
i′ += 1
end
end
Expand All @@ -835,8 +839,8 @@ end

Base.findmax(x::ChainedVector) = findmax(identity, x)
Base.findmin(x::ChainedVector) = findmin(identity, x)
Base.argmax(x::ChainedVector) = findmax(identity, x)[1]
Base.argmin(x::ChainedVector) = findmin(identity, x)[1]
Base.argmax(x::ChainedVector) = findmax(identity, x)[2]
Base.argmin(x::ChainedVector) = findmin(identity, x)[2]
Base.argmax(f::F, x::ChainedVector) where {F} = x[findmax(f, x)[2]]
Base.argmin(f::F, x::ChainedVector) where {F} = x[findmin(f, x)[2]]

Expand Down
171 changes: 171 additions & 0 deletions test/chainedvector.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testset "ChainedVector" begin

# identity checks
x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]])
@test x == 1:10
@test length(x) == 10
Expand Down Expand Up @@ -55,6 +56,7 @@
insert!(x, 1, 2)
@test x[1] == 2


x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]])
y = ChainedVector([[11,12,13], [14,15,16], [17,18,19,20]])

Expand Down Expand Up @@ -284,6 +286,17 @@
@test all(==(2), x)
@test length(x) == 5

# https://github.com/JuliaData/SentinelArrays.jl/issues/97
x = ChainedVector([[18, 70, 92, 15, 65], [25, 14, 95, 54, 57]])
@test findmax(x) == (95, 8)
@test findmin(x) == (14, 7)
@test argmax(x) == 8
@test argmin(x) == 7
@test findmax(inv, x) == (inv(14), 7)
@test findmin(inv, x) == (inv(95), 8)
@test argmax(inv, x) == 14
@test argmin(inv, x) == 95

x = ChainedVector(Vector{Float64}[])
@test !any(x -> iseven(x), x)
@test !any(map(x -> iseven(x), x))
Expand Down Expand Up @@ -423,6 +436,164 @@
@test (rand(10,10) * v) isa ChainedVector
end

@testset "ChainedVectors on Generated Vectors" begin
#=
# Use to generate text below
function test_vector_generator(;
lengths = rand(0:5, 5),
possible_values = 1:100,
)
values = rand(possible_values, sum(lengths))
remaining_values = copy(values)
arrays = map(lengths) do length
result = remaining_values[1:length]
remaining_values = @view remaining_values[length+1:end]
return result
end
return ChainedVector(arrays) => values
end
function Base.show(io::IO, cv::ChainedVector)
print(io, "ChainedVector(")
show(io, cv.arrays)
print(io, ")")
end
for i in 1:10
test_vector_generator() |>
repr |>
x->replace(x, "=>" => "=>\n ") |>
x->println(x,",")
end
=#

# Pairs of test vectors
# Some were inspired by https://github.com/JuliaData/SentinelArrays.jl/issues/97
int_vectors = [
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a hard-coded set of test vectors, could a few 1000 be generated when the test runs? That might catch a few more edge cases.

Also, I think it would be clearer to only test Int here, to avoid NaN and floating point rounding.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is part of @Seelengrab 's Supposition.jl is supposed to do. I'm just trying to get this pull request through before I start messing with adding test dependencies, but we seem to be stuck here.

ChainedVector([[100, 20], [10, 30, 70, 40], [50], Int[], [60, 90, 80]]) =>
[100, 20, 10, 30, 70, 40, 50, 60, 90, 80],
ChainedVector([[2,1,3], [4,5,6], [7,8,10,9]]) =>
[2, 1, 3, 4, 5, 6, 7, 8, 10, 9],
ChainedVector([[18, 70, 92, 15, 65], [25, 14, 95, 54, 57]]) =>
[18, 70, 92, 15, 65, 25, 14, 95, 54, 57],
ChainedVector([[2, 34], [61, 8, 71], [65, 81, 51], [48, 93, 48, 94], [59, 15, 16, 56, 83]]) =>
[2, 34, 61, 8, 71, 65, 81, 51, 48, 93, 48, 94, 59, 15, 16, 56, 83],
ChainedVector([[23, 97, 70, 70], [4, 4], [61, 17], [95, 84, 91]]) =>
[23, 97, 70, 70, 4, 4, 61, 17, 95, 84, 91],
ChainedVector([[61, 23, 67, 61], [27, 19, 100], [26, 95], [2, 27, 63], [51, 52, 25]]) =>
[61, 23, 67, 61, 27, 19, 100, 26, 95, 2, 27, 63, 51, 52, 25],
ChainedVector([[25, 6, 94], [50], [63, 1, 76], [96, 6]]) =>
[25, 6, 94, 50, 63, 1, 76, 96, 6],
ChainedVector([[98, 5, 94], [82, 60], [58, 46, 13, 62, 48]]) =>
[98, 5, 94, 82, 60, 58, 46, 13, 62, 48],
ChainedVector([[28, 26], [21, 18, 64, 15], [11, 81, 17, 90], [29], [16, 67, 34, 84]]) =>
[28, 26, 21, 18, 64, 15, 11, 81, 17, 90, 29, 16, 67, 34, 84],
ChainedVector([[95, 15, 49, 31, 63], [79, 88], [76], [87, 52], [86, 50, 68, 61]]) =>
[95, 15, 49, 31, 63, 79, 88, 76, 87, 52, 86, 50, 68, 61],
ChainedVector([[71], [96, 84], [88, 3], [76, 47]]) =>
[71, 96, 84, 88, 3, 76, 47],
ChainedVector([[7, 21, 31], [45], [53, 53]]) =>
[7, 21, 31, 45, 53, 53],
ChainedVector([[24, 28, 75, 42], [7, 38, 59, 10], [21, 30, 14], [8, 39], [13, 68, 42]]) =>
[24, 28, 75, 42, 7, 38, 59, 10, 21, 30, 14, 8, 39, 13, 68, 42],
]
floating_point_vectors = [
ChainedVector([[2.1, -4.6, -2.5], [-5.0, 6.4, 2.0, -0.5], [-6.1, -7.6, -3.2, -4.7, 4.3], [-1.7, 6.4, -8.9, -7.4], [-7.7, -1.4, 3.1, 4.5]]) =>
[2.1, -4.6, -2.5, -5.0, 6.4, 2.0, -0.5, -6.1, -7.6, -3.2, -4.7, 4.3, -1.7, 6.4, -8.9, -7.4, -7.7, -1.4, 3.1, 4.5],
ChainedVector([[-8.5, -1.2, -3.8, 7.5], [8.2, 7.5, -5.3], [-2.7, 0.6, -6.2, 6.1, 1.4]]) =>
[-8.5, -1.2, -3.8, 7.5, 8.2, 7.5, -5.3, -2.7, 0.6, -6.2, 6.1, 1.4],
ChainedVector([[-7.2], [8.1, 2.3, 7.5], [-8.4, -5.7]]) =>
[-7.2, 8.1, 2.3, 7.5, -8.4, -5.7],
ChainedVector([[-3.7, 7.8, -5.0], [0.1], [5.0, -4.1], [-1.6, -0.9, 8.7, -7.8]]) =>
[-3.7, 7.8, -5.0, 0.1, 5.0, -4.1, -1.6, -0.9, 8.7, -7.8],
ChainedVector([[8.6, -2.0], [8.0, 3.4, 3.3], [1.0], [5.4, -2.6, -4.7, 4.4, 4.4], [7.9]]) =>
[8.6, -2.0, 8.0, 3.4, 3.3, 1.0, 5.4, -2.6, -4.7, 4.4, 4.4, 7.9],
ChainedVector([[7.6, 5.9], [7.9, -8.8, -1.5, -0.4, 6.0], [-5.1, -0.4, 4.4, 7.3]]) =>
[7.6, 5.9, 7.9, -8.8, -1.5, -0.4, 6.0, -5.1, -0.4, 4.4, 7.3],
ChainedVector([[3.2, -3.2, 1.2, -1.2, -2.1], [0.5], [6.2], [2.9], [-8.1, 5.8, 4.8, -3.4, -3.1]]) =>
[3.2, -3.2, 1.2, -1.2, -2.1, 0.5, 6.2, 2.9, -8.1, 5.8, 4.8, -3.4, -3.1],
ChainedVector([[-8.0, -1.9, -5.1, -1.4, -8.3], [5.1, -3.7, 6.3, -4.8, -3.3], [-7.0], [-2.4, 4.0, -3.7], [-6.6, -6.9, 2.5, -1.3]]) =>
[-8.0, -1.9, -5.1, -1.4, -8.3, 5.1, -3.7, 6.3, -4.8, -3.3, -7.0, -2.4, 4.0, -3.7, -6.6, -6.9, 2.5, -1.3],
ChainedVector([[-7.5], [-1.5, -5.8, 8.4], [-8.4, -1.9, 2.3, -0.8, -8.5], [0.2, 0.5, -7.4, 2.1, -3.9]]) =>
[-7.5, -1.5, -5.8, 8.4, -8.4, -1.9, 2.3, -0.8, -8.5, 0.2, 0.5, -7.4, 2.1, -3.9],
ChainedVector([[3.9, -8.9], [-0.3, 0.0, 7.3], [-2.9, 8.6, 5.8, 0.5], [0.0, -4.5, 3.3, 0.4, -3.2]]) =>
[3.9, -8.9, -0.3, 0.0, 7.3, -2.9, 8.6, 5.8, 0.5, 0.0, -4.5, 3.3, 0.4, -3.2],
]
rational_vectors = [
ChainedVector(Vector{Rational{Int64}}[[1, 1//2, 1//2, 4//5], [7//10], [1, 1//5, 7//10, 3//10, 1], [3//5], [1]]) =>
Rational{Int64}[1, 1//2, 1//2, 4//5, 7//10, 1, 1//5, 7//10, 3//10, 1, 3//5, 1],
ChainedVector(Vector{Rational{Int64}}[[1//5], [1, 4//5, 1//5], [3//5, 7//10, 3//5], [9//10, 1//5, 7//10, 1//2], [1//2, 7//10, 9//10, 3//5, 7//10]]) =>
Rational{Int64}[1//5, 1, 4//5, 1//5, 3//5, 7//10, 3//5, 9//10, 1//5, 7//10, 1//2, 1//2, 7//10, 9//10, 3//5, 7//10],
ChainedVector(Vector{Rational{Int64}}[[7//10, 1, 1//5, 1//2, 2//5], [1//5, 4//5, 1//2, 1//5], [3//10, 3//10, 1//2], [3//10, 1//10, 4//5, 3//5], [2//5, 7//10, 1, 3//10, 3//10]]) =>
Rational{Int64}[7//10, 1, 1//5, 1//2, 2//5, 1//5, 4//5, 1//2, 1//5, 3//10, 3//10, 1//2, 3//10, 1//10, 4//5, 3//5, 2//5, 7//10, 1, 3//10, 3//10],
ChainedVector(Vector{Rational{Int64}}[[1//10, 4//5], [1//2], [1//10], [4//5, 1, 3//5, 9//10, 9//10]]) =>
Rational{Int64}[1//10, 4//5, 1//2, 1//10, 4//5, 1, 3//5, 9//10, 9//10],
ChainedVector(Vector{Rational{Int64}}[[3//10, 1, 9//10, 3//5], [1, 1], [1, 4//5, 3//5, 9//10]]) =>
Rational{Int64}[3//10, 1, 9//10, 3//5, 1, 1, 1, 4//5, 3//5, 9//10],
ChainedVector(Vector{Rational{Int64}}[[3//10, 7//10], [4//5], [4//5, 1, 1//10, 9//10], [1, 1, 4//5]]) =>
Rational{Int64}[3//10, 7//10, 4//5, 4//5, 1, 1//10, 9//10, 1, 1, 4//5],
ChainedVector(Vector{Rational{Int64}}[[2//5], [3//5, 9//10, 7//10, 9//10], [1//2, 1, 1//10, 1//5], [1//5, 4//5, 7//10, 2//5]]) =>
Rational{Int64}[2//5, 3//5, 9//10, 7//10, 9//10, 1//2, 1, 1//10, 1//5, 1//5, 4//5, 7//10, 2//5],
ChainedVector(Vector{Rational{Int64}}[[7//10], [3//5, 1//5, 2//5, 3//5, 4//5], [4//5], [7//10, 3//5, 7//10, 7//10, 1//10]]) =>
Rational{Int64}[7//10, 3//5, 1//5, 2//5, 3//5, 4//5, 4//5, 7//10, 3//5, 7//10, 7//10, 1//10],
ChainedVector(Vector{Rational{Int64}}[[1//2, 1, 1//2, 9//10, 2//5], [9//10, 1//2, 3//5], [4//5, 7//10], [3//10, 2//5], [9//10, 1]]) =>
Rational{Int64}[1//2, 1, 1//2, 9//10, 2//5, 9//10, 1//2, 3//5, 4//5, 7//10, 3//10, 2//5, 9//10, 1],
ChainedVector(Vector{Rational{Int64}}[[9//10, 3//10, 1//10, 2//5], [4//5], [9//10, 2//5]]) =>
Rational{Int64}[9//10, 3//10, 1//10, 2//5, 4//5, 9//10, 2//5],

]
@testset for (x,y) in Iterators.flatten([int_vectors, floating_point_vectors, rational_vectors])
@test copy(x) == y
@test collect(x) == y
@test length(x) == length(y)
# Floating point tests fail if this is not approx
# See https://github.com/JuliaData/SentinelArrays.jl/pull/99#issuecomment-2171005657
@test sum(x) ≈ sum(y)
@test findmax(x) == findmax(y)
@test findmin(x) == findmin(y)
@test maximum(x) == maximum(y)
@test minimum(x) == minimum(y)
@test argmax(x) == argmax(y)
@test argmin(x) == argmin(y)
@test all(>(0),x) == all(>(0),y)
@test any(>(0),x) == any(>(0),y)
@test any(<(0),x) == any(<(0),y)
@test count(>(0),x) == count(>(0),y)
@test count(<(0),x) == count(<(0),y)
@test extrema(inv, x) == extrema(inv, y)
@static if VERSION ≥ v"1.6"
@test findmax(x->x+1, x) == findmax(x->x+1, y)
@test findmin(x->x-1, x) == findmin(x->x-1, y)
@test findfirst(isodd, x) == findfirst(isodd, y)
@test findfirst(iseven, x) == findfirst(iseven ,y)
@test findlast(isodd, x) == findlast(isodd, y)
@test findlast(iseven, x) == findlast(iseven ,y)
@test findall(iseven, x) == findall(iseven ,y)
@test findnext(isodd, x, 5) == findnext(isodd, y, 5)
@test findprev(isodd, x, 5) == findprev(isodd, y, 5)
end
@test let (val, idx) = findmax(x)
max_val = maximum(x)
val == max_val == x[idx]
end
@test let (val, idx) = findmin(x)
min_val = minimum(x)
val == min_val == x[idx]
end
@test x[argmax(x)] == maximum(x)
@test x[argmin(x)] == minimum(x)
@test let (val, idx) = findmax(inv, x)
max_val = maximum(inv, x)
val == max_val == inv(x[idx])
end
@test let (val, idx) = findmin(inv, x)
min_val = minimum(inv, x)
val == min_val == inv(x[idx])
end
@test inv(argmax(inv, x)) == maximum(inv, x)
@test inv(argmin(inv, x)) == minimum(inv, x)
end
end


@testset "iteration protocol on ChainedVector" begin
for len in 0:6
cv = ChainedVector([1:len])
Expand Down
Loading