diff --git a/src/samplers.jl b/src/samplers.jl index 8701189..e5443a7 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -74,7 +74,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir t = e.traces p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) - w .*= e.sampleable_inds[1:end-1] + w .*= e.sampleable_inds[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...)) end @@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) valids, ns = valid_range(s,e) - w .*= valids[1:end-1] + w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) merge( (key=t.keys[inds], priority=p[inds]), @@ -362,7 +362,7 @@ function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, < p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) valids, ns = valid_range(s,e) - w .*= valids[1:end-1] + w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) merge( (key=t.keys[inds], priority=p[inds]),