Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TargetNetwork #966

Merged
merged 30 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 2 additions & 15 deletions docs/src/How_to_implement_a_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,8 @@ The sampler is the object that will fetch data in your trajectory to create the

## Using resources from RLCore

RL algorithms typically only differ partially but broadly use the same mechanisms. The subpackage RLCore contains some utilities that you can reuse to implement your algorithm.

### QBasedPolicy

`QBasedPolicy` is a policy that wraps a Q-Value _learner_ (tabular or approximated) and an _explorer_. Use this wrapper to implement a policy that directly uses a Q-value function to
decide its next action. In that case, instead of creating an `AbstractPolicy` subtype for your algorithm, define an `AbstractLearner` subtype and specialize `RLBase.optimise!(::YourLearnerType, ::Stage, ::Trajectory)`. This way you will not have to code the interaction between your policy and the explorer yourself.
RLCore provides the most common explorers (such as epsilon-greedy, UCB, etc.).

### Neural and linear approximators

If your algorithm uses a neural network or a linear approximator to approximate a function trained with `Flux.jl`, use the `Approximator`. Approximator
wraps a `Flux` model and an `Optimiser` (such as Adam or SGD). Your `optimise!(::PolicyOrLearner, batch)` function will probably consist in computing a gradient
and call the `RLCore.optimise!(app::Approximator, gradient::Flux.Grads)` after that.

Common model architectures are also provided such as the `GaussianNetwork` for continuous policies with diagonal multivariate policies; and `CovGaussianNetwork` for full covariance (very slow on GPUs at the moment).
RL algorithms typically only differ partially but broadly use the same mechanisms. The subpackage RLCore contains some modules that you can reuse to implement your algorithm.
These will take care of many aspects of training for you. See the [RLCore manual](./rlcore.md)

### Utils
In utils/distributions.jl you will find implementations of gaussian log probabilities functions that are both GPU compatible and differentiable and that do not require the overhead of using Distributions.jl structs.
Expand Down
29 changes: 28 additions & 1 deletion docs/src/rlcore.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,31 @@

