Skip to content

Commit

Permalink
Fix offline agent test (#1025)
Browse files Browse the repository at this point in the history
  • Loading branch information
joelreymont authored Mar 6, 2024
1 parent 1d3c7da commit d8af17f
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions src/ReinforcementLearningCore/test/policies/agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import ReinforcementLearningCore.SRT
a_1 = Agent(
RandomPolicy(),
Trajectory(
CircularArraySARTSTraces(; capacity = 1_000),
CircularArraySARTSTraces(; capacity=1_000),
DummySampler(),
),
)
a_2 = Agent(
RandomPolicy(),
Trajectory(
CircularArraySARTSTraces(; capacity = 1_000),
CircularArraySARTSTraces(; capacity=1_000),
BatchSampler(1),
InsertSampleRatioController(),
),
Expand All @@ -26,25 +26,25 @@ import ReinforcementLearningCore.SRT
env = RandomWalk1D()
push!(agent, PreEpisodeStage(), env)
action = RLBase.plan!(agent, env)
@test action in (1,2)
@test length(agent.trajectory.container) == 0
@test action in (1, 2)
@test length(agent.trajectory.container) == 0
push!(agent, PostActStage(), env, action)
push!(agent, PreActStage(), env)
@test RLBase.plan!(agent, env) in (1,2)
@test RLBase.plan!(agent, env) in (1, 2)
@test length(agent.trajectory.container) == 1

#The following tests checks args / kwargs passed to policy cause an error
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, 1)
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, fake_kwarg = 1)
@test_throws "MethodError: no method matching plan!(::Agent{RandomPolicy" RLBase.plan!(agent, env, fake_kwarg=1)
end
end
end
@testset "OfflineAgent" begin
env = RandomWalk1D()
a_1 = OfflineAgent(
policy = RandomPolicy(),
trajectory = Trajectory(
CircularArraySARTSTraces(; capacity = 1_000),
policy=RandomPolicy(),
trajectory=Trajectory(
CircularArraySARTSTraces(; capacity=1_000),
DummySampler(),
),
)
Expand All @@ -53,27 +53,35 @@ import ReinforcementLearningCore.SRT
@test isempty(a_1.trajectory.container)

trajectory = Trajectory(
CircularArraySARTSTraces(; capacity = 1_000),
DummySampler(),
)
CircularArraySARTSTraces(; capacity=1_000),
DummySampler(),
)

a_2 = OfflineAgent(
policy = RandomPolicy(),
trajectory = trajectory,
offline_behavior = OfflineBehavior(
policy=RandomPolicy(),
trajectory=trajectory,
offline_behavior=OfflineBehavior(
Agent(RandomPolicy(), trajectory),
steps = 5,
steps=5,
)
)
push!(a_2, PreExperimentStage(), env)
@test length(a_2.trajectory.container) == 5
# We'll have 1 extra element where terminal is true
# if the environment was terminated mid-episode and restarted!
ix = findfirst(x -> x.terminal, map(identity, a_2.trajectory.container))
len = length(a_2.trajectory.container)
max = isnothing(ix) || ix == len ? 5 : 6
@test len == max

for agent in [a_1, a_2]
action = RLBase.plan!(agent, env)
@test action in (1,2)
@test action in (1, 2)
for stage in [PreEpisodeStage(), PreActStage(), PostActStage(), PostEpisodeStage()]
push!(agent, stage, env)
@test length(agent.trajectory.container) in (0,5)
ix = findfirst(x -> x.terminal, map(identity, agent.trajectory.container))
len = length(agent.trajectory.container)
max = isnothing(ix) || ix == len ? 5 : 6
@test len in (0, max)
end
end
end
Expand Down

0 comments on commit d8af17f

Please sign in to comment.