Skip to content

Commit

Permalink
implement IterHasKnownSize trait
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Oct 15, 2023
1 parent 84dee30 commit 9a061cb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 13 deletions.
23 changes: 20 additions & 3 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ function itsample(rng::AbstractRNG, iter, n::Int; alloc = true, iter_type = Any)
if alloc
unweighted_sampling_multi(iter, rng, n)
else
unweighted_resorvoir_sampling_multi(iter, rng, n, iter_type)
IterHasKnownSize = Base.IteratorSize(iter)
unweighted_resorvoir_sampling_multi(iter, rng, n, IterHasKnownSize, iter_type)
end
end

Expand All @@ -20,7 +21,8 @@ function itsample(rng::AbstractRNG, iter, condition::Function, n::Int; alloc = t
unweighted_sampling_with_condition_multi(iter, rng, n, condition)
else
iter_filtered = Iterators.filter(x -> condition(x), iter)
unweighted_resorvoir_sampling_multi(iter_filtered, rng, n, iter_type)
IterHasKnownSize = Base.IteratorSize(iter_filtered)
unweighted_resorvoir_sampling_multi(iter_filtered, rng, n, IterHasKnownSize, iter_type)
end
end

Expand Down Expand Up @@ -50,7 +52,7 @@ function unweighted_sampling_with_condition_multi(iter, rng, n, condition)
return res[1:i]
end

function unweighted_resorvoir_sampling_multi(iter, rng, n, iter_type = Any)
function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Base.SizeUnknown, iter_type = eltype(iter))
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand All @@ -77,4 +79,19 @@ function unweighted_resorvoir_sampling_multi(iter, rng, n, iter_type = Any)
reservoir[rand(rng, 1:n)] = el
w *= rand(rng)^(1/n)
end
end

function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Union{Base.HasLength, Base.HasShape}, iter_type = eltype(iter))
N = length(iter)
N <= n && return collect(iter)
indices = sort!(sample(rng, 1:N, n; replace=false))
reservoir = Vector{iter_type}(undef, n)
j = 1
for (i, x) in enumerate(iter)
if i == indices[j]
reservoir[j] = x
j == n && return reservoir
j += 1
end
end
end
20 changes: 15 additions & 5 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ function itsample(rng::AbstractRNG, iter; alloc = false)
if alloc
unweighted_sampling_single(iter, rng)
else
unweighted_resorvoir_sampling_single(iter, rng)
IterHasKnownSize = Base.IteratorSize(iter)
unweighted_resorvoir_sampling_single(iter, rng, IterHasKnownSize)
end
end

Expand All @@ -20,11 +21,11 @@ function itsample(rng::AbstractRNG, iter, condition::Function; alloc = false)
unweighted_sampling_with_condition_single(iter, rng, condition)
else
iter_filtered = Iterators.filter(x -> condition(x), iter)
unweighted_resorvoir_sampling_single(iter_filtered, rng)
IterHasKnownSize = Base.IteratorSize(iter_filtered)
unweighted_resorvoir_sampling_single(iter_filtered, rng, IterHasKnownSize)
end
end


function unweighted_sampling_single(iter, rng)
pop = collect(iter)
isempty(pop) && return nothing
Expand All @@ -44,7 +45,7 @@ function unweighted_sampling_with_condition_single(iter, rng, condition)
return nothing
end

function unweighted_resorvoir_sampling_single(iter, rng)
function unweighted_resorvoir_sampling_single(iter, rng, ::Base.SizeUnknown)
res = iterate(iter)
isnothing(res) && return nothing
w = rand(rng)
Expand All @@ -61,4 +62,13 @@ function unweighted_resorvoir_sampling_single(iter, rng)
isnothing(res) && return choice
w *= rand(rng)
end
end
end

function unweighted_resorvoir_sampling_single(iter, rng, ::Union{Base.HasLength, Base.HasShape})
k = rand(rng, 1:length(iter))
for (i, x) in enumerate(iter)
i == k && return x
end
end


27 changes: 22 additions & 5 deletions test/unweighted_sampling_multi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,44 @@
@testset "Unweighted sampling multi tests" begin
@testset "alloc=$(alloc)" for alloc in [false, true]

# test values are inrange
a, b = 1, 10
s = itsample(a:b, 2, alloc=alloc)
# test return values of iter with known lengths are inrange
iter = a:b
s = itsample(iter, 2, alloc=alloc)
@test length(s) == 2
@test all(x -> a <= x <= b, s)

@test typeof(s) == Vector{ifelse(alloc, Int, Any)}
s = itsample(a:b, 2, alloc=alloc, iter_type=Int)
s = itsample(iter, 2, alloc=alloc, iter_type=Int)
@test length(s) == 2
@test all(x -> a <= x <= b, s)
@test typeof(s) == Vector{Int}
s = itsample(a:b, 100, alloc=alloc)
s = itsample(iter, 100, alloc=alloc)
@test length(s) == 10
@test length(unique(s)) == 10

# test return values of iter with unknown lengths are inrange
iter = Iterators.filter(x -> x < 5, a:b)
s = itsample(iter, 2, alloc=alloc)
@test length(s) == 2
@test all(x -> a <= x <= b, s)

@test typeof(s) == Vector{ifelse(alloc, Int, Any)}
s = itsample(iter, 2, alloc=alloc, iter_type=Int)
@test length(s) == 2
@test all(x -> a <= x <= b, s)
@test typeof(s) == Vector{Int}
s = itsample(iter, 100, alloc=alloc)
@test length(s) == 4
@test length(unique(s)) == 4

# create empirical distribution
iter = a:b
rng = StableRNG(43)
reps = 10000
dict_res = Dict{Vector, Int}()
for _ in 1:reps
s = itsample(rng, a:b, 2, alloc=alloc)
s = itsample(rng, iter, 2, alloc=alloc)
if s in keys(dict_res)
dict_res[s] += 1
elseif ifelse(alloc, Int, Any)[s[2], s[1]] in keys(dict_res)
Expand Down

0 comments on commit 9a061cb

Please sign in to comment.