```@autodocs
Modules = [ReinforcementLearningCore]
```
```

In addition to containing the [run loop](./How_to_implement_a_new_algorithm.md), RLCore is a collection of pre-implemented components that are frequently used in RL.

## QBasedPolicy

`QBasedPolicy` is an `AbstractPolicy` that wraps a Q-Value _learner_ (tabular or approximated) and an _explorer_. Use this wrapper to implement a policy that directly uses a Q-value function to
decide its next action. In that case, instead of creating an `AbstractPolicy` subtype for your algorithm, define an `AbstractLearner` subtype and specialize `RLBase.optimise!(::YourLearnerType, ::Stage, ::Trajectory)`. This way you will not have to code the interaction between your policy and the explorer yourself.
RLCore provides the most common explorers (such as epsilon-greedy, UCB, etc.). You can find many examples of QBasedPolicies in the DQNs section of RLZoo.

## Parametric approximators
### Approximator

If your algorithm uses a neural network or a linear approximator to approximate a function trained with `Flux.jl`, use the `Approximator`. It
wraps a `Flux` model and an `Optimiser` (such as Adam or SGD). Your `optimise!(::PolicyOrLearner, batch)` function will probably consist in computing a gradient
and call the `RLBase.optimise!(app::Approximator, gradient::Flux.Grads)` after that.

`Approximator` implements the `model(::Approximator)` and `target(::Approximator)` interface. Both return the underlying Flux model. The advantage of this interface is explained in the `TargetNetwork` section below.

### TargetNetwork

The use of a target network is frequent in state or action value-based RL. The principle is to hold a copy of of the main approximator, which is trained using a gradient, and a copy of it that is either only partially updated, or just less frequently updated. `TargetNetwork` is constructed by wrapping an `Approximator`. Set the `sync_freq` keyword argument to a value greater that one to copy the main model into the target every `sync_freq` updates, or set the `\rho` parameter to a value greater than 0 (usually 0.99f0) to let the target be partially updated towards the main model every update. `RLBase.optimise!(tn::TargetNetwork, gradient::Flux.Grads)` will take care of updating the target for you.

The other advantage of `TargetNetwork` is that it uses Julia's multiple dispatch to let your algorithm be agnostic to the presence or absence of a target network. For example, the `DQNLearner` in RLZoo has an `approximator` field typed to be a `Union{Approximator, TargetNetwork}`. When computing the temporal difference error, the learner calls `Q = model(learner.approximator)` and `Qt = target(learner.approximator)`. If `learner.approximator` is a `Approximator`, then no target network is used because both calls point to the same neural network, if it is a `TargetNetwork` then the automatically managed target is returned.

## Architectures

Common model architectures are also provided such as the `GaussianNetwork` for continuous policies with diagonal multivariate policies; and `CovGaussianNetwork` for full covariance (very slow on GPUs at the moment).
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
version = "0.13.0"
version = "0.14.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -43,7 +43,7 @@ Parsers = "2"
ProgressMeter = "1"
Reexport = "1"
ReinforcementLearningBase = "0.12"
ReinforcementLearningTrajectories = "^0.3.3"
ReinforcementLearningTrajectories = "0.3.4"
StatsBase = "0.32, 0.33, 0.34"
TimerOutputs = "0.5"
UnicodePlots = "1.3, 2, 3"
Expand Down
81 changes: 81 additions & 0 deletions src/ReinforcementLearningCore/src/policies/approximator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
export Approximator, TargetNetwork, target, model

using Flux


"""
Approximator(model, optimiser)

Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)`
interface. See the RLCore documentation for more information on proper usage.
"""
Base.@kwdef mutable struct Approximator{M,O}
model::M
optimiser::O
end

Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))

@functor Approximator (model,)

forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)

RLBase.optimise!(A::Approximator, gs) = Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)

target(ap::Approximator) = ap.model #see TargetNetwork
model(ap::Approximator) = ap.model #see TargetNetwork

"""
TargetNetwork(network::Approximator; sync_freq::Int = 1, ρ::Float32 = 0f0)

Wraps an Approximator to hold a target network that is updated towards the model of the
approximator.
- `sync_freq` is the number of updates of `network` between each update of the `target`.
- ρ (\rho) is "how much of the target is kept when updating it".

The two common usages of TargetNetwork are
- use ρ = 0 to totally replace `target` with `network` every sync_freq updates.
- use ρ < 1 (but close to one) and sync_freq = 1 to let the target follow `network` with polyak averaging.

Implements the `RLBase.optimise!(::TargetNetwork, ::Gradient)` interface to update the model with the gradient
and the target with weights replacement or Polyak averaging.

Note to developpers: `model(::TargetNetwork)` will return the trainable Flux model
and `target(::TargetNetwork)` returns the target model and `target(::Approximator)`
returns the non-trainable Flux model. See the RLCore documentation.
"""
mutable struct TargetNetwork{M}
network::Approximator{M}
target::M
sync_freq::Int
ρ::Float32
n_optimise::Int
end

function TargetNetwork(x; sync_freq = 1, ρ = 0f0)
@assert 0 <= ρ <= 1 "ρ must in [0,1]"
TargetNetwork(x, deepcopy(x.model), sync_freq, ρ, 0)
end

@functor TargetNetwork (network, target)

Flux.trainable(model::TargetNetwork) = (model.network,)

forward(tn::TargetNetwork, args...) = forward(tn.network, args...)

model(tn::TargetNetwork) = model(tn.network)
target(tn::TargetNetwork) = tn.target

function RLBase.optimise!(tn::TargetNetwork, gs)
A = tn.network
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
tn.n_optimise += 1

if tn.n_optimise % tn.sync_freq == 0
# polyak averaging
for (dest, src) in zip(Flux.params(target(tn)), Flux.params(tn.network))
dest .= tn.ρ .* dest .+ (1 - tn.ρ) .* src
end
tn.n_optimise = 0
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ export AbstractExplorer
using FillArrays: Trues

"""
RLBase.plan!(p::AbstractExplorer, x)
RLBase.plan!(p::AbstractExplorer, x, mask)
RLBase.plan!(p::AbstractExplorer, x[, mask])

