Skip to content

Commit

Permalink
Update test for NStepBatchSampler
Browse files Browse the repository at this point in the history
  • Loading branch information
ludvigk committed Sep 27, 2024
1 parent bfc7610 commit 26c24f0
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,42 +87,42 @@ import ReinforcementLearningTrajectories.fetch
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
end
push!(eb, (state = 7, action = 7))
for (j,i) = enumerate(8:11)
for (j,i) = enumerate(8:12)
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
end
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
@test weights == [0,1,1,1,1,0,0,1,1,1,0]
@test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode
@test weights == [0,1,1,1,0,0,1,1,1,0,0]
@test ns == [3,3,2,1,-1,3,3,3,2,1,0] #the -1 is due to ep_lengths[5] being that of 2nd episode but step_numbers[6] being that of 1st episode
inds = [i for i in eachindex(weights) if weights[i] == 1]
batch = sample(s1, eb)
for key in keys(eb)
@test haskey(batch, key)
end
#state: samples with stacksize
states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds])
@test states == [1 2 3 4 7 8 9;
2 3 4 5 8 9 10]
@test states == [1 2 3 6 7 8;
2 3 4 7 8 9]
@test all(in(eachcol(states)), unique(eachcol(batch[:state])))
#next_state: samples with stacksize and nsteps forward
next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds])
@test next_states == [4 5 5 5 10 10 10;
5 6 6 6 11 11 11]
@test next_states == [4 4 4 9 10 10;
5 5 5 10 11 11]
@test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state])))
#action: samples normally
actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds])
@test actions == inds
@test actions == [3, 4, 5, 8, 9, 10]
@test all(in(actions), unique(batch[:action]))
#next_action: is a multiplex trace: should automatically sample nsteps forward
next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds])
@test next_actions == [5, 6, 6, 6, 11, 11, 11]
@test next_actions == [6, 6, 6, 11, 12, 12]
@test all(in(next_actions), unique(batch[:next_action]))
#reward: discounted sum
rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds])
@test rewards [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10]
@test rewards [2+0.99*3+0.99^2*4, 3+0.99*4, 4, 7+0.99*8+0.99^2*9, 8+0.99*9+0.99^2*10,9+0.99*10]
@test all(in(rewards), unique(batch[:reward]))
#terminal: nsteps forward
terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds])
@test terminals == [0,1,1,1,0,0,0]
@test terminals == [0,0,0,0,0,0]

### CircularPrioritizedTraces and NStepBatchSampler
γ = 0.99
Expand Down

0 comments on commit 26c24f0

Please sign in to comment.