From 26c24f0fc47a698b8a7a823dee7d719be4d275bd Mon Sep 17 00:00:00 2001 From: ludvigk Date: Fri, 27 Sep 2024 14:04:37 +0200 Subject: [PATCH] Update test for NStepBatchSampler --- test/samplers.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/samplers.jl b/test/samplers.jl index 2565ddd..69a6449 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -87,12 +87,12 @@ import ReinforcementLearningTrajectories.fetch push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) end push!(eb, (state = 7, action = 7)) - for (j,i) = enumerate(8:11) + for (j,i) = enumerate(8:12) push!(eb, (state = i, action =i, reward = i-1, terminal = false)) end weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) - @test weights == [0,1,1,1,1,0,0,1,1,1,0] - @test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode + @test weights == [0,1,1,1,0,0,1,1,1,0,0] + @test ns == [3,3,2,1,-1,3,3,3,2,1,0] #the -1 is due to ep_lengths[5] being that of 2nd episode but step_numbers[6] being that of 1st episode inds = [i for i in eachindex(weights) if weights[i] == 1] batch = sample(s1, eb) for key in keys(eb) @@ -100,29 +100,29 @@ import ReinforcementLearningTrajectories.fetch end #state: samples with stacksize states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds]) - @test states == [1 2 3 4 7 8 9; - 2 3 4 5 8 9 10] + @test states == [1 2 3 6 7 8; + 2 3 4 7 8 9] @test all(in(eachcol(states)), unique(eachcol(batch[:state]))) #next_state: samples with stacksize and nsteps forward next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds]) - @test next_states == [4 5 5 5 10 10 10; - 5 6 6 6 11 11 11] + @test next_states == [4 4 4 9 10 10; + 5 5 5 10 11 11] @test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state]))) #action: samples normally actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds]) - @test actions == inds + @test actions == [3, 4, 5, 8, 9, 10] @test all(in(actions), unique(batch[:action])) #next_action: is a multiplex trace: should automatically sample nsteps forward next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds]) - @test next_actions == [5, 6, 6, 6, 11, 11, 11] + @test next_actions == [6, 6, 6, 11, 12, 12] @test all(in(next_actions), unique(batch[:next_action])) #reward: discounted sum rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds]) - @test rewards ≈ [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10] + @test rewards ≈ [2+0.99*3+0.99^2*4, 3+0.99*4, 4, 7+0.99*8+0.99^2*9, 8+0.99*9+0.99^2*10,9+0.99*10] @test all(in(rewards), unique(batch[:reward])) #terminal: nsteps forward terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds]) - @test terminals == [0,1,1,1,0,0,0] + @test terminals == [0,0,0,0,0,0] ### CircularPrioritizedTraces and NStepBatchSampler γ = 0.99