Skip to content

Commit

Permalink
Precompile functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Tortar committed Apr 25, 2024
1 parent 99d5181 commit 491adf5
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 43 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ uuid = "ff63dad9-3335-55d8-95ec-f8139d39e468"
version = "0.3.2"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
julia = "1.8"
Accessors = "0.1"
DataStructures = "0.18"
Distributions = "0.25"
PrecompileTools = "1"
Random = "1"
StatsBase = "0.32, 0.33, 0.34"
16 changes: 15 additions & 1 deletion src/StreamSampling.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module StreamSampling

using Accessors
using DataStructures
using Distributions
using Random
Expand Down Expand Up @@ -42,6 +43,19 @@ struct AlgARes <: ReservoirAlgorithm end
struct AlgAExpJ <: ReservoirAlgorithm end
struct AlgWRSWRSKIP <: ReservoirAlgorithm end


macro imm_reset(e)
s = e.args[1].args[1]
esc(quote
if ismutabletype(typeof($s))
$e
else
StreamSampling.Accessors.@reset $e
end
end)
end


"""
Implements random sampling without replacement.
Expand Down Expand Up @@ -96,7 +110,7 @@ include("UnweightedSamplingSingle.jl")
include("UnweightedSamplingMulti.jl")
include("WeightedSamplingSingle.jl")
include("WeightedSamplingMulti.jl")

include("precompile.jl")

"""
Expand Down
45 changes: 31 additions & 14 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end

@inline function update!(s::Union{SampleMultiAlgR, SampleMultiOrdAlgR}, el)
n = length(s.value)
s.seen_k += 1
s = @inline update_state!(s)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
else
Expand All @@ -88,27 +88,27 @@ end
end
@inline function update!(s::Union{SampleMultiAlgL, SampleMultiOrdAlgL}, el)
n = length(s.value)
s.seen_k += 1
s.skip_k -= 1
s = @inline update_state!(s)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
s.seen_k == n && @inline recompute_skip!(s, n)
if s.seen_k == n
s = @inline recompute_skip!(s, n)
end
elseif s.skip_k < 0
j = rand(s.rng, 1:n)
@inbounds s.value[j] = el
update_order!(s, j)
@inline recompute_skip!(s, n)
s = @inline recompute_skip!(s, n)
end
return s
end
@inline function update!(s::AbstractWrReservoirSampleMulti, el)
n = length(s.value)
s.seen_k += 1
s.skip_k -= 1
s = @inline update_state!(s)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
if s.seen_k == n
recompute_skip!(s, n)
s = recompute_skip!(s, n)
new_values = sample(s.rng, s.value, n, ordered=is_ordered(s))
@inbounds for i in 1:n
s.value[i] = new_values[i]
Expand All @@ -133,18 +133,35 @@ end
end
end
end
recompute_skip!(s, n)
s = recompute_skip!(s, n)
end
return s
end

function update_state!(s::Union{SampleMultiAlgR, SampleMultiOrdAlgR})
@imm_reset s.seen_k += 1
return s
end
function update_state!(s::Union{SampleMultiAlgL, SampleMultiOrdAlgL})
@imm_reset s.seen_k += 1
@imm_reset s.skip_k -= 1
return s
end
function update_state!(s::AbstractWrReservoirSampleMulti)
@imm_reset s.seen_k += 1
@imm_reset s.skip_k -= 1
return s
end

function recompute_skip!(s::AbstractWorReservoirSampleMulti, n)
s.state += randexp(s.rng)
s.skip_k = -ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n)))
@imm_reset s.state += randexp(s.rng)
@imm_reset s.skip_k = -ceil(Int, randexp(s.rng)/log(1-exp(-s.state/n)))
return s
end
function recompute_skip!(s::AbstractWrReservoirSampleMulti, n)
q = rand(s.rng)^(1/n)
s.skip_k = ceil(Int, s.seen_k/q - s.seen_k - 1)
@imm_reset s.skip_k = ceil(Int, s.seen_k/q - s.seen_k - 1)
return s
end

