Skip to content

Fix reset and stop conditions #1046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ julia> using ReinforcementLearning
julia> run(
RandomPolicy(),
CartPoleEnv(),
StopAfterStep(1_000),
StopAfterNSteps(1_000),
TotalRewardPerEpisode()
)
```
Expand All @@ -66,7 +66,7 @@ reinforcement learning experiment:
to test reinforcement learning algorithms.

- **Stop Condition**. The
[`StopAfterStep(1_000)`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.StopAfterStep)
[`StopAfterNSteps(1_000)`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.StopAfterNSteps)
is to inform that our experiment should stop after
`1_000` steps.

Expand Down
20 changes: 10 additions & 10 deletions docs/homepage/blog/a_practical_introduction_to_RL.jl/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -13872,7 +13872,7 @@ <h2 id="Q2:-What-if-we-want-to-stop-after-several-episodes?"><strong>Q2: What if
<div class="prompt input_prompt">In&nbsp;[15]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">))</span>
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">))</span>
</pre></div>

</div>
Expand Down Expand Up @@ -13906,8 +13906,8 @@ <h2 id="Q2:-What-if-we-want-to-stop-after-several-episodes?"><strong>Q2: What if
<h3 id="Q2.b:-What-if-we-want-to-stop-until-arbitrary-condition-meets?"><strong>Q2.b: What if we want to stop until arbitrary condition meets?</strong><a class="anchor-link" href="#Q2.b:-What-if-we-want-to-stop-until-arbitrary-condition-meets?">&#182;</a></h3><p>Well, in that case, you need to implement your customized <em>stop condition</em> here.
In RL.jl, several common ones are already provided, like:</p>
<ul>
<li><code>StopAfterStep</code></li>
<li><code>StopAfterEpisode</code></li>
<li><code>StopAfterNSteps</code></li>
<li><code>StopAfterNEpisodes</code></li>
<li><code>StopAfterNSeconds</code></li>
<li>...</li>
</ul>
Expand Down Expand Up @@ -13936,7 +13936,7 @@ <h2 id="Q3:-How-to-collect-experiment-results?"><strong>Q3: How to collect exper
<div class="prompt input_prompt">In&nbsp;[16]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
</pre></div>

</div>
Expand Down Expand Up @@ -14144,7 +14144,7 @@ <h2 id="The-Actor-Mode">The <em>Actor</em> Mode<a class="anchor-link" href="#The
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span>
<span class="n">policy</span><span class="p">,</span>
<span class="n">RandomWalk1D</span><span class="p">(),</span>
<span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
<span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
<span class="n">TotalRewardPerEpisode</span><span class="p">()</span>
<span class="p">)</span>
</pre></div>
Expand Down Expand Up @@ -14311,7 +14311,7 @@ <h2 id="The-Training-Mode">The <em>Training</em> Mode<a class="anchor-link" href
<div class="prompt input_prompt">In&nbsp;[22]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
<div class=" highlight hl-julia"><pre><span></span><span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">TotalRewardPerEpisode</span><span class="p">())</span>
</pre></div>

</div>
Expand Down Expand Up @@ -14383,7 +14383,7 @@ <h2 id="The-Training-Mode">The <em>Training</em> Mode<a class="anchor-link" href
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-julia"><pre><span></span><span class="n">hook</span> <span class="o">=</span> <span class="n">StepsPerEpisode</span><span class="p">()</span>
<span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">hook</span><span class="p">)</span>
<span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">hook</span><span class="p">)</span>
<span class="n">plot</span><span class="p">(</span><span class="n">hook</span><span class="o">.</span><span class="n">steps</span><span class="p">[</span><span class="mi">1</span><span class="o">:</span><span class="k">end</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
</pre></div>

Expand Down Expand Up @@ -14593,7 +14593,7 @@ <h3 id="Q4:-Why-does-it-need-more-than-3-steps-to-reach-our-goal?">Q4: Why does
<span class="n">explorer</span><span class="o">=</span><span class="n">GreedyExplorer</span><span class="p">()</span>
<span class="p">),</span>
<span class="n">env</span><span class="p">,</span>
<span class="n">StopAfterEpisode</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
<span class="n">StopAfterNEpisodes</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span>
<span class="n">hook</span>
<span class="p">)</span>
<span class="n">plot</span><span class="p">(</span><span class="n">hook</span><span class="o">.</span><span class="n">steps</span><span class="p">[</span><span class="mi">1</span><span class="o">:</span><span class="k">end</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span>
Expand Down Expand Up @@ -15158,7 +15158,7 @@ <h2 id="Two-Most-Commonly-Used-Algorithms">Two Most Commonly Used Algorithms<a c
<div class="prompt input_prompt">In&nbsp;[35]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterStep</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterNSteps</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
<span class="n">hook</span> <span class="o">=</span> <span class="n">TotalRewardPerEpisode</span><span class="p">()</span>
<span class="n">run</span><span class="p">(</span><span class="n">policy</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">stop_condition</span><span class="p">,</span> <span class="n">hook</span><span class="p">)</span>
</pre></div>
Expand Down Expand Up @@ -15306,7 +15306,7 @@ <h2 id="Two-Most-Commonly-Used-Algorithms">Two Most Commonly Used Algorithms<a c
<div class="prompt input_prompt">In&nbsp;[38]:</div>
<div class="inner_cell">
<div class="input_area">
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterStep</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
<div class=" highlight hl-julia"><pre><span></span><span class="n">stop_condition</span> <span class="o">=</span> <span class="n">StopAfterNSteps</span><span class="p">(</span><span class="mi">10_000</span><span class="p">)</span>
<span class="n">hook</span> <span class="o">=</span> <span class="n">TotalBatchRewardPerEpisode</span><span class="p">(</span><span class="n">N_ENV</span><span class="p">)</span>
<span class="n">run</span><span class="p">(</span><span class="n">agent</span><span class="p">,</span> <span class="n">env</span><span class="p">,</span> <span class="n">stop_condition</span><span class="p">,</span> <span class="n">hook</span><span class="p">)</span>
</pre></div>
Expand Down
4 changes: 2 additions & 2 deletions docs/homepage/blog/ospp_report_210370190/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ Besides, I implement the [`EDManager`](https://juliareinforcementlearning.org/do
function Base.run(
π::EDManager,
env::AbstractEnv,
stop_condition = StopAfterEpisode(1),
stop_condition = StopAfterNEpisodes(1),
hook::AbstractHook = EmptyHook(),
)
@assert NumAgentStyle(env) == MultiAgent(2) "ED algorithm only support 2-players games."
Expand Down Expand Up @@ -757,7 +757,7 @@ EDmanager = EDManager(
)
)
# initialize the `stop_condition` and `hook`.
stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI"))
stop_condition = StopAfterNEpisodes(100_000, is_show_progress=!haskey(ENV, "CI"))
hook = KuhnOpenNewEDHook(0, 100, [], [])
```

