Skip to content

Commit

Permalink
Fix sampling of CircularPrioritizedTraces
Browse files Browse the repository at this point in the history
  • Loading branch information
Dharanish committed May 9, 2024
1 parent 0ec5743 commit 73f2efb
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
t = e.traces
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
w .*= e.sampleable_inds[1:end-1]
w .*= e.sampleable_inds[1:length(t)]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
end
Expand Down Expand Up @@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
w .*= valids[1:length(t)]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=p[inds]),
Expand Down Expand Up @@ -362,7 +362,7 @@ function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <
p = collect(deepcopy(t.priorities))
w = StatsBase.FrequencyWeights(p)
valids, ns = valid_range(s,e)
w .*= valids[1:end-1]
w .*= valids[1:length(t)]
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
merge(
(key=t.keys[inds], priority=p[inds]),
Expand Down

0 comments on commit 73f2efb

Please sign in to comment.