From 2abf497419a892d7719ce015d41e96f5e821de8c Mon Sep 17 00:00:00 2001 From: Tortar <68152031+Tortar@users.noreply.github.com> Date: Sat, 20 Apr 2024 03:15:15 +0200 Subject: [PATCH] Fix merge and test it (#66) --- src/UnweightedSamplingMulti.jl | 23 ++++-------------- src/UnweightedSamplingSingle.jl | 3 --- src/WeightedSamplingMulti.jl | 8 ------- test/merge_tests.jl | 41 +++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 47 insertions(+), 29 deletions(-) create mode 100644 test/merge_tests.jl diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index 766b1fd..3e43a64 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -178,41 +178,28 @@ end is_ordered(s::AbstractOrdWrReservoirSampleMulti) = true is_ordered(s::AbstractWrReservoirSampleMulti) = false -function Base.merge(s1::AbstractWorReservoirSampleMulti, s2::AbstractWorReservoirSampleMulti) - len1, len2, n1, n2 = check_merging_support(s1, s2) - n_tot = n1 + n2 - p = n2 / n_tot - value = create_new_res_vec(s1, s2, p, len1) - return SampleSingleAlgR(n_tot, s1.rng, value) -end function Base.merge(s1::AbstractWrReservoirSampleMulti, s2::AbstractWrReservoirSampleMulti) len1, len2, n1, n2 = check_merging_support(s1, s2) + shuffle!(s1.rng, s1.value) + shuffle!(s2.rng, s2.value) n_tot = n1 + n2 p = n2 / n_tot value = create_new_res_vec(s1, s2, p, len1) - s_merged = SampleMultiAlgRSWRSKIP(n_tot, s1.rng, value) + s_merged = SampleMultiAlgRSWRSKIP(0, n_tot, s1.rng, value) recompute_skip!(s_merged, len1) return s_merged end -function Base.merge!(s1::SampleMultiAlgR, s2::AbstractWorReservoirSampleMulti) - len1, len2, n1, n2 = check_merging_support(s1, s2) - n_tot = n1 + n2 - p = n2 / n_tot - s1 = merge_res_vec!(s1, s2, p, len1, n_tot) - return s1 -end function Base.merge!(s1::SampleMultiAlgRSWRSKIP, s2::AbstractWrReservoirSampleMulti) len1, len2, n1, n2 = check_merging_support(s1, s2) + shuffle!(s1.rng, s1.value) + shuffle!(s2.rng, s2.value) n_tot = n1 + n2 p = n2 / n_tot s1 = merge_res_vec!(s1, s2, p, len1, n_tot) recompute_skip!(s1, len1) return s1 end -function Base.merge!(s1::SampleSingleAlgL, s2::AbstractWorReservoirSampleMulti) - error("Merging into a ReservoirSample using method algL is not supported") -end function check_merging_support(s1, s2) len1, len2 = length(s1.value), length(s2.value) diff --git a/src/UnweightedSamplingSingle.jl b/src/UnweightedSamplingSingle.jl index edb6f0d..43d86d7 100644 --- a/src/UnweightedSamplingSingle.jl +++ b/src/UnweightedSamplingSingle.jl @@ -72,9 +72,6 @@ function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle) s1.seen_k = n_tot return s1 end -function Base.merge!(s1::SampleSingleAlgL, s2::AbstractReservoirSampleSingle) - error("Merging into a ReservoirSample using method algL is not supported") -end function itsample(iter, method::ReservoirAlgorithm = algL) return itsample(Random.default_rng(), iter, method) diff --git a/src/WeightedSamplingMulti.jl b/src/WeightedSamplingMulti.jl index b2174df..1d22947 100644 --- a/src/WeightedSamplingMulti.jl +++ b/src/WeightedSamplingMulti.jl @@ -160,14 +160,6 @@ end is_ordered(s::SampleMultiOrdAlgWRSWRSKIP) = true is_ordered(s::SampleMultiAlgWRSWRSKIP) = false -function Base.merge(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample) - error("Merging is not supported for weighted sampling") -end - -function Base.merge!(s1::AbstractWeightedReservoirSample, s2::AbstractWeightedReservoirSample) - error("Merging is not supported for weighted sampling") -end - function value(s::AbstractWeightedWorReservoirSampleMulti) if n_seen(s) < s.n return first.(s.value.valtree)[1:n_seen(s)] diff --git a/test/merge_tests.jl b/test/merge_tests.jl new file mode 100644 index 0000000..df01949 --- /dev/null +++ b/test/merge_tests.jl @@ -0,0 +1,41 @@ + +@testset "merge tests" begin + rng = StableRNG(43) + iters = (1:2, 3:10) + reps = 10^5 + size = 2 + for (m1, m2) in [(algRSWRSKIP, algRSWRSKIP)] + res = zeros(Int, 10, 10) + for _ in 1:reps + s1 = ReservoirSample(rng, Int, size, m1) + s2 = ReservoirSample(rng, Int, size, m2) + s_all = (s1, s2) + for (s, it) in zip(s_all, iters) + for x in it + update!(s, x) + end + end + s_merged = merge(s1, s2) + res[value(s_merged)...] += 1 + end + cases = m1 == algRSWRSKIP ? 10^size : factorial(10)/factorial(10-size) + ps_exact = [1/cases for _ in 1:cases] + count_est = vec(res) + chisq_test = ChisqTest(count_est, ps_exact) + @test pvalue(chisq_test) > 0.05 + end + s1 = ReservoirSample(rng, Int, 2, algRSWRSKIP) + s2 = ReservoirSample(rng, Int, 2, algRSWRSKIP) + s_all = (s1, s2) + for (s, it) in zip(s_all, iters) + for x in it + update!(s, x) + end + end + @test length(merge!(s1, s2).value) == 2 + s1 = ReservoirSample(rng, Int, algR) + s2 = ReservoirSample(rng, Int, algR) + update!(s1, 1) + update!(s2, 2) + @test merge!(s1, s2).value in (1, 2) +end diff --git a/test/runtests.jl b/test/runtests.jl index 331e515..c502cf6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,4 +13,5 @@ using Test include("unweighted_sampling_multi_tests.jl") include("weighted_sampling_single_tests.jl") include("weighted_sampling_multi_tests.jl") + include("merge_tests.jl") end