Skip to content

Commit

Permalink
Fix tests (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 26, 2024
1 parent ce75f40 commit 37f65e2
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 30 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ julia> rng = Xoshiro(42);

julia> iter = Iterators.filter(x -> x != 10, 1:10^7);

julia> wv(el) = 1.0
wv (generic function with 1 method)
julia> wv(el) = 1.0;

julia> @btime itsample($rng, $iter, 10^4, algRSWRSKIP);
12.209 ms (8 allocations: 156.47 KiB)
Expand Down
3 changes: 2 additions & 1 deletion src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ export ReservoirSample

"""
update!(rs::AbstractReservoirSample, el, [w])
update!(rs::AbstractReservoirSample, el)
update!(rs::AbstractReservoirSample, el, w::Float64)
Updates the reservoir sample by scanning the passed element.
In the case of weighted sampling also the weight of the element
Expand Down
30 changes: 18 additions & 12 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,33 @@ function value(s::SampleSingleAlgR)
end

function ReservoirSample(T, method::ReservoirAlgorithm = algL)
return ReservoirSample(Random.default_rng(), T, method)
return ReservoirSample(Random.default_rng(), T, method, ms)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgL = algL)
function ReservoirSample(rng::AbstractRNG, T, method::ReservoirAlgorithm = algL)
return ReservoirSample(rng, T, method, ms)
end
function ReservoirSample(rng::AbstractRNG, T, ::AlgL, ::MutSample)
return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, 0, rng)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgR)
function ReservoirSample(rng::AbstractRNG, T, ::AlgR, ::MutSample)
return SampleSingleAlgR{T, typeof(rng)}(0, rng)
end

@inline function update!(s::SampleSingleAlgR, el)
@imm_reset s.seen_k += 1
s.seen_k += 1
if rand(s.rng) <= 1/s.seen_k
@imm_reset s.value = el
s.value = el
end
return s
end
@inline function update!(s::SampleSingleAlgL, el)
s.seen_k += 1
if s.skip_k > 0
@imm_reset s.skip_k -= 1
s.skip_k -= 1
else
@imm_reset s.value = el
@imm_reset s.state *= rand(s.rng)
@imm_reset s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
s.value = el
s.state *= rand(s.rng)
s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
end
return s
end
Expand Down Expand Up @@ -80,15 +83,18 @@ end
function itsample(rng::AbstractRNG, iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
if Base.IteratorSize(iter) isa Base.SizeUnknown
return reservoir_sample(rng, iter, method; iter_type)
return reservoir_sample(rng, iter, iter_type, method)
else
return sortedindices_sample(rng, iter)
end
end

function reservoir_sample(rng, iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
function reservoir_sample(rng, iter, iter_type, method::ReservoirAlgorithm = algL)
s = ReservoirSample(rng, iter_type, method)
return update_all!(s, iter)
end

function update_all!(s, iter)
for x in iter
s = update!(s, x)
end
Expand Down
51 changes: 39 additions & 12 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@

mutable struct SampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
struct ImmutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
rng::R
value::T
ImmutSampleSingleAlgARes(state, rng::R, value::T) where {T,R} = new{T,R}(state, rng, value)
ImmutSampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
end
mutable struct MutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
const rng::R
value::T
SampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
MutSampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
end
const SampleSingleAlgARes = Union{ImmutSampleSingleAlgARes, MutSampleSingleAlgARes}

mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
struct ImmutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
skip_w::Float64
rng::R
value::T
ImmutSampleSingleAlgAExpJ(state, skip_w, rng::R, value::T) where {T,R} = new{T,R}(state, skip_w, rng, value)
ImmutSampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
end
mutable struct MutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
skip_w::Float64
const rng::R
value::T
SampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
MutSampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
end
const SampleSingleAlgAExpJ = Union{ImmutSampleSingleAlgAExpJ, MutSampleSingleAlgAExpJ}

function ReservoirSample(rng::R, T, method::AlgARes) where {R<:AbstractRNG}
return SampleSingleAlgARes{T,R}(0.0, rng)
function ReservoirSample(rng::R, T, ::AlgARes, ::MutSample) where {R<:AbstractRNG}
return MutSampleSingleAlgARes{T,R}(typemax(Float64), rng)
end
function ReservoirSample(rng::R, T, ::AlgARes, ::ImmutSample) where {R<:AbstractRNG}
return ImmutSampleSingleAlgARes{T,R}(typemax(Float64), rng)
end
function ReservoirSample(rng::R, T, ::AlgAExpJ, ::MutSample) where {R<:AbstractRNG}
return MutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
end
function ReservoirSample(rng::R, T, method::AlgAExpJ) where {R<:AbstractRNG}
return SampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
function ReservoirSample(rng::R, T, ::AlgAExpJ, ::ImmutSample) where {R<:AbstractRNG}
return ImmutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
end

function value(s::AbstractWeightedReservoirSampleSingle)
Expand All @@ -27,8 +50,8 @@ function value(s::AbstractWeightedReservoirSampleSingle)
end

@inline function update!(s::SampleSingleAlgARes, el, w)
priority = -randexp(s.rng)/w
if priority > s.state
priority = randexp(s.rng)/w
if priority < s.state
@imm_reset s.state = priority
@imm_reset s.value = el
end
Expand All @@ -45,12 +68,16 @@ end

function itsample(iter, wv::Function, method::ReservoirAlgorithm = algAExpJ;
iter_type = infer_eltype(iter))
return itsample(Random.default_rng(), iter, wv, method; iter_type)
return itsample(Random.default_rng(), iter, wv, method)
end

function itsample(rng::AbstractRNG, iter, wv::Function, method::ReservoirAlgorithm = algAExpJ;
iter_type = infer_eltype(iter))
s = ReservoirSample(rng, iter_type, algAExpJ)
s = ReservoirSample(rng, iter_type, method, ms)
return update_all!(s, iter, wv)
end

function update_all!(s, iter, wv::Function)
for x in iter
s = update!(s, x, wv(x))
end
Expand Down
2 changes: 1 addition & 1 deletion test/unweighted_sampling_single_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@test a <= z <= b

iter = Iterators.filter(x -> x != b + 1, a:b+1)
rs = ReservoirSample(Int, method; ordered = ordered)
rs = ReservoirSample(Int, method)
for x in iter
update!(rs, x)
end
Expand Down
4 changes: 2 additions & 2 deletions test/weighted_sampling_single_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
@test a <= z <= b

iter = Iterators.filter(x -> x != b + 1, a:b+1)
rs = ReservoirSample(Int, method; ordered = ordered)
rs = ReservoirSample(Int, method)
for x in iter
update!(rs, x)
update!(rs, x, wv(x))
end
@test a <= value(rs) <= b

Expand Down

0 comments on commit 37f65e2

Please sign in to comment.