function choose(n, p, q, z)
Expand Down Expand Up @@ -196,7 +213,7 @@ function Base.merge!(s1::SampleMultiAlgRSWRSKIP, s2::AbstractWrReservoirSampleMu
shuffle!(s2.rng, s2.value)
n_tot = n1 + n2
p = n2 / n_tot
s1 = merge_res_vec!(s1, s2, p, len1, n_tot)
merge_res_vec!(s1, s2, p, len1, n_tot)
recompute_skip!(s1, len1)
return s1
end
Expand Down Expand Up @@ -278,7 +295,7 @@ end

function update_all!(s, iter, ordered)
for x in iter
update!(s, x)
s = update!(s, x)
end
return ordered ? ordered_value(s) : shuffle!(s.rng, value(s))
end
Expand Down
14 changes: 7 additions & 7 deletions src/UnweightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,20 @@ function ReservoirSample(rng::AbstractRNG, T, method::AlgR)
end

@inline function update!(s::SampleSingleAlgR, el)
s.seen_k += 1
@imm_reset s.seen_k += 1
if rand(s.rng) <= 1/s.seen_k
s.value = el
@imm_reset s.value = el
end
return s
end
@inline function update!(s::SampleSingleAlgL, el)
s.seen_k += 1
if s.skip_k > 0
s.skip_k -= 1
@imm_reset s.skip_k -= 1
else
s.value = el
s.state *= rand(s.rng)
s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
@imm_reset s.value = el
@imm_reset s.state *= rand(s.rng)
@imm_reset s.skip_k = -ceil(Int, randexp(s.rng)/log(1-s.state))
end
return s
end
Expand Down Expand Up @@ -87,7 +87,7 @@ end
function reservoir_sample(rng, iter, method::ReservoirAlgorithm = algL)
s = ReservoirSample(rng, calculate_eltype(iter), method)
for x in iter
update!(s, x)
s = update!(s, x)
end
return value(s)
end
43 changes: 30 additions & 13 deletions src/WeightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end

@inline function update!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, el, w)
n = s.n
s.seen_k += 1
s = @inline update_state!(s, w)
priority = -randexp(s.rng)/w
if s.seen_k <= n
push_value!(s, el, priority)
Expand All @@ -100,24 +100,24 @@ end
end
@inline function update!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, el, w)
n = s.n
s.seen_k += 1
s.state -= w
s = @inline update_state!(s, w)
if s.seen_k <= n
priority = exp(-randexp(s.rng)/w)
push_value!(s, el, priority)
s.seen_k == n && @inline recompute_skip!(s)
if s.seen_k == n
s = @inline recompute_skip!(s)
end
elseif s.state <= 0.0
priority = @inline compute_skip_priority(s, w)
pop!(s.value)
push_value!(s, el, priority)
@inline recompute_skip!(s)
s = @inline recompute_skip!(s)
end
return s
end
@inline function update!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, el, w)
n = length(s.value)
s.seen_k += 1
s.state += w
s = @inline update_state!(s, w)
if s.seen_k <= n
@inbounds s.value[s.seen_k] = el
@inbounds s.weights[s.seen_k] = w
Expand All @@ -126,7 +126,7 @@ end
@inbounds for i in 1:n
s.value[i] = new_values[i]
end
@inline recompute_skip!(s, n)
s = @inline recompute_skip!(s, n)
empty!(s.weights)
end
elseif s.skip_w <= s.state
Expand All @@ -148,23 +148,40 @@ end
end
end
end
@inline recompute_skip!(s, n)
s = @inline recompute_skip!(s, n)
end
return s
end

function update_state!(s::Union{SampleMultiAlgARes, SampleMultiOrdAlgARes}, w)
@imm_reset s.seen_k += 1
return s
end
function update_state!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ}, w)
@imm_reset s.seen_k += 1
@imm_reset s.state -= w
return s
end
function update_state!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, w)
@imm_reset s.seen_k += 1
@imm_reset s.state += w
return s
end

function compute_skip_priority(s, w)
t = exp(log(s.min_priority)*w)
return exp(log(rand(s.rng, Uniform(t,1)))/w)
end

