Skip to content

Commit

Permalink
Fix merge and test it (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 20, 2024
1 parent 039eeaa commit 2abf497
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 29 deletions.
23 changes: 5 additions & 18 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,41 +178,28 @@ end
is_ordered(s::AbstractOrdWrReservoirSampleMulti) = true
is_ordered(s::AbstractWrReservoirSampleMulti) = false

function Base.merge(s1::AbstractWorReservoirSampleMulti, s2::AbstractWorReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
n_tot = n1 + n2
p = n2 / n_tot
value = create_new_res_vec(s1, s2, p, len1)
return SampleSingleAlgR(n_tot, s1.rng, value)
end
function Base.merge(s1::AbstractWrReservoirSampleMulti, s2::AbstractWrReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
shuffle!(s1.rng, s1.value)
shuffle!(s2.rng, s2.value)
n_tot = n1 + n2
p = n2 / n_tot
value = create_new_res_vec(s1, s2, p, len1)
s_merged = SampleMultiAlgRSWRSKIP(n_tot, s1.rng, value)
s_merged = SampleMultiAlgRSWRSKIP(0, n_tot, s1.rng, value)
recompute_skip!(s_merged, len1)
return s_merged
end

function Base.merge!(s1::SampleMultiAlgR, s2::AbstractWorReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
n_tot = n1 + n2
p = n2 / n_tot
s1 = merge_res_vec!(s1, s2, p, len1, n_tot)
return s1
end
function Base.merge!(s1::SampleMultiAlgRSWRSKIP, s2::AbstractWrReservoirSampleMulti)
len1, len2, n1, n2 = check_merging_support(s1, s2)
shuffle!(s1.rng, s1.value)
shuffle!(s2.rng, s2.value)
n_tot = n1 + n2
p = n2 / n_tot
s1 = merge_res_vec!(s1, s2, p, len1, n_tot)
recompute_skip!(s1, len1)
return s1
end
function Base.merge!(s1::SampleSingleAlgL, s2::AbstractWorReservoirSampleMulti)
error("Merging into a ReservoirSample using method algL is not supported")
end

function check_merging_support(s1, s2)
len1, len2 = length(s1.value), length(s2.value)
Expand Down
3 changes: 0 additions & 3 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle)
s1.seen_k = n_tot
return s1
end
function Base.merge!(s1::SampleSingleAlgL, s2::AbstractReservoirSampleSingle)
error("Merging into a ReservoirSample using method algL is not supported")
end

function itsample(iter, method::ReservoirAlgorithm = algL)
return itsample(Random.default_rng(), iter, method)
Expand Down
8 changes: 0 additions & 8 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,6 @@ end
is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true
is_ordered(s::SampleMultiAlgWRSWRSKIP) = false

function Base.merge(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample)
error("Merging is not supported for weighted sampling")
end

function Base.merge!(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample)
error("Merging is not supported for weighted sampling")
end

function value(s::AbstractWeightedWorReservoirSampleMulti)
if n_seen(s) < s.n
return first.(s.value.valtree)[1:n_seen(s)]
Expand Down
41 changes: 41 additions & 0 deletions test/merge_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

@testset "merge tests" begin
rng = StableRNG(43)
iters = (1:2, 3:10)
reps = 10^5
size = 2
for (m1, m2) in [(algRSWRSKIP, algRSWRSKIP)]
res = zeros(Int, 10, 10)
for _ in 1:reps
s1 = ReservoirSample(rng, Int, size, m1)
s2 = ReservoirSample(rng, Int, size, m2)
s_all = (s1, s2)
for (s, it) in zip(s_all, iters)
for x in it
update!(s, x)
end
end
s_merged = merge(s1, s2)
res[value(s_merged)...] += 1
end
cases = m1 == algRSWRSKIP ? 10^size : factorial(10)/factorial(10-size)
ps_exact = [1/cases for _ in 1:cases]
count_est = vec(res)
chisq_test = ChisqTest(count_est, ps_exact)
@test pvalue(chisq_test) > 0.05
end
s1 = ReservoirSample(rng, Int, 2, algRSWRSKIP)
s2 = ReservoirSample(rng, Int, 2, algRSWRSKIP)
s_all = (s1, s2)
for (s, it) in zip(s_all, iters)
for x in it
update!(s, x)
end
end
@test length(merge!(s1, s2).value) == 2
s1 = ReservoirSample(rng, Int, algR)
s2 = ReservoirSample(rng, Int, algR)
update!(s1, 1)
update!(s2, 2)
@test merge!(s1, s2).value in (1, 2)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ using Test
include("unweighted_sampling_multi_tests.jl")
include("weighted_sampling_single_tests.jl")
include("weighted_sampling_multi_tests.jl")
include("merge_tests.jl")
end

0 comments on commit 2abf497

Please sign in to comment.