Expand Down
10 changes: 5 additions & 5 deletions docs/src/How_to_use_hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end
(h::TimeCostPerEpisode)(::PreEpisodeStage, policy, env) = h.t = time_ns()
(h::TimeCostPerEpisode)(::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t)
h = TimeCostPerEpisode()
run(RandomPolicy(), CartPoleEnv(), StopAfterEpisode(10), h)
run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(10), h)
h.time_costs
```

Expand Down Expand Up @@ -97,7 +97,7 @@ policy = RandomPolicy()
run(
policy,
CartPoleEnv(),
StopAfterEpisode(100),
StopAfterNEpisodes(100),
DoEveryNEpisode(;n=10) do t, policy, env
# In real world cases, the policy is usually wrapped in an Agent,
# we need to extract the inner policy to run it in the *actor* mode.
Expand All @@ -107,7 +107,7 @@ run(
# polluting the original env.

hook = TotalRewardPerEpisode(;is_display_on_exit=false)
run(policy, CartPoleEnv(), StopAfterEpisode(10), hook)
run(policy, CartPoleEnv(), StopAfterNEpisodes(10), hook)

# now you can report the result of the hook.
println("avg reward at episode $t is: $(mean(hook.rewards))")
Expand Down Expand Up @@ -159,7 +159,7 @@ parameters_dir = mktempdir()
run(
policy,
env,
StopAfterStep(10_000),
StopAfterNSteps(10_000),
DoEveryNStep(n=1_000) do t, p, e
ps = params(p)
f = joinpath(parameters_dir, "parameters_at_step_$t.bson")
Expand Down Expand Up @@ -192,7 +192,7 @@ hook = ComposedHook(
end
end
)
run(RandomPolicy(), CartPoleEnv(), StopAfterEpisode(50), hook)
run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(50), hook)
readdir(tf_log_dir)
```