Define how to select an action based on action values.
"""
Expand Down
31 changes: 15 additions & 16 deletions src/ReinforcementLearningCore/src/policies/learners.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
export AbstractLearner, Approximator
export AbstractLearner

import Flux
using Functors: @functor

abstract type AbstractLearner end

Base.show(io::IO, m::MIME"text/plain", L::AbstractLearner) = show(io, m, convert(AnnotatedStructTree, L))

# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
forward(L::Le, env::E) where {Le <: AbstractLearner, E <: AbstractEnv} = env |> state |> send_to_device(L.approximator) |> x -> forward(L, x) |> send_to_device(env)

function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end

Base.@kwdef mutable struct Approximator{M,O}
model::M
optimiser::O
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv)
legal_action_space_ = RLBase.legal_action_space_mask(env)
RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
end

Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
function RLBase.plan!(explorer::AbstractExplorer, learner::AbstractLearner, env::AbstractEnv, player::Symbol)
legal_action_space_ = RLBase.legal_action_space_mask(env, player)
return RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
end

@functor Approximator (model,)
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
function forward(L::AbstractLearner, env::AbstractEnv)
s = state(env) |> send_to_device(L.approximator)
forward(L,s) |> send_to_device(env)
end

forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end

RLBase.optimise!(A::Approximator, gs) =
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
Base.show(io::IO, m::MIME"text/plain", L::AbstractLearner) = show(io, m, convert(AnnotatedStructTree, L))
3 changes: 3 additions & 0 deletions src/ReinforcementLearningCore/src/policies/policies.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
include("agent/agent.jl")
include("random_policy.jl")
include("explorers/explorers.jl")
include("learners.jl")
include("q_based_policy.jl")
include("approximator.jl")
12 changes: 0 additions & 12 deletions src/ReinforcementLearningCore/src/policies/q_based_policy.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
export QBasedPolicy

include("learners.jl")
include("explorers/explorers.jl")

using Functors: @functor

"""
Expand All @@ -26,19 +23,10 @@ function RLBase.plan!(p::QBasedPolicy{L,Ex}, env::E) where {Ex<:AbstractExplorer
RLBase.plan!(p.explorer, p.learner, env)
end

function RLBase.plan!(explorer::Ex, learner::L, env::E) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv}
RLBase.plan!(explorer, forward(learner, env), legal_action_space_mask(env))
end

function RLBase.plan!(p::QBasedPolicy{L,Ex}, env::E, player::Symbol) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv}
RLBase.plan!(p.explorer, p.learner, env, player)
end

function RLBase.plan!(explorer::Ex, learner::L, env::E, player::Symbol) where {Ex<:AbstractExplorer,L<:AbstractLearner,E<:AbstractEnv}
legal_action_space_ = RLBase.legal_action_space_mask(env, player)
return RLBase.plan!(explorer, forward(learner, env), legal_action_space_)
end

RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env))

Expand Down
35 changes: 0 additions & 35 deletions src/ReinforcementLearningCore/src/utils/networks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -510,38 +510,3 @@ function vae_loss(model::VAE, state, action)
kl_loss = -0.5f0 * mean(1.0f0 .+ log.(σ .^ 2) .- μ .^ 2 .- σ .^ 2)
return recon_loss, kl_loss
end

#####
# TwinNetwork
#####

export TwinNetwork

Base.@kwdef mutable struct TwinNetwork{S,T}
source::S
target::T
sync_freq::Int = 1
ρ::Float32 = 0.0f0
n_optimise::Int = 0
end

TwinNetwork(x; kw...) = TwinNetwork(; source=x, target=deepcopy(x), kw...)

@functor TwinNetwork (source, target)

Flux.trainable(model::TwinNetwork) = (model.source,)

(model::TwinNetwork)(args...) = model.source(args...)

function RLBase.optimise!(A::Approximator{<:TwinNetwork}, gs)
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
M = A.model
M.n_optimise += 1

