Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix argmax, findmax, findXwithfirst, and expand testing #99

Merged
merged 23 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/chainedvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,9 @@ function findXwithfirst(comp, f, x, y, i)
for A in x.arrays
for y′ in A
y′′ = f(y′)
y = ifelse(comp(y′′, y), y′′, y)
i = ifelse(comp(y′′, y), i′, i)
c = comp(y′′, y) # store this before y changes
y = ifelse(c, y′′, y)
i = ifelse(c, i′, i)
i′ += 1
end
end
Expand All @@ -835,8 +836,8 @@ end

Base.findmax(x::ChainedVector) = findmax(identity, x)
Base.findmin(x::ChainedVector) = findmin(identity, x)
Base.argmax(x::ChainedVector) = findmax(identity, x)[1]
Base.argmin(x::ChainedVector) = findmin(identity, x)[1]
Base.argmax(x::ChainedVector) = findmax(identity, x)[2]
Base.argmin(x::ChainedVector) = findmin(identity, x)[2]
Base.argmax(f::F, x::ChainedVector) where {F} = x[findmax(f, x)[2]]
Base.argmin(f::F, x::ChainedVector) where {F} = x[findmin(f, x)[2]]

Expand Down
117 changes: 117 additions & 0 deletions test/chainedvector.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@testset "ChainedVector" begin

# identity checks
x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]])
@test x == 1:10
@test length(x) == 10
Expand Down Expand Up @@ -55,6 +56,7 @@
insert!(x, 1, 2)
@test x[1] == 2


x = ChainedVector([[1,2,3], [4,5,6], [7,8,9,10]])
y = ChainedVector([[11,12,13], [14,15,16], [17,18,19,20]])

Expand Down Expand Up @@ -284,6 +286,17 @@
@test all(==(2), x)
@test length(x) == 5

# https://github.com/JuliaData/SentinelArrays.jl/issues/97
x = ChainedVector([[18, 70, 92, 15, 65], [25, 14, 95, 54, 57]])
@test findmax(x) == (95, 8)
@test findmin(x) == (14, 7)
@test argmax(x) == 8
@test argmin(x) == 7
@test findmax(inv, x) == (inv(14), 7)
@test findmin(inv, x) == (inv(95), 8)
@test argmax(inv, x) == 14
@test argmin(inv, x) == 95

x = ChainedVector(Vector{Float64}[])
@test !any(x -> iseven(x), x)
@test !any(map(x -> iseven(x), x))
Expand Down Expand Up @@ -423,6 +436,110 @@
@test (rand(10,10) * v) isa ChainedVector
end

@testset "ChainedVectors on Generated Vectors" begin
#=
# Use to generate text below
function test_vector_generator(;
lengths = rand(0:5, 5),
possible_values = 1:100,
)
values = rand(possible_values, sum(lengths))
remaining_values = copy(values)
arrays = map(lengths) do length
result = remaining_values[1:length]
remaining_values = @view remaining_values[length+1:end]
return result
end
return ChainedVector(arrays) => values
end
function Base.show(io::IO, cv::ChainedVector)
print(io, "ChainedVector(")
show(io, cv.arrays)
print(io, ")")
end
for i in 1:10
test_vector_generator() |>
repr |>
x->replace(x, "=>" => "=>\n ") |>
x->println(x,",")
end
=#

