Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and refactor SAC #985

Merged
merged 25 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15129,7 +15129,7 @@ <h2 id="Two-Most-Commonly-Used-Algorithms">Two Most Commonly Used Algorithms<a c
<span class="p">)</span> <span class="o">|&gt;</span> <span class="n">cpu</span><span class="p">,</span>
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">Adam</span><span class="p">(),</span>
<span class="p">),</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">batchsize</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">min_replay_history</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">loss_func</span> <span class="o">=</span> <span class="n">huber_loss</span><span class="p">,</span>
<span class="n">rng</span> <span class="o">=</span> <span class="n">rng</span><span class="p">,</span>
Expand Down
14 changes: 7 additions & 7 deletions docs/homepage/blog/ospp_final_term_report_210370741/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ struct D4RLDataSet{T<:AbstractRNG} <: RLDataSet
dataset::Dict{Symbol, Any}
repo::String
dataset_size::Integer
batch_size::Integer
batchsize::Integer
style::Tuple
rng::T
meta::Dict
Expand All @@ -330,7 +330,7 @@ function dataset(dataset::String;
repo = "d4rl",
rng = StableRNG(123),
is_shuffle = true,
batch_size=256
batchsize=256
)
```

Expand Down Expand Up @@ -383,7 +383,7 @@ Multi threaded batching using a parallel loop where each thread loads the batche

```julia
res = Channel{AtariRLTransition}(n_preallocations; taskref=taskref, spawn=true) do ch
Threads.@threads for i in 1:batch_size
Threads.@threads for i in 1:batchsize
put!(ch, deepcopy(batch(buffer_template, popfirst!(transitions), i)))
end
end
Expand Down Expand Up @@ -472,7 +472,7 @@ end
The datapoints are then put in a `RingBuffer` which is returned.
```julia
res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff
Threads.@threads for i in 1:batch_size
Threads.@threads for i in 1:batchsize
batch!(buff, take!(transitions), i)
end
end
Expand Down Expand Up @@ -694,7 +694,7 @@ mutable struct FQE{
target_q_network::C_T
n_evals::Int
γ::Float32
batch_size::Int
batchsize::Int
update_freq::Int
update_step::Int
tar_update_freq::Int
Expand All @@ -714,7 +714,7 @@ function RLBase.update!(l::FQE, batch::NamedTuple{SARTS})
D = device(Q)
s, a, r, t, s′ = (send_to_device(D, batch[x]) for x in SARTS)
γ = l.γ
batch_size = l.batch_size
batchsize = l.batchsize

loss_func = Flux.Losses.mse

Expand All @@ -723,7 +723,7 @@ function RLBase.update!(l::FQE, batch::NamedTuple{SARTS})
target = r .+ γ .* (1 .- t) .* q′

gs = gradient(params(Q)) do
q = Q(vcat(s, reshape(a, :, batch_size))) |> vec
q = Q(vcat(s, reshape(a, :, batchsize))) |> vec
loss = loss_func(q, target)
Zygote.ignore() do
l.loss = loss
Expand Down
12 changes: 6 additions & 6 deletions docs/homepage/blog/ospp_report_210370190/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ rl_agent = Agent(
),
γ = 1.0f0,
loss_func = huber_loss,
batch_size = 128,
batchsize = 128,
update_freq = 128,
min_replay_history = 1000,
target_update_freq = 1000,
Expand Down Expand Up @@ -281,7 +281,7 @@ sl_agent = Agent(
optimizer = Descent(0.01),
),
explorer = WeightedSoftmaxExplorer(),
batch_size = 128,
batchsize = 128,
min_reservoir_history = 1000,
rng = rng,
),
Expand Down Expand Up @@ -351,7 +351,7 @@ Given that the [`DDPGPolicy`](https://juliareinforcementlearning.org/docs/rlzoo/
mutable struct MADDPGManager <: AbstractPolicy
agents::Dict{<:Any, <:Agent}
traces
batch_size::Int
batchsize::Int
update_freq::Int
update_step::Int
rng::AbstractRNG
Expand Down Expand Up @@ -454,7 +454,7 @@ agents = MADDPGManager(
trajectory = deepcopy(trajectory),
)) for player in players(env) if player != chance_player(env)),
SARTS, # trace's type
512, # batch_size
512, # batchsize
100, # update_freq
0, # initial update_step
rng
Expand Down Expand Up @@ -508,7 +508,7 @@ create_policy(player) = DDPGPolicy(
na = length(action_space(env, player)),
start_steps = 0,
start_policy = nothing,
update_after = 512 * env.max_steps, # batch_size * env.max_steps
update_after = 512 * env.max_steps, # batchsize * env.max_steps
act_limit = 1.0,
act_noise = 0.,
)
Expand All @@ -530,7 +530,7 @@ agents = MADDPGManager(
) for player in (:Speaker, :Listener)
),
SARTS, # trace's type
512, # batch_size
512, # batchsize
100, # update_freq
0, # initial update_step
rng
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ Base.@kwdef struct OfflinePolicy{L,T} <: AbstractPolicy
learner::L
dataset::T
continuous::Bool
batch_size::Int
batchsize::Int
end
```
This implementation of `OfflinePolicy` refers to [`QBasePolicy`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.QBasedPolicy). It provides a parameter `continuous` to support different action space types, including continuous and discrete. `learner` is a specific algorithm for learning and providing policy. `dataset` and `batch_size` are used to sample data for learning.
This implementation of `OfflinePolicy` refers to [`QBasePolicy`](https://juliareinforcementlearning.org/docs/rlcore/#ReinforcementLearningCore.QBasedPolicy). It provides a parameter `continuous` to support different action space types, including continuous and discrete. `learner` is a specific algorithm for learning and providing policy. `dataset` and `batchsize` are used to sample data for learning.

Besides, we implement corresponding functions `π`, `update!` and `sample`. `π` is used to select the action, whose form is determined by the type of action space. `update!` can be used in two stages. In `PreExperiment` stage, we can call this function for pre-training algorithms with `pretrain_step` parameters. In `PreAct` stage, we call this function for training the `learner`. In function `update!`, we need to call function `sample` to sample a batch of data from the dataset. With the development of [ReinforcementLearningDataset.jl](https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/tree/master/src/ReinforcementLearningDatasets), the `sample` function will be deprecated.

Expand All @@ -73,7 +73,7 @@ offline_dqn_policy = OfflinePolicy(
),
dataset = dataset,
continuous = false,
batch_size = 64,
batchsize = 64,
)
```

Expand Down Expand Up @@ -266,7 +266,7 @@ function RLBase.update!(p::OfflinePolicy, traj::AbstractTrajectory, ::AbstractEn
if in(:pretrain_step, fieldnames(typeof(l)))
println("Pretrain...")
for _ in 1:l.pretrain_step
inds, batch = sample(l.rng, p.dataset, p.batch_size)
inds, batch = sample(l.rng, p.dataset, p.batchsize)
update!(l, batch)
end
end
Expand Down
2 changes: 1 addition & 1 deletion docs/src/How_to_use_hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ policy = Agent(
) |> cpu,
optimizer = Adam(),
),
batch_size = 32,
batchsize = 32,
min_replay_history = 100,
loss_func = huber_loss,
),
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ReinforcementLearningBase"
uuid = "e575027e-6cd6-5018-9292-cdc6200d2b44"
authors = ["Johanni Brea <[email protected]>", "Jun Tian <[email protected]>"]
version = "0.12.1"
version = "0.12.2"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
4 changes: 2 additions & 2 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReinforcementLearningCore"
uuid = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
version = "0.14.0"
version = "0.15"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -30,7 +30,7 @@ cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[compat]
AbstractTrees = "0.3, 0.4"
Adapt = "3"
CUDA = "4"
CUDA = "4, 5"
ChainRulesCore = "1"
CircularArrayBuffers = "0.1"
Crayons = "4"
Expand Down
6 changes: 3 additions & 3 deletions src/ReinforcementLearningCore/src/core/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ end
Base.getindex(h::BatchStepsPerEpisode) = h.steps

"""
BatchStepsPerEpisode(batch_size::Int; tag = "TRAINING")
BatchStepsPerEpisode(batchsize::Int; tag = "TRAINING")

Similar to [`StepsPerEpisode`](@ref), but is specific to environments
which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
"""
function BatchStepsPerEpisode(batch_size::Int)
BatchStepsPerEpisode([Int[] for _ = 1:batch_size], zeros(Int, batch_size))
function BatchStepsPerEpisode(batchsize::Int)
BatchStepsPerEpisode([Int[] for _ = 1:batchsize], zeros(Int, batchsize))
end

function Base.push!(hook::BatchStepsPerEpisode,
Expand Down
8 changes: 4 additions & 4 deletions src/ReinforcementLearningCore/src/utils/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ end
mvnormlogpdf(μ::A, LorU::A, x::A; ϵ = 1f-8) where A <: AbstractArray

Batch version that takes 3D tensors as input where each slice along the 3rd
dimension is a batch sample. `μ` is a (action_size x 1 x batch_size) matrix,
`L` is a (action_size x action_size x batch_size), x is a (action_size x
action_samples x batch_size). Return a 3D matrix of size (1 x action_samples x
batch_size).
dimension is a batch sample. `μ` is a (action_size x 1 x batchsize) matrix,
`L` is a (action_size x action_size x batchsize), x is a (action_size x
action_samples x batchsize). Return a 3D matrix of size (1 x action_samples x
batchsize).
"""
function mvnormlogpdf(μ::A, LorU::A, x::A; ϵ=1.0f-8) where {A<:AbstractArray}
it = zip(eachslice(μ, dims = 3), eachslice(LorU, dims = 3), eachslice(x, dims = 3))
Expand Down
Loading
Loading