if M.n_optimise % M.sync_freq == 0
# polyak averaging
for (dest, src) in zip(Flux.params(M.target), Flux.params(M.source))
dest .= M.ρ .* dest .+ (1 - M.ρ) .* src
end
end
end
28 changes: 28 additions & 0 deletions src/ReinforcementLearningCore/test/policies/approximators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

@testset "approximators.jl" begin
@testset "TargetNetwork" begin
m = Chain(Dense(4,1))
app = Approximator(model = m, optimiser = Flux.Adam())
tn = TargetNetwork(app, sync_freq = 3)
@test typeof(model(tn)) == typeof(target(tn))
p1 = Flux.destructure(model(tn))[1]
pt1 = Flux.destructure(target(tn))[1]
@test p1 == pt1
gs = Flux.Zygote.gradient(Flux.params(tn)) do
sum(RLCore.forward(tn, ones(Float32, 4)))
end
RLCore.optimise!(tn, gs)
@test p1 != Flux.destructure(model(tn))[1]
@test p1 == Flux.destructure(target(tn))[1]
RLCore.optimise!(tn, gs)
@test p1 != Flux.destructure(model(tn))[1]
@test p1 == Flux.destructure(target(tn))[1]
RLCore.optimise!(tn, gs)
@test Flux.destructure(target(tn))[1] == Flux.destructure(model(tn))[1]
@test p1 != Flux.destructure(target(tn))[1]
p2 = Flux.destructure(model(tn))[1]
RLCore.optimise!(tn, gs)
@test p2 != Flux.destructure(model(tn))[1]
@test p2 == Flux.destructure(target(tn))[1]
end
end
1 change: 1 addition & 0 deletions src/ReinforcementLearningCore/test/policies/policies.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include("agent.jl")
include("multi_agent.jl")
include("approximators.jl")
2 changes: 1 addition & 1 deletion src/ReinforcementLearningEnvironments/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ DelimitedFiles = "1"
IntervalSets = "0.7"
MacroTools = "0.5"
ReinforcementLearningBase = "0.12"
ReinforcementLearningCore = "0.12, 0.13"
ReinforcementLearningCore = "0.12, 0.13, 0.14"
Requires = "1.0"
StatsBase = "0.32, 0.33, 0.34"
julia = "1.3"
Expand Down
6 changes: 3 additions & 3 deletions src/ReinforcementLearningExperiments/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReinforcementLearningExperiments"
uuid = "6bd458e5-1694-412f-b601-3a888375c491"
authors = ["Jun Tian <[email protected]>"]
version = "0.3.5"
version = "0.4"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -21,9 +21,9 @@ Distributions = "0.25"
Flux = "0.13, 0.14"
Reexport = "1"
ReinforcementLearningBase = "0.12"
ReinforcementLearningCore = "0.12, 0.13"
ReinforcementLearningCore = "0.14"
ReinforcementLearningEnvironments = "0.8"
ReinforcementLearningZoo = "^0.8.3"
ReinforcementLearningZoo = "0.9"
StableRNGs = "1"
Weave = "0.10"
cuDNN = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,22 @@ function RLCore.Experiment(
policy = Agent(
QBasedPolicy(
learner=DQNLearner(
approximator=Approximator(
model=TwinNetwork(
Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na; init = glorot_uniform(rng)),
);
approximator=TargetNetwork(
Approximator(
model = Chain(
Dense(ns, 128, relu; init=glorot_uniform(rng)),
Dense(128, 128, relu; init=glorot_uniform(rng)),
Dense(128, na; init=glorot_uniform(rng)),
),
optimiser=Adam()
),
sync_freq=100
),
optimiser=Adam(),
) |> gpu,
n=n,
γ=γ,
is_enable_double_DQN=true,
loss_func=huber_loss,
rng=rng,
),
n=n,
γ=γ,
loss_func=huber_loss,
rng=rng,
),
explorer=EpsilonGreedyExplorer(
kind=:exp,
ϵ_stable=0.01,
Expand Down
Loading