# Pairs of text vectors
mkitti marked this conversation as resolved.
Show resolved Hide resolved
# Some were inspired by https://github.com/JuliaData/SentinelArrays.jl/issues/97
test_vectors = [
ChainedVector([[100, 20], [10, 30, 70, 40], [50], [], [60, 90, 80]]) =>
[100, 20, 10, 30, 70, 40, 50, 60, 90, 80],
ChainedVector([[2,1,3], [4,5,6], [7,8,10,9]]) =>
[2, 1, 3, 4, 5, 6, 7, 8, 10, 9],
ChainedVector([[18, 70, 92, 15, 65], [25, 14, 95, 54, 57]]) =>
[18, 70, 92, 15, 65, 25, 14, 95, 54, 57],
ChainedVector([[2, 34], [61, 8, 71], [65, 81, 51], [48, 93, 48, 94], [59, 15, 16, 56, 83]]) =>
[2, 34, 61, 8, 71, 65, 81, 51, 48, 93, 48, 94, 59, 15, 16, 56, 83],
ChainedVector([[23, 97, 70, 70], [4, 4], [61, 17], [95, 84, 91]]) =>
[23, 97, 70, 70, 4, 4, 61, 17, 95, 84, 91],
ChainedVector([[61, 23, 67, 61], [27, 19, 100], [26, 95], [2, 27, 63], [51, 52, 25]]) =>
[61, 23, 67, 61, 27, 19, 100, 26, 95, 2, 27, 63, 51, 52, 25],
ChainedVector([[25, 6, 94], [50], [63, 1, 76], [96, 6]]) =>
[25, 6, 94, 50, 63, 1, 76, 96, 6],
ChainedVector([[98, 5, 94], [82, 60], [58, 46, 13, 62, 48]]) =>
[98, 5, 94, 82, 60, 58, 46, 13, 62, 48],
ChainedVector([[28, 26], [21, 18, 64, 15], [11, 81, 17, 90], [29], [16, 67, 34, 84]]) =>
[28, 26, 21, 18, 64, 15, 11, 81, 17, 90, 29, 16, 67, 34, 84],
ChainedVector([[95, 15, 49, 31, 63], [79, 88], [76], [87, 52], [86, 50, 68, 61]]) =>
[95, 15, 49, 31, 63, 79, 88, 76, 87, 52, 86, 50, 68, 61],
ChainedVector([[71], [96, 84], [88, 3], [76, 47]]) =>
[71, 96, 84, 88, 3, 76, 47],
ChainedVector([[7, 21, 31], [45], [53, 53]]) =>
[7, 21, 31, 45, 53, 53],
ChainedVector([[24, 28, 75, 42], [7, 38, 59, 10], [21, 30, 14], [8, 39], [13, 68, 42]]) =>
[24, 28, 75, 42, 7, 38, 59, 10, 21, 30, 14, 8, 39, 13, 68, 42],
ChainedVector([[2.1, -4.6, -2.5], [-5.0, 6.4, 2.0, -0.5], [-6.1, -7.6, -3.2, -4.7, 4.3], [-1.7, 6.4, -8.9, -7.4], [-7.7, -1.4, 3.1, 4.5]]) =>
[2.1, -4.6, -2.5, -5.0, 6.4, 2.0, -0.5, -6.1, -7.6, -3.2, -4.7, 4.3, -1.7, 6.4, -8.9, -7.4, -7.7, -1.4, 3.1, 4.5],
ChainedVector([[-8.5, -1.2, -3.8, 7.5], [8.2, 7.5, -5.3], [-2.7, 0.6, -6.2, 6.1, 1.4]]) =>
[-8.5, -1.2, -3.8, 7.5, 8.2, 7.5, -5.3, -2.7, 0.6, -6.2, 6.1, 1.4],
ChainedVector([[-7.2], [8.1, 2.3, 7.5], [-8.4, -5.7]]) =>
[-7.2, 8.1, 2.3, 7.5, -8.4, -5.7],
ChainedVector([[-3.7, 7.8, -5.0], [0.1], [5.0, -4.1], [-1.6, -0.9, 8.7, -7.8]]) =>
[-3.7, 7.8, -5.0, 0.1, 5.0, -4.1, -1.6, -0.9, 8.7, -7.8],
ChainedVector([[8.6, -2.0], [8.0, 3.4, 3.3], [1.0], [5.4, -2.6, -4.7, 4.4, 4.4], [7.9]]) =>
[8.6, -2.0, 8.0, 3.4, 3.3, 1.0, 5.4, -2.6, -4.7, 4.4, 4.4, 7.9],
ChainedVector([[7.6, 5.9], [7.9, -8.8, -1.5, -0.4, 6.0], [-5.1, -0.4, 4.4, 7.3]]) =>
[7.6, 5.9, 7.9, -8.8, -1.5, -0.4, 6.0, -5.1, -0.4, 4.4, 7.3],
ChainedVector([[3.2, -3.2, 1.2, -1.2, -2.1], [0.5], [6.2], [2.9], [-8.1, 5.8, 4.8, -3.4, -3.1]]) =>
[3.2, -3.2, 1.2, -1.2, -2.1, 0.5, 6.2, 2.9, -8.1, 5.8, 4.8, -3.4, -3.1],
ChainedVector([[-8.0, -1.9, -5.1, -1.4, -8.3], [5.1, -3.7, 6.3, -4.8, -3.3], [-7.0], [-2.4, 4.0, -3.7], [-6.6, -6.9, 2.5, -1.3]]) =>
[-8.0, -1.9, -5.1, -1.4, -8.3, 5.1, -3.7, 6.3, -4.8, -3.3, -7.0, -2.4, 4.0, -3.7, -6.6, -6.9, 2.5, -1.3],
ChainedVector([[-7.5], [-1.5, -5.8, 8.4], [-8.4, -1.9, 2.3, -0.8, -8.5], [0.2, 0.5, -7.4, 2.1, -3.9]]) =>
[-7.5, -1.5, -5.8, 8.4, -8.4, -1.9, 2.3, -0.8, -8.5, 0.2, 0.5, -7.4, 2.1, -3.9],
ChainedVector([[3.9, -8.9], [-0.3, 0.0, 7.3], [-2.9, 8.6, 5.8, 0.5], [0.0, -4.5, 3.3, 0.4, -3.2]]) =>
[3.9, -8.9, -0.3, 0.0, 7.3, -2.9, 8.6, 5.8, 0.5, 0.0, -4.5, 3.3, 0.4, -3.2],
]
for (x,y) in test_vectors
try
@test length(x) == length(y)
# should this be approx?
@test sum(x) ≈ sum(y)
@test findmax(x) == findmax(y)
@test findmin(x) == findmin(y)
@test maximum(x) == maximum(y)
@test minimum(x) == minimum(y)
@test argmax(x) == argmax(y)
@test argmin(x) == argmin(y)
@test findmax(x->x+1, x) == findmax(x->x+1, y)
@test findmin(x->x-1, x) == findmin(x->x-1, y)
@test findfirst(isodd, x) == findfirst(isodd, y)
@test findfirst(iseven, x) == findfirst(iseven ,y)
@test findlast(isodd, x) == findlast(isodd, y)
@test findlast(iseven, x) == findlast(iseven ,y)
@test findnext(isodd, x, 5) == findnext(isodd, y, 5)
catch err
error("Test failed for $x")
end
end
mkitti marked this conversation as resolved.
Show resolved Hide resolved
end


@testset "iteration protocol on ChainedVector" begin
for len in 0:6
cv = ChainedVector([1:len])
Expand Down
Loading