Expand Down
6 changes: 3 additions & 3 deletions docs/src/How_to_write_a_customized_environment.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ ReinforcementLearning.jl also work. Similar to the test above, let's try the
[`RandomPolicy`](@ref) first:

```@repl customized_env
run(RandomPolicy(action_space(env)), env, StopAfterEpisode(1_000))
run(RandomPolicy(action_space(env)), env, StopAfterNEpisodes(1_000))
```

If no error shows up, then it means our environment at least works with
Expand All @@ -126,7 +126,7 @@ episode to see the performance of the `RandomPolicy`.

```@repl customized_env
hook = TotalRewardPerEpisode()
run(RandomPolicy(action_space(env)), env, StopAfterEpisode(1_000), hook)
run(RandomPolicy(action_space(env)), env, StopAfterNEpisodes(1_000), hook)
using Plots
pyplot() #hide
plot(hook.rewards)
Expand Down Expand Up @@ -198,7 +198,7 @@ Nice job! Now we are ready to run the experiment:

```@repl customized_env
h = TotalRewardPerEpisode()
run(p, wrapped_env, StopAfterEpisode(1_000), h)
run(p, wrapped_env, StopAfterNEpisodes(1_000), h)
plot(h.rewards)
savefig("custom_env_random_policy_reward_wrapped_env.svg"); nothing # hide
```
Expand Down
4 changes: 2 additions & 2 deletions docs/src/non_episodic.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Using this means that the value of the terminal state is set to 0 when learning

Also called _Continuing tasks_ (Sutton & Barto, 2018), non-episodic environment do not have a terminal state and thus may run for ever, or until the `stop_condition` is reached. Sometimes however, one may want to periodically reset the environment to start fresh. A first possibility is to implement `RLBase.is_terminated(::YourEnvironment)` to reset according to an arbitrary condition. However this may not be a good idea because the value of the last state (note that it is not a _terminal_ state) will be bootstrapped to 0 during learning, even though it is not the true value of the state.

To manage this, we provide the `ResetAfterNSteps(n)` condition as an argument to `run(policy, env, stop_condition, hook, reset_condition = ResetAtTerminal())`. The default `ResetAtTerminal()` assumes an episodic environment, changing that to `ResetAfterNSteps(n)` will no longer check `is_terminated` but will instead call `reset!` every `n` steps. This way, the value of the last state will not be multiplied by 0 during bootstrapping and the correct value can be learned.
To manage this, we provide the `ResetAfterNSteps(n)` condition as an argument to `run(policy, env, stop_condition, hook, reset_condition = ResetIfEnvTerminated())`. The default `ResetIfEnvTerminated()` assumes an episodic environment, changing that to `ResetAfterNSteps(n)` will no longer check `is_terminated` but will instead call `reset!` every `n` steps. This way, the value of the last state will not be multiplied by 0 during bootstrapping and the correct value can be learned.

## Custom reset conditions

Expand Down Expand Up @@ -39,7 +39,7 @@ end
run(agent, env, stop_condition, hook, MyCondition(ResetAfterNSteps(10000)))
```

A last possibility is to use an anonymous function. This approach cannot be used to implement stateful conditions (such as `ResetAfterNSteps`). For example here is alternative way to implement `ResetAtTerminal`:
A last possibility is to use an anonymous function. This approach cannot be used to implement stateful conditions (such as `ResetAfterNSteps`). For example here is alternative way to implement `ResetIfEnvTerminated`:

