Skip to content

Commit

Permalink
new project structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Dec 26, 2023
1 parent a94d6a8 commit 0dbccbd
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 316 deletions.
218 changes: 61 additions & 157 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,132 +1,32 @@

# UnWeighted

function itsample(
iter, n::Int; replace = false, alloc = true, iter_type = Any
)
function itsample(iter, n::Int;
replace = false, ordered = false, is_stateful = false)
return itsample(Random.GLOBAL_RNG, iter, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(
rng::AbstractRNG, iter, n::Int;
replace = false, alloc = true, iter_type = Any
)
if alloc
if replace
error("Not implemented yet")
else
unweighted_sampling(iter, rng, n)
end
else
if replace
error("Not implemented yet")
else
IterHasKnownSize = Base.IteratorSize(iter)
unweighted_resorvoir_sampling(iter, rng, n, IterHasKnownSize,
iter_type)
end
end
replace=replace, ordered=ordered, is_stateful=is_stateful)
end

function itsample(
condition::Function, iter, n::Int;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, condition, iter, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end


function itsample(
rng::AbstractRNG, condition::Function, iter, n::Int;
replace = false, alloc = true, iter_type = Any
)
if alloc
if replace
error("Not implemented yet")
function itsample(rng::AbstractRNG, iter, n::Int;
replace = false, ordered = false, is_stateful = false)
IterHasKnownSize = Base.IteratorSize(iter)
if IterHasKnownSize isa NonIndexable
if is_stateful
if replace
error("Not implemented yet")
else
unweighted_resorvoir_sampling(rng, iter, n)
end
else
conditioned_unweighted_sampling(iter, rng, n, condition)
unweighted_resorvoir_sampling(rng, iter, n)
#double_scan_sampling(rng, iter, n, replace, ordered)
end
else
if replace
error("Not implemented yet")
else
iter_filtered = Iterators.filter(x -> condition(x), iter)
IterHasKnownSize = Base.IteratorSize(iter_filtered)
unweighted_resorvoir_sampling(iter_filtered, rng, n, IterHasKnownSize,
iter_type)
end
unweighted_resorvoir_sampling(rng, iter, n)
#single_scan_sampling(rng, iter, n, replace, ordered)
end
end

# Weighted

function itsample(
iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, iter, wv, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(
rng::AbstractRNG, iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return error("Not implemented yet")
end

function itsample(
condition::Function, iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return itsample(Random.GLOBAL_RNG, condition, iter, wv, n;
replace = replace, alloc = alloc, iter_type = iter_type)
end

function itsample(
rng::AbstractRNG, condition::Function, iter, wv::Function, n::Int;
replace = false, alloc = true, iter_type = Any
)
return error("Not implemented yet")
end

# ALGORITHMS

function unweighted_sampling(iter, rng, n::Int)
pop = collect(iter)
length(pop) <= n && return pop
return sample(rng, pop, n; replace=false)
end

function conditioned_unweighted_sampling(iter, rng, n::Int, condition)
pop = collect(iter)
n_p = length(pop)
n_p <= n && return filter(el -> condition(el), pop)
res = Vector{eltype(pop)}(undef, n)
i = 0
while n_p != 0
idx = rand(rng, 1:n_p)
el = pop[idx]
if condition(el)
i += 1
res[i] = el
i == n && return res
end
pop[idx], pop[n_p] = pop[n_p], pop[idx]
n_p -= 1
end
return res[1:i]
end

function unweighted_resorvoir_sampling(
iter,
rng,
n::Int,
::NonIndexable,
function unweighted_resorvoir_sampling(rng, iter, n::Int)
iter_type = eltype(iter)
)
it = iterate(iter)
isnothing(it) && return iter_type[]
el, state = it
Expand All @@ -136,65 +36,69 @@ function unweighted_resorvoir_sampling(
it = iterate(iter, state)
isnothing(it) && return reservoir[1:i-1]
el, state = it
reservoir[i] = el
@inbounds reservoir[i] = el
end
w = rand(rng)^(1/n)
u = randexp(rng)
while true
skip_counter = floor(log(rand(rng))/log(1-w))
w = exp(-u/n)
skip_counter = ceil(Int, randexp(rng)/log(1-w))
while skip_counter != 0
skip_it = iterate(iter, state)
isnothing(skip_it) && return reservoir
state = skip_it[2]
skip_counter -= 1
skip_res = iterate(iter, state)
isnothing(skip_res) && return reservoir
state = skip_res[2]
skip_counter += 1
end
it = iterate(iter, state)
isnothing(it) && return reservoir
el, state = it
reservoir[rand(rng, 1:n)] = el
w *= rand(rng)^(1/n)
u += randexp(rng)
end
end

function unweighted_resorvoir_sampling(
iter,
rng,
n::Int,
::Indexable,
iter_type = eltype(iter)
)
N = length(iter)
N <= n && return collect(iter)
indices = sort!(sample(rng, 1:N, n; replace=false))
reservoir = Vector{iter_type}(undef, n)
j = 1
for (i, x) in enumerate(iter)
if i == indices[j]
reservoir[j] = x
j == n && return shuffle!(reservoir)
j += 1
end
end
function double_scan_sampling(rng, iter, n::Int, replace, ordered)
N = get_population_size(iter)
single_scan_sampling(iter, rng, n, N, replace, ordered)
end

function single_scan_sampling(rng, iter, n::Int, replace, ordered)
return single_scan_sampling(rng, iter, n, length(iter), replace, ordered)
end

function unweighted_resorvoir_sampling_multi(
iter,
rng,
n::Int,
::Indexable,
iter_type = eltype(iter),
replace = true
)
N = length(iter)
function single_scan_sampling(rng, iter, n::Int, N::Int, replace, ordered)
N <= n && return collect(iter)
indices = sort!(sample(rng, 1:N, n; replace=false))
iter_type = eltype(iter)
indices = sort!(sample(rng, 1:N, n; replace=replace))
reservoir = Vector{iter_type}(undef, n)
j = 1
i = 1
for (i, x) in enumerate(iter)
if i == indices[j]
@inbounds while i == indices[j]
reservoir[j] = x
j == n && return shuffle!(reservoir)
if j == n
if ordered
return reservoir
else
return shuffle!(reservoir)
end
end
j += 1
end
end
if ordered
return reservoir
else
return shuffle!(reservoir)
end
end

function get_population_size(iter)
n = 0
it = iterate(iter)
while !isnothing(it)
n += 1
@inbounds state = it[2]
it = iterate(iter, state)
end
return n
end
Loading

0 comments on commit 0dbccbd

Please sign in to comment.