diff --git a/src/UnweightedSamplingMulti.jl b/src/UnweightedSamplingMulti.jl index f1b67c0..3e1d8e0 100644 --- a/src/UnweightedSamplingMulti.jl +++ b/src/UnweightedSamplingMulti.jl @@ -7,7 +7,8 @@ function itsample(rng::AbstractRNG, iter, n::Int; alloc = true, iter_type = Any) if alloc unweighted_sampling_multi(iter, rng, n) else - unweighted_resorvoir_sampling_multi(iter, rng, n, iter_type) + IterHasKnownSize = Base.IteratorSize(iter) + unweighted_resorvoir_sampling_multi(iter, rng, n, IterHasKnownSize, iter_type) end end @@ -20,7 +21,8 @@ function itsample(rng::AbstractRNG, iter, condition::Function, n::Int; alloc = t unweighted_sampling_with_condition_multi(iter, rng, n, condition) else iter_filtered = Iterators.filter(x -> condition(x), iter) - unweighted_resorvoir_sampling_multi(iter_filtered, rng, n, iter_type) + IterHasKnownSize = Base.IteratorSize(iter_filtered) + unweighted_resorvoir_sampling_multi(iter_filtered, rng, n, IterHasKnownSize, iter_type) end end @@ -50,7 +52,7 @@ function unweighted_sampling_with_condition_multi(iter, rng, n, condition) return res[1:i] end -function unweighted_resorvoir_sampling_multi(iter, rng, n, iter_type = Any) +function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Base.SizeUnknown, iter_type = eltype(iter)) it = iterate(iter) isnothing(it) && return iter_type[] el, state = it @@ -77,4 +79,19 @@ function unweighted_resorvoir_sampling_multi(iter, rng, n, iter_type = Any) reservoir[rand(rng, 1:n)] = el w *= rand(rng)^(1/n) end +end + +function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Union{Base.HasLength, Base.HasShape}, iter_type = eltype(iter)) + N = length(iter) + N <= n && return collect(iter) + indices = sort!(sample(rng, 1:N, n; replace=false)) + reservoir = Vector{iter_type}(undef, n) + j = 1 + for (i, x) in enumerate(iter) + if i == indices[j] + reservoir[j] = x + j == n && return reservoir + j += 1 + end + end end \ No newline at end of file diff --git a/src/UnweightedSamplingSingle.jl b/src/UnweightedSamplingSingle.jl index d7d8483..73caa30 100644 --- a/src/UnweightedSamplingSingle.jl +++ b/src/UnweightedSamplingSingle.jl @@ -7,7 +7,8 @@ function itsample(rng::AbstractRNG, iter; alloc = false) if alloc unweighted_sampling_single(iter, rng) else - unweighted_resorvoir_sampling_single(iter, rng) + IterHasKnownSize = Base.IteratorSize(iter) + unweighted_resorvoir_sampling_single(iter, rng, IterHasKnownSize) end end @@ -20,11 +21,11 @@ function itsample(rng::AbstractRNG, iter, condition::Function; alloc = false) unweighted_sampling_with_condition_single(iter, rng, condition) else iter_filtered = Iterators.filter(x -> condition(x), iter) - unweighted_resorvoir_sampling_single(iter_filtered, rng) + IterHasKnownSize = Base.IteratorSize(iter_filtered) + unweighted_resorvoir_sampling_single(iter_filtered, rng, IterHasKnownSize) end end - function unweighted_sampling_single(iter, rng) pop = collect(iter) isempty(pop) && return nothing @@ -44,7 +45,7 @@ function unweighted_sampling_with_condition_single(iter, rng, condition) return nothing end -function unweighted_resorvoir_sampling_single(iter, rng) +function unweighted_resorvoir_sampling_single(iter, rng, ::Base.SizeUnknown) res = iterate(iter) isnothing(res) && return nothing w = rand(rng) @@ -61,4 +62,13 @@ function unweighted_resorvoir_sampling_single(iter, rng) isnothing(res) && return choice w *= rand(rng) end -end \ No newline at end of file +end + +function unweighted_resorvoir_sampling_single(iter, rng, ::Union{Base.HasLength, Base.HasShape}) + k = rand(rng, 1:length(iter)) + for (i, x) in enumerate(iter) + i == k && return x + end +end + + diff --git a/test/unweighted_sampling_multi_tests.jl b/test/unweighted_sampling_multi_tests.jl index 2d64409..8d9fce0 100644 --- a/test/unweighted_sampling_multi_tests.jl +++ b/test/unweighted_sampling_multi_tests.jl @@ -2,27 +2,44 @@ @testset "Unweighted sampling multi tests" begin @testset "alloc=$(alloc)" for alloc in [false, true] - # test values are inrange a, b = 1, 10 - s = itsample(a:b, 2, alloc=alloc) + # test return values of iter with known lengths are inrange + iter = a:b + s = itsample(iter, 2, alloc=alloc) @test length(s) == 2 @test all(x -> a <= x <= b, s) @test typeof(s) == Vector{ifelse(alloc, Int, Any)} - s = itsample(a:b, 2, alloc=alloc, iter_type=Int) + s = itsample(iter, 2, alloc=alloc, iter_type=Int) @test length(s) == 2 @test all(x -> a <= x <= b, s) @test typeof(s) == Vector{Int} - s = itsample(a:b, 100, alloc=alloc) + s = itsample(iter, 100, alloc=alloc) @test length(s) == 10 @test length(unique(s)) == 10 + # test return values of iter with unknown lengths are inrange + iter = Iterators.filter(x -> x < 5, a:b) + s = itsample(iter, 2, alloc=alloc) + @test length(s) == 2 + @test all(x -> a <= x <= b, s) + + @test typeof(s) == Vector{ifelse(alloc, Int, Any)} + s = itsample(iter, 2, alloc=alloc, iter_type=Int) + @test length(s) == 2 + @test all(x -> a <= x <= b, s) + @test typeof(s) == Vector{Int} + s = itsample(iter, 100, alloc=alloc) + @test length(s) == 4 + @test length(unique(s)) == 4 + # create empirical distribution + iter = a:b rng = StableRNG(43) reps = 10000 dict_res = Dict{Vector, Int}() for _ in 1:reps - s = itsample(rng, a:b, 2, alloc=alloc) + s = itsample(rng, iter, 2, alloc=alloc) if s in keys(dict_res) dict_res[s] += 1 elseif ifelse(alloc, Int, Any)[s[2], s[1]] in keys(dict_res)