Skip to content

Commit d10fae7

Browse files
Fix reset and stop conditions (#1046)
* use consistent, in place api for reset and stop * Fix reset naming * fix duplicated environments * fix multiagent * drop parentheses * simplify logic * fix syntax * fix calls * syntax * syntax * fix reset tests * Fix naming per PR comments * Update stop condition names to use "NSteps" instead of "Step" --------- Co-authored-by: Jeremiah Lewis <--get>
1 parent 4c935e7 commit d10fae7

File tree

26 files changed

+135
-415
lines changed

26 files changed

+135
-415
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ julia> using ReinforcementLearning
4545
julia> run(
4646
RandomPolicy(),
4747
CartPoleEnv(),
48-
StopAfterStep(1_000),
48+
StopAfterNSteps(1_000),
4949
TotalRewardPerEpisode()
5050
)
5151
```
@@ -66,7 +66,7 @@ reinforcement learning experiment:
6666
to test reinforcement learning algorithms.
6767

6868
- **Stop Condition**. The
69-
[`StopAfterStep(1_000)`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.StopAfterStep)
69+
[`StopAfterNSteps(1_000)`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.StopAfterNSteps)
7070
is to inform that our experiment should stop after
7171
`1_000` steps.
7272

docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html

+10-10
Original file line numberDiff line numberDiff line change
@@ -13872,7 +13872,7 @@ <h2 id="Q2:-What-if-we-want-to-stop-after-several-episodes?"><strong>Q2: What if
1387213872
<div class="prompt input_prompt">In&nbsp;[15]:</div>
1387313873
<div class="inner_cell">
1387413874
<div class="input_area">
13875-
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">))</span>
13875+
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">))</span>
1387613876
</pre></div>
1387713877

1387813878
</div>
@@ -13906,8 +13906,8 @@ <h2 id="Q2:-What-if-we-want-to-stop-after-several-episodes?"><strong>Q2: What if
1390613906
<h3 id="Q2.b:-What-if-we-want-to-stop-until-arbitrary-condition-meets?"><strong>Q2.b: What if we want to stop until arbitrary condition meets?</strong><a class="anchor-link" href="#Q2.b:-What-if-we-want-to-stop-until-arbitrary-condition-meets?">&#182;</a></h3><p>Well, in that case, you need to implement your customized <em>stop condition</em> here.
1390713907
In RL.jl, several common ones are already provided, like:</p>
1390813908
<ul>
13909-
<li><code>StopAfterStep</code></li>
13910-
<li><code>StopAfterEpisode</code></li>
13909+
<li><code>StopAfterNSteps</code></li>
13910+
<li><code>StopAfterNEpisodes</code></li>
1391113911
<li><code>StopAfterNSeconds</code></li>
1391213912
<li>...</li>
1391313913
</ul>
@@ -13936,7 +13936,7 @@ <h2 id="Q3:-How-to-collect-experiment-results?"><strong>Q3: How to collect exper
1393613936
<div class="prompt input_prompt">In&nbsp;[16]:</div>
1393713937
<div class="inner_cell">
1393813938
<div class="input_area">
13939-
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
13939+
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
1394013940
</pre></div>
1394113941

1394213942
</div>
@@ -14144,7 +14144,7 @@ <h2 id="The-Actor-Mode">The <em>Actor</em> Mode<a class="anchor-link" href="#The
1414414144
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span>
1414514145
<span class="n">policy</span><span class="p">,</span>
1414614146
<span class="n">RandomWalk1D</span><span class="p">(),</span>
14147-
<span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
14147+
<span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
1414814148
<span class="n">TotalRewardPerEpisode</span><span class="p">()</span>
1414914149
<span class="p">)</span>
1415014150
</pre></div>
@@ -14311,7 +14311,7 @@ <h2 id="The-Training-Mode">The <em>Training</em> Mode<a class="anchor-link" href
1431114311
<div class="prompt input_prompt">In&nbsp;[22]:</div>
1431214312
<div class="inner_cell">
1431314313
<div class="input_area">
14314-
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
14314+
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
1431514315
</pre></div>
1431614316

1431714317
</div>
@@ -14383,7 +14383,7 @@ <h2 id="The-Training-Mode">The <em>Training</em> Mode<a class="anchor-link" href
1438314383
<div class="inner_cell">
1438414384
<div class="input_area">
1438514385
<div class=" highlight hl-julia"><pre><span></span><span class="n">hook</span> <span class="o">=</span> <span class="n">StepsPerEpisode</span><span class="p">()</span>
14386-
<span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">hook</span><span class="p">)</span>
14386+
<span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">hook</span><span class="p">)</span>
1438714387
<span class="n">plot</span><span class="p">(</span><span class="n">hook</span><span class="o">.</span><span class="n">steps</span><span class="p">[</span><span class="mi">1</span><span class="o">:</span><span class="k">end</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
1438814388
</pre></div>
1438914389

@@ -14593,7 +14593,7 @@ <h3 id="Q4:-Why-does-it-need-more-than-3-steps-to-reach-our-goal?">Q4: Why does
1459314593
<span class="n">explorer</span><span class="o">=</span><span class="n">GreedyExplorer</span><span class="p">()</span>
1459414594
<span class="p">),</span>
1459514595
<span class="n">env</span><span class="p">,</span>
14596-
<span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
14596+
<span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
1459714597
<span class="n">hook</span>
1459814598
<span class="p">)</span>
1459914599
<span class="n">plot</span><span class="p">(</span><span class="n">hook</span><span class="o">.</span><span class="n">steps</span><span class="p">[</span><span class="mi">1</span><span class="o">:</span><span class="k">end</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
@@ -15158,7 +15158,7 @@ <h2 id="Two-Most-Commonly-Used-Algorithms">Two Most Commonly Used Algorithms<a c
1515815158
<div class="prompt input_prompt">In&nbsp;[35]:</div>
1515915159
<div class="inner_cell">
1516015160
<div class="input_area">
15161-
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterStep</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
15161+
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterNSteps</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
1516215162
<span class="n">hook</span> <span class="o">=</span> <span class="n">TotalRewardPerEpisode</span><span class="p">()</span>
1516315163
<span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">stop_condition</span><span class="p">,</span> <span class="n">hook</span><span class="p">)</span>
1516415164
</pre></div>
@@ -15306,7 +15306,7 @@ <h2 id="Two-Most-Commonly-Used-Algorithms">Two Most Commonly Used Algorithms<a c
1530615306
<div class="prompt input_prompt">In&nbsp;[38]:</div>
1530715307
<div class="inner_cell">
1530815308
<div class="input_area">
15309-
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterStep</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
15309+
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterNSteps</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
1531015310
<span class="n">hook</span> <span class="o">=</span> <span class="n">TotalBatchRewardPerEpisode</span><span class="p">(</span><span class="n">N_ENV</span><span class="p">)</span>
1531115311
<span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">stop_condition</span><span class="p">,</span> <span class="n">hook</span><span class="p">)</span>
1531215312
</pre></div>

docs/homepage/blog/ospp_report_210370190/index.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ Besides, I implement the [`EDManager`](https://juliareinforcementlearning.org/do
656656
function Base.run(
657657
π::EDManager,
658658
env::AbstractEnv,
659-
stop_condition = StopAfterEpisode(1),
659+
stop_condition = StopAfterNEpisodes(1),
660660
hook::AbstractHook = EmptyHook(),
661661
)
662662
@assert NumAgentStyle(env) == MultiAgent(2) "ED algorithm only support 2-players games."
@@ -757,7 +757,7 @@ EDmanager = EDManager(
757757
)
758758
)
759759
# initialize the `stop_condition` and `hook`.
760-
stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI"))
760+
stop_condition = StopAfterNEpisodes(100_000, is_show_progress=!haskey(ENV, "CI"))
761761
hook = KuhnOpenNewEDHook(0, 100, [], [])
762762
```
763763

docs/src/How_to_use_hooks.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868
(h::TimeCostPerEpisode)(::PreEpisodeStage, policy, env) = h.t = time_ns()
6969
(h::TimeCostPerEpisode)(::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t)
7070
h = TimeCostPerEpisode()
71-
run(RandomPolicy(), CartPoleEnv(), StopAfterEpisode(10), h)
71+
run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(10), h)
7272
h.time_costs
7373
```
7474

@@ -97,7 +97,7 @@ policy = RandomPolicy()
9797
run(
9898
policy,
9999
CartPoleEnv(),
100-
StopAfterEpisode(100),
100+
StopAfterNEpisodes(100),
101101
DoEveryNEpisode(;n=10) do t, policy, env
102102
# In real world cases, the policy is usually wrapped in an Agent,
103103
# we need to extract the inner policy to run it in the *actor* mode.
@@ -107,7 +107,7 @@ run(
107107
# polluting the original env.
108108
109109
hook = TotalRewardPerEpisode(;is_display_on_exit=false)
110-
run(policy, CartPoleEnv(), StopAfterEpisode(10), hook)
110+
run(policy, CartPoleEnv(), StopAfterNEpisodes(10), hook)
111111
112112
# now you can report the result of the hook.
113113
println("avg reward at episode $t is: $(mean(hook.rewards))")
@@ -159,7 +159,7 @@ parameters_dir = mktempdir()
159159
run(
160160
policy,
161161
env,
162-
StopAfterStep(10_000),
162+
StopAfterNSteps(10_000),
163163
DoEveryNStep(n=1_000) do t, p, e
164164
ps = params(p)
165165
f = joinpath(parameters_dir, "parameters_at_step_$t.bson")
@@ -192,7 +192,7 @@ hook = ComposedHook(
192192
end
193193
end
194194
)
195-
run(RandomPolicy(), CartPoleEnv(), StopAfterEpisode(50), hook)
195+
run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(50), hook)
196196
readdir(tf_log_dir)
197197
```
198198

