Skip to content

Commit

Permalink
Improve perf of weighted sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 25, 2024
1 parent 491adf5 commit ce75f40
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 111 deletions.
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,29 @@ julia> rng = Xoshiro(42);
julia> iter = Iterators.filter(x -> x != 10, 1:10^7);

julia> wv(el) = 1.0
wv (generic function with 1 method)

julia> @btime itsample($rng, $iter, 10^4, algRSWRSKIP);
11.744 ms (5 allocations: 156.39 KiB)
12.209 ms (8 allocations: 156.47 KiB)

julia> @btime sample($rng, collect($iter), 10^4; replace=true);
131.933 ms (20 allocations: 146.91 MiB)
134.622 ms (20 allocations: 146.91 MiB)

julia> @btime itsample($rng, $iter, 10^4, algL);
10.260 ms (3 allocations: 78.22 KiB)
10.450 ms (6 allocations: 78.30 KiB)

julia> @btime sample($rng, collect($iter), 10^4; replace=false);
132.069 ms (27 allocations: 147.05 MiB)
135.039 ms (27 allocations: 147.05 MiB)

julia> @btime itsample($rng, $iter, $wv, 10^4, algWRSWRSKIP);
32.278 ms (18 allocations: 547.34 KiB)
14.017 ms (13 allocations: 568.84 KiB)

julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=true);
348.220 ms (49 allocations: 675.21 MiB)
543.582 ms (45 allocations: 702.33 MiB)

julia> @btime itsample($rng, $iter, $wv, 10^4, algAExpJ);
39.965 ms (11 allocations: 234.78 KiB)
20.968 ms (9 allocations: 234.73 KiB)

julia> @btime sample($rng, collect($iter), Weights($wv.($iter)), 10^4; replace=false);
306.039 ms (43 allocations: 370.19 MiB)
305.226 ms (43 allocations: 370.19 MiB)
```
3 changes: 3 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# API

This is the API page of the package. For a general overview of the functionalities
consult the [ReadMe](https://github.com/JuliaDynamics/StreamSampling.jl).

## General functionalities

```@docs
Expand Down
4 changes: 2 additions & 2 deletions src/SortedSamplingMulti.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

function sortedindices_sample(rng, iter, n::Int; replace = false, ordered = false, kwargs...)
function sortedindices_sample(rng, iter, n::Int;
iter_type = infer_eltype(iter), replace = false, ordered = false)
N = length(iter)
if N <= n
reservoir = collect(iter)
Expand All @@ -13,7 +14,6 @@ function sortedindices_sample(rng, iter, n::Int; replace = false, ordered = fals
end
end
end
iter_type = calculate_eltype(iter)
reservoir = Vector{iter_type}(undef, n)
indices = get_sorted_indices(rng, n, N, replace)
first_idx = indices[1]
Expand Down
14 changes: 5 additions & 9 deletions src/StreamSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@ using Distributions
using Random
using StatsBase

struct WRSample end
struct OrdWRSample end
struct WORSample end
struct OrdWORSample end

const wrsample = WRSample()
const ordwrsample = OrdWRSample()
const worsample = WORSample()
const ordworsample = OrdWORSample()
struct ImmutSample end
struct MutSample end

const ims = ImmutSample()
const ms = MutSample()

abstract type AbstractReservoirSample end

Expand Down
57 changes: 33 additions & 24 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,34 @@ mutable struct SampleMultiOrdAlgRSWRSKIP{T,R} <: AbstractOrdWrReservoirSampleMul
const ord::Vector{Int}
end

function ReservoirSample(T, n::Integer, method::ReservoirAlgorithm=algL; ordered = false)
return ReservoirSample(Random.default_rng(), T, n, method; ordered = ordered)
function ReservoirSample(T, n::Integer, method::ReservoirAlgorithm=algL;
ordered = false)
return ReservoirSample(Random.default_rng(), T, n, method, ms; ordered = ordered)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgL=algL; ordered = false)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::ReservoirAlgorithm=algL;
ordered = false)
return ReservoirSample(rng, T, n, method, ms; ordered = ordered)
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgL, ::MutSample;
ordered = false)
value = Vector{T}(undef, n)
if ordered
return SampleMultiOrdAlgL(0.0, 0, 0, rng, value, collect(1:n))
else
return SampleMultiAlgL(0.0, 0, 0, rng, value)
end
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgR; ordered = false)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgR, ::MutSample;
ordered = false)
value = Vector{T}(undef, n)
if ordered
return SampleMultiOrdAlgR(0, rng, value, collect(1:n))
else
return SampleMultiAlgR(0, rng, value)
end
end
function ReservoirSample(rng::AbstractRNG, T, n::Integer, method::AlgRSWRSKIP; ordered = false)
function ReservoirSample(rng::AbstractRNG, T, n::Integer, ::AlgRSWRSKIP, ::MutSample;
ordered = false)
value = Vector{T}(undef, n)
if ordered
return SampleMultiOrdAlgRSWRSKIP(0, 0, rng, value, collect(1:n))
Expand Down Expand Up @@ -139,28 +147,28 @@ end
end

