Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tests #70

Merged
merged 11 commits into from
Apr 26, 2024
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ julia> rng = Xoshiro(42);

julia> iter = Iterators.filter(x -> x != 10, 1:10^7);

julia> wv(el) = 1.0
wv (generic function with 1 method)
julia> wv(el) = 1.0;

julia> @btime itsample($rng, $iter, 10^4, algRSWRSKIP);
12.209 ms (8 allocations: 156.47 KiB)
Expand Down
3 changes: 2 additions & 1 deletion src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ export ReservoirSample

"""

update!(rs::AbstractReservoirSample, el, [w])
update!(rs::AbstractReservoirSample, el)
update!(rs::AbstractReservoirSample, el, w::Float64)

Updates the reservoir sample by scanning the passed element.
In the case of weighted sampling also the weight of the element
Expand Down
30 changes: 18 additions & 12 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,33 @@ function value(s::SampleSingleAlgR)
end

function ReservoirSample(T, method::ReservoirAlgorithm = algL)
return ReservoirSample(Random.default_rng(), T, method)
return ReservoirSample(Random.default_rng(), T, method, ms)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgL = algL)
function ReservoirSample(rng::AbstractRNG, T, method::ReservoirAlgorithm = algL)
return ReservoirSample(rng, T, method, ms)
end
function ReservoirSample(rng::AbstractRNG, T, ::AlgL, ::MutSample)
return SampleSingleAlgL{T, typeof(rng)}(1.0, 0, 0, rng)
end
function ReservoirSample(rng::AbstractRNG, T, method::AlgR)
function ReservoirSample(rng::AbstractRNG, T, ::AlgR, ::MutSample)
return SampleSingleAlgR{T, typeof(rng)}(0, rng)
end

@inline function update!(s::SampleSingleAlgR, el)
@imm_reset s.seen_k += 1
s.seen_k += 1
if rand(s.rng) <= 1/s.seen_k
@imm_reset s.value = el
s.value = el
end
return s
end
@inline function update!(s::SampleSingleAlgL, el)
s.seen_k += 1
if s.skip_k > 0
@imm_reset s.skip_k -= 1
s.skip_k -= 1
else
@imm_reset s.value = el
@imm_reset s.state *= rand(s.rng)
@imm_reset s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
s.value = el
s.state *= rand(s.rng)
s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
end
return s
end
Expand Down Expand Up @@ -80,15 +83,18 @@ end
function itsample(rng::AbstractRNG, iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
if Base.IteratorSize(iter) isa Base.SizeUnknown
return reservoir_sample(rng, iter, method; iter_type)
return reservoir_sample(rng, iter, iter_type, method)
else
return sortedindices_sample(rng, iter)
end
end

function reservoir_sample(rng, iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
function reservoir_sample(rng, iter, iter_type, method::ReservoirAlgorithm = algL)
s = ReservoirSample(rng, iter_type, method)
return update_all!(s, iter)
end

function update_all!(s, iter)
for x in iter
s = update!(s, x)
end
Expand Down
51 changes: 39 additions & 12 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,47 @@

mutable struct SampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
struct ImmutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
rng::R
value::T
ImmutSampleSingleAlgARes(state, rng::R, value::T) where {T,R} = new{T,R}(state, rng, value)
ImmutSampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
end
mutable struct MutSampleSingleAlgARes{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
const rng::R
value::T
SampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
MutSampleSingleAlgARes{T,R}(state, rng) where {T,R} = new{T,R}(state, rng)
end
const SampleSingleAlgARes = Union{ImmutSampleSingleAlgARes, MutSampleSingleAlgARes}

mutable struct SampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
struct ImmutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
skip_w::Float64
rng::R
value::T
ImmutSampleSingleAlgAExpJ(state, skip_w, rng::R, value::T) where {T,R} = new{T,R}(state, skip_w, rng, value)
ImmutSampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
end
mutable struct MutSampleSingleAlgAExpJ{T,R} <: AbstractWeightedReservoirSampleSingle
state::Float64
skip_w::Float64
const rng::R
value::T
SampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
MutSampleSingleAlgAExpJ{T,R}(state, skip_w, rng) where {T,R} = new{T,R}(state, skip_w, rng)
end
const SampleSingleAlgAExpJ = Union{ImmutSampleSingleAlgAExpJ, MutSampleSingleAlgAExpJ}

function ReservoirSample(rng::R, T, method::AlgARes) where {R<:AbstractRNG}
return SampleSingleAlgARes{T,R}(0.0, rng)
function ReservoirSample(rng::R, T, ::AlgARes, ::MutSample) where {R<:AbstractRNG}
return MutSampleSingleAlgARes{T,R}(typemax(Float64), rng)
end
function ReservoirSample(rng::R, T, ::AlgARes, ::ImmutSample) where {R<:AbstractRNG}
return ImmutSampleSingleAlgARes{T,R}(typemax(Float64), rng)
end
function ReservoirSample(rng::R, T, ::AlgAExpJ, ::MutSample) where {R<:AbstractRNG}
return MutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
end
function ReservoirSample(rng::R, T, method::AlgAExpJ) where {R<:AbstractRNG}
return SampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
function ReservoirSample(rng::R, T, ::AlgAExpJ, ::ImmutSample) where {R<:AbstractRNG}
return ImmutSampleSingleAlgAExpJ{T,R}(0.0, 0.0, rng)
end

function value(s::AbstractWeightedReservoirSampleSingle)
Expand All @@ -27,8 +50,8 @@ function value(s::AbstractWeightedReservoirSampleSingle)
end

@inline function update!(s::SampleSingleAlgARes, el, w)
priority = -randexp(s.rng)/w
if priority > s.state
priority = randexp(s.rng)/w
if priority < s.state
@imm_reset s.state = priority
@imm_reset s.value = el
end
Expand All @@ -45,12 +68,16 @@ end

function itsample(iter, wv::Function, method::ReservoirAlgorithm = algAExpJ;
iter_type = infer_eltype(iter))
return itsample(Random.default_rng(), iter, wv, method; iter_type)
return itsample(Random.default_rng(), iter, wv, method)
end

function itsample(rng::AbstractRNG, iter, wv::Function, method::ReservoirAlgorithm = algAExpJ;
iter_type = infer_eltype(iter))
s = ReservoirSample(rng, iter_type, algAExpJ)
s = ReservoirSample(rng, iter_type, method, ms)
return update_all!(s, iter, wv)
end

function update_all!(s, iter, wv::Function)
for x in iter
s = update!(s, x, wv(x))
end
Expand Down
2 changes: 1 addition & 1 deletion test/unweighted_sampling_single_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@test a <= z <= b

iter = Iterators.filter(x -> x != b + 1, a:b+1)
rs = ReservoirSample(Int, method; ordered = ordered)
rs = ReservoirSample(Int, method)
for x in iter
update!(rs, x)
end
Expand Down
4 changes: 2 additions & 2 deletions test/weighted_sampling_single_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
@test a <= z <= b

iter = Iterators.filter(x -> x != b + 1, a:b+1)
rs = ReservoirSample(Int, method; ordered = ordered)
rs = ReservoirSample(Int, method)
for x in iter
update!(rs, x)
update!(rs, x, wv(x))
end
@test a <= value(rs) <= b

Expand Down
Loading