docs/src/How_to_write_a_customized_environment.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ ReinforcementLearning.jl also work. Similar to the test above, let's try the
117117
[`RandomPolicy`](@ref) first:
118118

119119
```@repl customized_env
120-
run(RandomPolicy(action_space(env)), env, StopAfterEpisode(1_000))
120+
run(RandomPolicy(action_space(env)), env, StopAfterNEpisodes(1_000))
121121
```
122122

123123
If no error shows up, then it means our environment at least works with
@@ -126,7 +126,7 @@ episode to see the performance of the `RandomPolicy`.
126126

127127
```@repl customized_env
128128
hook = TotalRewardPerEpisode()
129-
run(RandomPolicy(action_space(env)), env, StopAfterEpisode(1_000), hook)
129+
run(RandomPolicy(action_space(env)), env, StopAfterNEpisodes(1_000), hook)
130130
using Plots
131131
pyplot() #hide
132132
plot(hook.rewards)
@@ -198,7 +198,7 @@ Nice job! Now we are ready to run the experiment:
198198

199199
```@repl customized_env
200200
h = TotalRewardPerEpisode()
201-
run(p, wrapped_env, StopAfterEpisode(1_000), h)
201+
run(p, wrapped_env, StopAfterNEpisodes(1_000), h)
202202
plot(h.rewards)
203203
savefig("custom_env_random_policy_reward_wrapped_env.svg"); nothing # hide
204204
```

docs/src/non_episodic.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Using this means that the value of the terminal state is set to 0 when learning
99

1010
Also called _Continuing tasks_ (Sutton & Barto, 2018), non-episodic environment do not have a terminal state and thus may run for ever, or until the `stop_condition` is reached. Sometimes however, one may want to periodically reset the environment to start fresh. A first possibility is to implement `RLBase.is_terminated(::YourEnvironment)` to reset according to an arbitrary condition. However this may not be a good idea because the value of the last state (note that it is not a _terminal_ state) will be bootstrapped to 0 during learning, even though it is not the true value of the state.
1111

12-
To manage this, we provide the `ResetAfterNSteps(n)` condition as an argument to `run(policy, env, stop_condition, hook, reset_condition = ResetAtTerminal())`. The default `ResetAtTerminal()` assumes an episodic environment, changing that to `ResetAfterNSteps(n)` will no longer check `is_terminated` but will instead call `reset!` every `n` steps. This way, the value of the last state will not be multiplied by 0 during bootstrapping and the correct value can be learned.
12+
To manage this, we provide the `ResetAfterNSteps(n)` condition as an argument to `run(policy, env, stop_condition, hook, reset_condition = ResetIfEnvTerminated())`. The default `ResetIfEnvTerminated()` assumes an episodic environment, changing that to `ResetAfterNSteps(n)` will no longer check `is_terminated` but will instead call `reset!` every `n` steps. This way, the value of the last state will not be multiplied by 0 during bootstrapping and the correct value can be learned.
1313

1414
## Custom reset conditions
1515

@@ -39,7 +39,7 @@ end
3939
run(agent, env, stop_condition, hook, MyCondition(ResetAfterNSteps(10000)))
4040
```
4141

