Skip to content

Commit 81c161f

Browse files
Rearrange approximator setup
1 parent 1d3c7da commit 81c161f

File tree

6 files changed

+95
-48
lines changed

6 files changed

+95
-48
lines changed

src/ReinforcementLearningCore/src/policies/learners.jl

-26
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
export AbstractLearner, Approximator
2+
3+
using Flux
4+
using Functors: @functor
5+
6+
abstract type AbstractLearner end
7+
8+
Base.show(io::IO, m::MIME"text/plain", L::AbstractLearner) = show(io, m, convert(AnnotatedStructTree, L))
9+
10+
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
11+
function forward(L::Le, env::E) where {Le <: AbstractLearner, E <: AbstractEnv}
12+
env |> state |> Flux.gpu |> (x -> forward(L, x)) |> Flux.cpu
13+
end
14+
15+
function RLBase.optimise!(::AbstractLearner, ::AbstractStage, ::Trajectory) end
16+
17+
18+
"""
19+
Approximator(model, optimiser)
20+
21+
Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)`
22+
interface. See the RLCore documentation for more information on proper usage.
23+
"""
24+
struct Approximator{M,O} <: AbstractLearner
25+
model::M
26+
optimiser_state::O
27+
end
28+
29+
function Approximator(; model, optimiser)
30+
optimiser_state = Flux.setup(optimiser, model)
31+
Approximator(gpu(model), optimiser_state) # Pass model to GPU (if available) upon creation
32+
end
33+
34+
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
35+
36+
@functor Approximator (model,)
37+
38+
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
39+
40+
RLBase.optimise!(A::Approximator, grad) =
41+
Flux.Optimise.update!(A.model, A.optimiser_state, grad)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
include("abstract_learner.jl")
2+
include("tabular_approximator.jl")
3+
include("target_network.jl")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
export TabularApproximator, TabularVApproximator, TabularQApproximator
2+
3+
using Flux: gpu
4+
5+
"""
6+
TabularApproximator(table<:AbstractArray, opt)
7+
8+
For `table` of 1-d, it will serve as a state value approximator.
9+
For `table` of 2-d, it will serve as a state-action value approximator.
10+
11+
!!! warning
12+
For `table` of 2-d, the first dimension is action and the second dimension is state.
13+
"""
14+
# TODO: add back missing AbstractApproximator
15+
struct TabularApproximator{N,A,O} <: AbstractLearner
16+
table::A
17+
optimizer::O
18+
function TabularApproximator(table::A, opt::O) where {A<:AbstractArray,O}
19+
n = ndims(table)
20+
n <= 2 || throw(ArgumentError("the dimension of table must be <= 2"))
21+
new{n,A,O}(table, opt)
22+
end
23+
end
24+
25+
TabularVApproximator(; n_state, init = 0.0, opt = InvDecay(1.0)) =
26+
TabularApproximator(fill(init, n_state), opt)
27+
28+
TabularQApproximator(; n_state, n_action, init = 0.0, opt = InvDecay(1.0)) =
29+
TabularApproximator(fill(init, n_action, n_state), opt)
30+
31+
# Take Learner and Environment, get state, send to RLCore.forward(Learner, State)
32+
function forward(L::TabularApproximator, env::E) where {E <: AbstractEnv}
33+
env |> state |> (x -> forward(L, x))
34+
end
35+
36+
RLCore.forward(
37+
app::TabularApproximator{1,R,O},
38+
s::I,
39+
) where {R<:AbstractArray,O,I<:Integer} = @views app.table[s]
40+
41+
RLCore.forward(
42+
app::TabularApproximator{2,R,O},
43+
s::I,
44+
) where {R<:AbstractArray,O,I<:Integer} = @views app.table[:, s]
45+
RLCore.forward(
46+
app::TabularApproximator{2,R,O},
47+
s::I1,
48+
a::I2,
49+
) where {R<:AbstractArray,O,I1<:Integer,I2<:Integer} = @views app.table[a, s]

src/ReinforcementLearningCore/src/policies/approximator.jl renamed to src/ReinforcementLearningCore/src/policies/learners/target_network.jl

+1-20
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,6 @@ export Approximator, TargetNetwork, target, model
33
using Flux
44

55

6-
"""
7-
Approximator(model, optimiser)
8-
9-
Wraps a Flux trainable model and implements the `RLBase.optimise!(::Approximator, ::Gradient)`
10-
interface. See the RLCore documentation for more information on proper usage.
11-
"""
12-
Base.@kwdef mutable struct Approximator{M,O}
13-
model::M
14-
optimiser::O
15-
end
16-
17-
Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(AnnotatedStructTree, A))
18-
19-
@functor Approximator (model,)
20-
21-
forward(A::Approximator, args...; kwargs...) = A.model(args...; kwargs...)
22-
23-
RLBase.optimise!(A::Approximator, gs) = Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
24-
256
target(ap::Approximator) = ap.model #see TargetNetwork
267
model(ap::Approximator) = ap.model #see TargetNetwork
278

@@ -68,7 +49,7 @@ target(tn::TargetNetwork) = tn.target
6849

6950
function RLBase.optimise!(tn::TargetNetwork, gs)
7051
A = tn.network
71-
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
52+
Flux.Optimise.update!(A.optimiser_state, Flux.params(A), gs)
7253
tn.n_optimise += 1
7354

7455
if tn.n_optimise % tn.sync_freq == 0
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
include("agent/agent.jl")
22
include("random_policy.jl")
3+
include("learners/learners.jl")
34
include("explorers/explorers.jl")
4-
include("learners.jl")
55
include("q_based_policy.jl")
6-
include("approximator.jl")

0 commit comments

Comments
 (0)