Skip to content

Commit

Permalink
Fix test of CircularPrioritizedTraces with SARTSA
Browse files Browse the repository at this point in the history
The usage of SARTSA traces is more restrictive and should be done in this way
  • Loading branch information
Dharanish committed May 9, 2024
1 parent fbf054a commit 320e3f8
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,17 @@ import ReinforcementLearningTrajectories.fetch
batchsize = 4
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize)

push!(eb, (state = 1, action = 1))
push!(eb, (state = 1,))
for i = 1:5
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5))
end
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
push!(eb, PartialNamedTuple((action=6,)))
push!(eb, (state = 7,))
for (j,i) = enumerate(7:10)
push!(eb, (state = i+1, action =i, reward = i, terminal = i==10))
end
push!(eb, PartialNamedTuple((action = 11,)))
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
inds = [i for i in eachindex(weights) if weights[i] == 1]
batch = sample(s1, eb)
Expand Down

0 comments on commit 320e3f8

Please sign in to comment.