Skip to content

Commit

Permalink
Develop new itsample signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Oct 23, 2023
1 parent 1262ba7 commit a94d6a8
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 37 deletions.
5 changes: 5 additions & 0 deletions src/IteratorSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ module IteratorSampling

using StatsBase, Random


Indexable = Union{Base.HasLength, Base.HasShape}
NonIndexable = Base.SizeUnknown

include("SortedRand.jl")
include("UnweightedSamplingSingle.jl")
include("UnweightedSamplingMulti.jl")

Expand Down
141 changes: 122 additions & 19 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,106 @@

function itsample(iter, n::Int; alloc = true, iter_type = Any)
return itsample(Random.GLOBAL_RNG, iter, n; alloc = alloc, iter_type = iter_type)
# UnWeighted

function itsample(
iter, n::Int; replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, iter, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(rng::AbstractRNG, iter, n::Int; alloc = true, iter_type = Any)
if alloc
unweighted_sampling_multi(iter, rng, n)
function itsample(
rng::AbstractRNG, iter, n::Int;
replace = false, alloc = true, iter_type = Any
)
if alloc
if replace
error("Not implemented yet")
else
unweighted_sampling(iter, rng, n)
end
else
IterHasKnownSize = Base.IteratorSize(iter)
unweighted_resorvoir_sampling_multi(iter, rng, n, IterHasKnownSize, iter_type)
if replace
error("Not implemented yet")
else
IterHasKnownSize = Base.IteratorSize(iter)
unweighted_resorvoir_sampling(iter, rng, n, IterHasKnownSize,
iter_type)
end
end
end

function itsample(iter, condition::Function, n::Int; alloc = true, iter_type = Any)
return itsample(Random.GLOBAL_RNG, iter, condition, n; alloc = alloc, iter_type = iter_type)
end
function itsample(
condition::Function, iter, n::Int;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, condition, iter, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end


function itsample(rng::AbstractRNG, iter, condition::Function, n::Int; alloc = true, iter_type = Any)
function itsample(
rng::AbstractRNG, condition::Function, iter, n::Int;
replace = false, alloc = true, iter_type = Any
)
if alloc
unweighted_sampling_with_condition_multi(iter, rng, n, condition)
if replace
error("Not implemented yet")
else
conditioned_unweighted_sampling(iter, rng, n, condition)
end
else
iter_filtered = Iterators.filter(x -> condition(x), iter)
IterHasKnownSize = Base.IteratorSize(iter_filtered)
unweighted_resorvoir_sampling_multi(iter_filtered, rng, n, IterHasKnownSize, iter_type)
if replace
error("Not implemented yet")
else
iter_filtered = Iterators.filter(x -> condition(x), iter)
IterHasKnownSize = Base.IteratorSize(iter_filtered)
unweighted_resorvoir_sampling(iter_filtered, rng, n, IterHasKnownSize,
iter_type)
end
end
end

function unweighted_sampling_multi(iter, rng, n)
# Weighted

function itsample(
iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, iter, wv, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(
rng::AbstractRNG, iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return error("Not implemented yet")
end

function itsample(
condition::Function, iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, condition, iter, wv, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(
rng::AbstractRNG, condition::Function, iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return error("Not implemented yet")
end

# ALGORITHMS

function unweighted_sampling(iter, rng, n::Int)
pop = collect(iter)
length(pop) <= n && return pop
return sample(rng, pop, n; replace=false)
end

function unweighted_sampling_with_condition_multi(iter, rng, n, condition)
function conditioned_unweighted_sampling(iter, rng, n::Int, condition)
pop = collect(iter)
n_p = length(pop)
n_p <= n && return filter(el -> condition(el), pop)
Expand All @@ -52,7 +120,13 @@ function unweighted_sampling_with_condition_multi(iter, rng, n, condition)
return res[1:i]
end

function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Base.SizeUnknown, iter_type = eltype(iter))
function unweighted_resorvoir_sampling(
iter,
rng,
n::Int,
::NonIndexable,
iter_type = eltype(iter)
)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand Down Expand Up @@ -81,7 +155,13 @@ function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Base.SizeUnknown, i
end
end

function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Union{Base.HasLength, Base.HasShape}, iter_type = eltype(iter))
function unweighted_resorvoir_sampling(
iter,
rng,
n::Int,
::Indexable,
iter_type = eltype(iter)
)
N = length(iter)
N <= n && return collect(iter)
indices = sort!(sample(rng, 1:N, n; replace=false))
Expand All @@ -95,3 +175,26 @@ function unweighted_resorvoir_sampling_multi(iter, rng, n, ::Union{Base.HasLengt
end
end
end

function unweighted_resorvoir_sampling_multi(
iter,
rng,
n::Int,
::Indexable,
iter_type = eltype(iter),
replace = true
)
N = length(iter)
N <= n && return collect(iter)
indices = sort!(sample(rng, 1:N, n; replace=false))
reservoir = Vector{iter_type}(undef, n)
j = 1
for (i, x) in enumerate(iter)
if i == indices[j]
reservoir[j] = x
j == n && return shuffle!(reservoir)
j += 1
end
end
end

93 changes: 75 additions & 18 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,95 @@

function itsample(iter; alloc = false)
return itsample(Random.GLOBAL_RNG, iter; alloc = alloc)
# UnWeighted

function itsample(iter; replace = false, alloc = false)
return itsample(Random.GLOBAL_RNG, iter;
replace = replace, alloc = alloc)
end

function itsample(rng::AbstractRNG, iter; alloc = false)
if alloc
unweighted_sampling_single(iter, rng)
function itsample(rng::AbstractRNG, iter; replace = false, alloc = false)
if alloc
if replace
error("Not implemented yet")
else
unweighted_sampling(iter, rng)
end
else
IterHasKnownSize = Base.IteratorSize(iter)
unweighted_resorvoir_sampling_single(iter, rng, IterHasKnownSize)
if replace
error("Not implemented yet")
else
IterHasKnownSize = Base.IteratorSize(iter)
unweighted_resorvoir_sampling(iter, rng, IterHasKnownSize)
end
end
end

function itsample(iter, condition::Function; alloc = false)
return itsample(Random.GLOBAL_RNG, iter, condition; alloc = alloc)
function itsample(condition::Function, iter, replace = false, alloc = false)
return itsample(Random.GLOBAL_RNG, condition, iter;
replace = replace, alloc = alloc)
end

function itsample(rng::AbstractRNG, iter, condition::Function; alloc = false)
function itsample(
rng::AbstractRNG, condition::Function, iter;
replace = false, alloc = false
)
if alloc
unweighted_sampling_with_condition_single(iter, rng, condition)
if replace
error("Not implemented yet")
else
conditioned_unweighted_sampling(iter, rng, condition)
end
else
iter_filtered = Iterators.filter(x -> condition(x), iter)
IterHasKnownSize = Base.IteratorSize(iter_filtered)
unweighted_resorvoir_sampling_single(iter_filtered, rng, IterHasKnownSize)
if replace
error("Not implemented yet")
else
iter_filtered = Iterators.filter(x -> condition(x), iter)
IterHasKnownSize = Base.IteratorSize(iter_filtered)
unweighted_resorvoir_sampling(iter_filtered, rng, IterHasKnownSize)
end
end
end

function unweighted_sampling_single(iter, rng)
# Weighted

function itsample(
iter, wv::Function;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, iter, wv;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(
rng::AbstractRNG, iter, wv::Function;
replace = false, alloc = true, iter_type = Any
)
return error("Not implemented yet")
end

function itsample(
condition::Function, iter, wv::Function;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, condition, iter, wv;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(
rng::AbstractRNG, condition::Function, iter, wv::Function;
replace = false, alloc = true, iter_type = Any
)
return error("Not implemented yet")
end

# ALGORITHMS

function unweighted_sampling(iter, rng)
pop = collect(iter)
isempty(pop) && return nothing
return rand(rng, pop)
end

function unweighted_sampling_with_condition_single(iter, rng, condition)
function conditioned_unweighted_sampling(iter, rng, condition)
pop = collect(iter)
n_p = length(pop)
while n_p != 0
Expand All @@ -45,7 +102,7 @@ function unweighted_sampling_with_condition_single(iter, rng, condition)
return nothing
end

function unweighted_resorvoir_sampling_single(iter, rng, ::Base.SizeUnknown)
function unweighted_resorvoir_sampling(iter, rng, ::NonIndexable)
res = iterate(iter)
isnothing(res) && return nothing
w = rand(rng)
Expand All @@ -64,7 +121,7 @@ function unweighted_resorvoir_sampling_single(iter, rng, ::Base.SizeUnknown)
end
end

function unweighted_resorvoir_sampling_single(iter, rng, ::Union{Base.HasLength, Base.HasShape})
function unweighted_resorvoir_sampling(iter, rng, ::Indexable)
k = rand(rng, 1:length(iter))
for (i, x) in enumerate(iter)
i == k && return x
Expand Down

0 comments on commit a94d6a8

Please sign in to comment.