42-
A last possibility is to use an anonymous function. This approach cannot be used to implement stateful conditions (such as `ResetAfterNSteps`). For example here is alternative way to implement `ResetAtTerminal`:
42+
A last possibility is to use an anonymous function. This approach cannot be used to implement stateful conditions (such as `ResetAfterNSteps`). For example here is alternative way to implement `ResetIfEnvTerminated`:
4343

4444
```julia
4545
run(agent, env, stop_condition, hook, (p,e) -> is_terminated(e))

docs/src/tutorial.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ a descriptive pattern.
4343
run(
4444
RandomPolicy(),
4545
RandomWalk1D(),
46-
StopAfterEpisode(10),
46+
StopAfterNEpisodes(10),
4747
TotalRewardPerEpisode()
4848
)
4949
```
@@ -58,7 +58,7 @@ policy = TabularPolicy(;table=Dict(zip(1:NS, fill(2, NS))))
5858
run(
5959
policy,
6060
RandomWalk1D(),
61-
StopAfterEpisode(10),
61+
StopAfterNEpisodes(10),
6262
TotalRewardPerEpisode()
6363
)
6464
```
@@ -91,7 +91,7 @@ this policy to the `env` to estimate its performance.
9191
run(
9292
policy,
9393
RandomWalk1D(),
94-
StopAfterEpisode(10),
94+
StopAfterNEpisodes(10),
9595
TotalRewardPerEpisode()
9696
)
9797
```
@@ -109,7 +109,7 @@ agent = Agent(
109109
policy = policy,
110110
trajectory = VectorSARTTrajectory()
111111
)
112-
run(agent, env, StopAfterEpisode(10), TotalRewardPerEpisode())
112+
run(agent, env, StopAfterNEpisodes(10), TotalRewardPerEpisode())
113113
```
114114

115115
Here the [`VectorSARTTrajectory`](@ref) is used to store the **S**tate,

src/ReinforcementLearningCore/src/core/reset_conditions.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
export AbstractResetCondition, ResetAtTerminal, ResetAfterNSteps
1+
export AbstractResetCondition, ResetIfEnvTerminated, ResetAfterNSteps
22

33
abstract type AbstractResetCondition end
44

55
"""
6-
ResetAtTerminal()
6+
ResetIfEnvTerminated()
77
88
A reset condition that resets the environment if is_terminated(env) is true.
99
"""
10-
struct ResetAtTerminal <: AbstractResetCondition end
10+
struct ResetIfEnvTerminated <: AbstractResetCondition end
1111

12-
(::ResetAtTerminal)(policy, env) = is_terminated(env)
12+
check!(::ResetIfEnvTerminated, policy::AbstractPolicy, env::AbstractEnv) = is_terminated(env)
1313

1414
"""
1515
ResetAfterNSteps(n)
@@ -23,7 +23,7 @@ end
2323

2424
ResetAfterNSteps(n::Int) = ResetAfterNSteps(0, n)
2525

26-
function (r::ResetAfterNSteps)(policy, env)
26+
function check!(r::ResetAfterNSteps, policy::AbstractPolicy, env::AbstractEnv)
2727
stop = r.t >= r.n
2828
r.t += 1
2929
if stop

src/ReinforcementLearningCore/src/core/run.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ end
1717
function Base.run(
1818
policy::AbstractPolicy,
1919
env::AbstractEnv,
20-
stop_condition::AbstractStopCondition=StopAfterEpisode(1),
20+
stop_condition::AbstractStopCondition=StopAfterNEpisodes(1),
2121
hook::AbstractHook=EmptyHook(),
22-
reset_condition::AbstractResetCondition=ResetAtTerminal()
22+
reset_condition::AbstractResetCondition=ResetIfEnvTerminated()
2323
)
2424
policy, env = check(policy, env)
2525
_run(policy, env, stop_condition, hook, reset_condition)
@@ -44,7 +44,7 @@ function _run(policy::AbstractPolicy,
4444
@timeit_debug timer "push!(hook) PreEpisodeStage" push!(hook, PreEpisodeStage(), policy, env)
4545

4646

47-
while !reset_condition(policy, env) # one episode
47+
while !check!(reset_condition, policy, env) # one episode
4848
@timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
4949
@timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
5050
@timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
@@ -56,7 +56,7 @@ function _run(policy::AbstractPolicy,
5656
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
5757
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)
5858

59-
if check_stop(stop_condition, policy, env)
59+
if check!(stop_condition, policy, env)
6060
is_stop = true
6161
break
6262
end

0 commit comments

Comments
 (0)