```julia
run(agent, env, stop_condition, hook, (p,e) -> is_terminated(e))
Expand Down
8 changes: 4 additions & 4 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ a descriptive pattern.
run(
RandomPolicy(),
RandomWalk1D(),
StopAfterEpisode(10),
StopAfterNEpisodes(10),
TotalRewardPerEpisode()
)
```
Expand All @@ -58,7 +58,7 @@ policy = TabularPolicy(;table=Dict(zip(1:NS, fill(2, NS))))
run(
policy,
RandomWalk1D(),
StopAfterEpisode(10),
StopAfterNEpisodes(10),
TotalRewardPerEpisode()
)
```
Expand Down Expand Up @@ -91,7 +91,7 @@ this policy to the `env` to estimate its performance.
run(
policy,
RandomWalk1D(),
StopAfterEpisode(10),
StopAfterNEpisodes(10),
TotalRewardPerEpisode()
)
```
Expand All @@ -109,7 +109,7 @@ agent = Agent(
policy = policy,
trajectory = VectorSARTTrajectory()
)
run(agent, env, StopAfterEpisode(10), TotalRewardPerEpisode())
run(agent, env, StopAfterNEpisodes(10), TotalRewardPerEpisode())
```

Here the [`VectorSARTTrajectory`](@ref) is used to store the **S**tate,
Expand Down
10 changes: 5 additions & 5 deletions src/ReinforcementLearningCore/src/core/reset_conditions.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
export AbstractResetCondition, ResetAtTerminal, ResetAfterNSteps
export AbstractResetCondition, ResetIfEnvTerminated, ResetAfterNSteps

abstract type AbstractResetCondition end

"""
ResetAtTerminal()
ResetIfEnvTerminated()

A reset condition that resets the environment if is_terminated(env) is true.
"""
struct ResetAtTerminal <: AbstractResetCondition end
struct ResetIfEnvTerminated <: AbstractResetCondition end

(::ResetAtTerminal)(policy, env) = is_terminated(env)
check!(::ResetIfEnvTerminated, policy::AbstractPolicy, env::AbstractEnv) = is_terminated(env)

"""
ResetAfterNSteps(n)
Expand All @@ -23,7 +23,7 @@ end

ResetAfterNSteps(n::Int) = ResetAfterNSteps(0, n)

function (r::ResetAfterNSteps)(policy, env)
function check!(r::ResetAfterNSteps, policy::AbstractPolicy, env::AbstractEnv)
stop = r.t >= r.n
r.t += 1
if stop
Expand Down
8 changes: 4 additions & 4 deletions src/ReinforcementLearningCore/src/core/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ end
function Base.run(
policy::AbstractPolicy,
env::AbstractEnv,
stop_condition::AbstractStopCondition=StopAfterEpisode(1),
stop_condition::AbstractStopCondition=StopAfterNEpisodes(1),
hook::AbstractHook=EmptyHook(),
reset_condition::AbstractResetCondition=ResetAtTerminal()
reset_condition::AbstractResetCondition=ResetIfEnvTerminated()
)
policy, env = check(policy, env)
_run(policy, env, stop_condition, hook, reset_condition)
Expand All @@ -44,7 +44,7 @@ function _run(policy::AbstractPolicy,
@timeit_debug timer "push!(hook) PreEpisodeStage" push!(hook, PreEpisodeStage(), policy, env)


while !reset_condition(policy, env) # one episode
while !check!(reset_condition, policy, env) # one episode
@timeit_debug timer "push!(policy) PreActStage" push!(policy, PreActStage(), env)
@timeit_debug timer "optimise! PreActStage" optimise!(policy, PreActStage())
@timeit_debug timer "push!(hook) PreActStage" push!(hook, PreActStage(), policy, env)
Expand All @@ -56,7 +56,7 @@ function _run(policy::AbstractPolicy,
@timeit_debug timer "optimise! PostActStage" optimise!(policy, PostActStage())
@timeit_debug timer "push!(hook) PostActStage" push!(hook, PostActStage(), policy, env)

if check_stop(stop_condition, policy, env)
if check!(stop_condition, policy, env)
is_stop = true
break
end
Expand Down
Loading
Loading