function update_state!(s::Union{SampleMultiAlgR, SampleMultiOrdAlgR})
@imm_reset s.seen_k += 1
s.seen_k += 1
return s
end
function update_state!(s::Union{SampleMultiAlgL, SampleMultiOrdAlgL})
@imm_reset s.seen_k += 1
@imm_reset s.skip_k -= 1
s.seen_k += 1
s.skip_k -= 1
return s
end
function update_state!(s::AbstractWrReservoirSampleMulti)
@imm_reset s.seen_k += 1
@imm_reset s.skip_k -= 1
s.seen_k += 1
s.skip_k -= 1
return s
end

function recompute_skip!(s::AbstractWorReservoirSampleMulti, n)
@imm_reset s.state += randexp(s.rng)
@imm_reset s.skip_k = -ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n)))
s.state += randexp(s.rng)
s.skip_k = -ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n)))
return s
end
function recompute_skip!(s::AbstractWrReservoirSampleMulti, n)
q = rand(s.rng)^(1/n)
@imm_reset s.skip_k = ceil(Int, s.seen_k/q - s.seen_k - 1)
s.skip_k = ceil(Int, s.seen_k/q - s.seen_k - 1)
return s
end

Expand Down Expand Up @@ -274,22 +282,23 @@ function ordered_value(s::AbstractOrdWrReservoirSampleMulti)
end
end

function itsample(iter, n::Int, method::ReservoirAlgorithm = algL; ordered = false)
function itsample(iter, n::Int, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter), ordered = false)
return itsample(Random.default_rng(), iter, n, method; ordered)
end
function itsample(rng::AbstractRNG, iter, n::Int, method::ReservoirAlgorithm = algL; ordered = false)
iter_type = calculate_eltype(iter)
function itsample(rng::AbstractRNG, iter, n::Int, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter), ordered = false)
if Base.IteratorSize(iter) isa Base.SizeUnknown
reservoir_sample(rng, iter, n, method; ordered)::Vector{iter_type}
reservoir_sample(rng, iter, n, method; iter_type, ordered)::Vector{iter_type}
else
replace = method isa AlgL || method isa AlgR ? false : true
sortedindices_sample(rng, iter, n; replace, ordered)::Vector{iter_type}
sortedindices_sample(rng, iter, n; iter_type, replace, ordered)::Vector{iter_type}
end
end

function reservoir_sample(rng, iter, n::Int, method::ReservoirAlgorithm = algL; ordered = false)
iter_type = calculate_eltype(iter)
s = ReservoirSample(rng, iter_type, n, method; ordered = ordered)
function reservoir_sample(rng, iter, n::Int, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter), ordered = false)
s = ReservoirSample(rng, iter_type, n, method, ms; ordered = ordered)
return update_all!(s, iter, ordered)
end

Expand All @@ -300,7 +309,7 @@ function update_all!(s, iter, ordered)
return ordered ? ordered_value(s) : shuffle!(s.rng, value(s))
end

function calculate_eltype(iter)
T = eltype(iter)
return T === Any ? Base.@default_eltype(iter) : T
function infer_eltype(itr)
T1, T2 = eltype(itr), Base.@default_eltype(itr)
ifelse(T2 !== Union{} && T2 <: T1, T2, T1)
end
15 changes: 9 additions & 6 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,22 @@ function Base.merge!(s1::SampleSingleAlgR, s2::AbstractReservoirSampleSingle)
return s1
end

function itsample(iter, method::ReservoirAlgorithm = algL)
return itsample(Random.default_rng(), iter, method)
function itsample(iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
return itsample(Random.default_rng(), iter, method; iter_type)
end
function itsample(rng::AbstractRNG, iter, method::ReservoirAlgorithm = algL)
function itsample(rng::AbstractRNG, iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
if Base.IteratorSize(iter) isa Base.SizeUnknown
return reservoir_sample(rng, iter, method)
return reservoir_sample(rng, iter, method; iter_type)
else
return sortedindices_sample(rng, iter)
end
end

function reservoir_sample(rng, iter, method::ReservoirAlgorithm = algL)
s = ReservoirSample(rng, calculate_eltype(iter), method)
function reservoir_sample(rng, iter, method::ReservoirAlgorithm = algL;
iter_type = infer_eltype(iter))
s = ReservoirSample(rng, iter_type, method)
for x in iter
s = update!(s, x)
end
Expand Down
Loading

0 comments on commit ce75f40

Please sign in to comment.