function recompute_skip!(s::Union{SampleMultiAlgAExpJ, SampleMultiOrdAlgAExpJ})
s.min_priority = last(first(s.value))
s.state = -randexp(s.rng)/log(s.min_priority)
@imm_reset s.min_priority = last(first(s.value))
@imm_reset s.state = -randexp(s.rng)/log(s.min_priority)
return s
end
function recompute_skip!(s::Union{SampleMultiAlgWRSWRSKIP, SampleMultiOrdAlgWRSWRSKIP}, n)
q = rand(s.rng)^(1/n)
s.skip_w = s.state/q
@imm_reset s.skip_w = s.state/q
return s
end

function push_value!(s::Union{SampleMultiAlgARes, SampleMultiAlgAExpJ}, el, priority)
Expand Down Expand Up @@ -238,7 +255,7 @@ end

function update_all!(s, iter, wv, ordered)
for x in iter
update!(s, x, wv(x))
s = update!(s, x, wv(x))
end
return ordered ? ordered_value(s) : shuffle!(s.rng, value(s))
end
12 changes: 6 additions & 6 deletions src/WeightedSamplingSingle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ end
@inline function update!(s::SampleSingleAlgARes, el, w)
priority = -randexp(s.rng)/w
if priority > s.state
s.state = priority
s.value = el
@imm_reset s.state = priority
@imm_reset s.value = el
end
return s
end
@inline function update!(s::SampleSingleAlgAExpJ, el, weight)
s.state += weight
@imm_reset s.state += weight
if s.skip_w <= s.state
s.value = el
s.skip_w = s.state/rand(s.rng)
@imm_reset s.value = el
@imm_reset s.skip_w = s.state/rand(s.rng)
end
return s
end
Expand All @@ -50,7 +50,7 @@ end
function itsample(rng::AbstractRNG, iter, wv::Function, method::ReservoirAlgorithm = algAExpJ)
s = ReservoirSample(rng, calculate_eltype(iter), algAExpJ)
for x in iter
update!(s, x, wv(x))
s = update!(s, x, wv(x))
end
return value(s)
end
39 changes: 39 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

using PrecompileTools

@setup_workload begin
@compile_workload begin
iter = Iterators.filter(x -> x != 10, 1:20);
wv(el) = 1.0
rs = ReservoirSample(Int, algR)
for x in iter update!(rs, x) end
rs = ReservoirSample(Int, algL)
for x in iter update!(rs, x) end
rs = ReservoirSample(Int, 2, algR)
for x in iter update!(rs, x) end
rs = ReservoirSample(Int, 2, algL)
for x in iter update!(rs, x) end
rs = ReservoirSample(Int, 2, algRSWRSKIP)
for x in iter update!(rs, x) end
rs = ReservoirSample(Int, algARes)
for x in iter update!(rs, x, wv(x)) end
rs = ReservoirSample(Int, algAExpJ)
for x in iter update!(rs, x, wv(x)) end
rs = ReservoirSample(Int, 2, algARes)
for x in iter update!(rs, x, wv(x)) end
rs = ReservoirSample(Int, 2, algAExpJ)
for x in iter update!(rs, x, wv(x)) end
rs = ReservoirSample(Int, 2, algWRSWRSKIP)
for x in iter update!(rs, x, wv(x)) end
itsample(iter, algR)
itsample(iter, algL)
itsample(iter, 2, algR)
itsample(iter, 2, algL)
itsample(iter, 2, algRSWRSKIP)
itsample(iter, wv, algARes)
itsample(iter, wv, algAExpJ)
itsample(iter, wv, 2, algARes)
itsample(iter, wv, 2, algAExpJ)
itsample(iter, wv, 2, algWRSWRSKIP)
end
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@

using StreamSampling

using Distributions
using HypothesisTests
using Random
using StableRNGs
using Test

using StreamSampling

@testset "StreamSampling.jl Tests" begin
include("package_sanity_tests.jl")
include("unweighted_sampling_single_tests.jl")
Expand Down

0 comments on commit 491adf5

Please sign in to comment.