From 38e6f593085b34578ec910ead02219cbcc1770c4 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Thu, 29 Jul 2021 00:30:16 +0000
Subject: [PATCH] Format .jl files
---
.../CFR/JuliaRL_DeepCFR_OpenSpiel.jl | 10 +-
.../CFR/JuliaRL_TabularCFR_OpenSpiel.jl | 10 +-
.../experiments/DQN/Dopamine_DQN_Atari.jl | 74 ++-
.../experiments/DQN/Dopamine_IQN_Atari.jl | 73 +--
.../experiments/DQN/Dopamine_Rainbow_Atari.jl | 71 ++-
.../DQN/JuliaRL_BasicDQN_CartPole.jl | 2 +-
.../DQN/JuliaRL_BasicDQN_MountainCar.jl | 2 +-
.../JuliaRL_BasicDQN_SingleRoomUndirected.jl | 48 +-
.../experiments/DQN/JuliaRL_DQN_CartPole.jl | 12 +-
.../DQN/JuliaRL_DQN_MountainCar.jl | 2 +-
.../experiments/DQN/JuliaRL_IQN_CartPole.jl | 2 +-
.../experiments/DQN/JuliaRL_QRDQN_Cartpole.jl | 52 +-
.../DQN/JuliaRL_REMDQN_CartPole.jl | 4 +-
.../DQN/JuliaRL_Rainbow_CartPole.jl | 2 +-
.../Offline/JuliaRL_BC_CartPole.jl | 10 +-
.../JuliaRL_A2CGAE_CartPole.jl | 4 +-
.../Policy Gradient/JuliaRL_A2C_CartPole.jl | 4 +-
.../Policy Gradient/JuliaRL_DDPG_Pendulum.jl | 4 +-
.../Policy Gradient/JuliaRL_MAC_CartPole.jl | 4 +-
.../Policy Gradient/JuliaRL_PPO_CartPole.jl | 4 +-
.../Policy Gradient/JuliaRL_PPO_Pendulum.jl | 7 +-
.../Policy Gradient/JuliaRL_SAC_Pendulum.jl | 11 +-
.../Policy Gradient/JuliaRL_TD3_Pendulum.jl | 2 +-
.../Policy Gradient/JuliaRL_VPG_CartPole.jl | 2 +-
.../Policy Gradient/rlpyt_A2C_Atari.jl | 10 +-
.../Policy Gradient/rlpyt_PPO_Atari.jl | 12 +-
.../Search/JuliaRL_Minimax_OpenSpiel.jl | 8 +-
docs/homepage/utils.jl | 31 +-
docs/make.jl | 14 +-
.../src/actor_model.jl | 17 +-
.../src/core.jl | 13 +-
.../src/extensions.jl | 2 +-
.../test/actor.jl | 72 +--
.../test/core.jl | 311 +++++-----
.../test/runtests.jl | 4 +-
.../src/CommonRLInterface.jl | 5 +-
.../src/interface.jl | 9 +-
.../test/CommonRLInterface.jl | 50 +-
.../test/runtests.jl | 4 +-
.../src/core/hooks.jl | 29 +-
.../src/extensions/ArrayInterface.jl | 7 +-
.../trajectories/trajectory_extension.jl | 2 +-
.../neural_network_approximator.jl | 26 +-
.../test/components/trajectories.jl | 4 +-
.../test/core/core.jl | 12 +-
.../test/core/stop_conditions_test.jl | 5 +-
.../src/ReinforcementLearningDatasets.jl | 2 +-
.../src/d4rl/d4rl_dataset.jl | 70 ++-
.../src/d4rl/register.jl | 583 ++++++++++--------
.../test/d4rl/d4rl_dataset.jl | 12 +-
.../src/ReinforcementLearningEnvironments.jl | 6 +-
.../src/environments/3rd_party/AcrobotEnv.jl | 42 +-
.../src/environments/3rd_party/open_spiel.jl | 44 +-
.../src/environments/3rd_party/structs.jl | 4 +-
.../wrappers/ActionTransformedEnv.jl | 2 +-
.../wrappers/DefaultStateStyle.jl | 9 +-
.../environments/wrappers/SequentialEnv.jl | 6 +-
.../environments/wrappers/StateCachedEnv.jl | 2 +-
.../wrappers/StateTransformedEnv.jl | 4 +-
.../src/plots.jl | 64 +-
.../test/environments/wrappers/wrappers.jl | 20 +-
.../deps/build.jl | 7 +-
.../src/ReinforcementLearningExperiments.jl | 3 +-
.../src/algorithms/dqns/common.jl | 5 +-
.../src/algorithms/dqns/dqn.jl | 8 +-
.../src/algorithms/dqns/dqns.jl | 2 +-
.../src/algorithms/dqns/qr_dqn.jl | 51 +-
.../src/algorithms/dqns/rem_dqn.jl | 7 +-
.../algorithms/offline_rl/behavior_cloning.jl | 23 +-
.../src/algorithms/policy_gradient/ddpg.jl | 9 +-
.../src/algorithms/policy_gradient/ppo.jl | 11 +-
.../src/algorithms/policy_gradient/sac.jl | 15 +-
.../src/algorithms/tabular/tabular_policy.jl | 2 +-
test/runtests.jl | 3 +-
74 files changed, 1139 insertions(+), 940 deletions(-)
diff --git a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl
index ff647a3a4..2eeecbd74 100644
--- a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl
+++ b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl
@@ -61,5 +61,11 @@ function RL.Experiment(
batch_size_Π = 2048,
initializer = glorot_normal(CUDA.CURAND.default_rng()),
)
- Experiment(p, env, StopAfterStep(500, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), "# run DeepcCFR on leduc_poker")
-end
\ No newline at end of file
+ Experiment(
+ p,
+ env,
+ StopAfterStep(500, is_show_progress = !haskey(ENV, "CI")),
+ EmptyHook(),
+ "# run DeepcCFR on leduc_poker",
+ )
+end
diff --git a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl
index edfd7f199..d89cabb16 100644
--- a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl
+++ b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl
@@ -23,8 +23,14 @@ function RL.Experiment(
π = TabularCFRPolicy(; rng = rng)
description = "# Play `$game` in OpenSpiel with TabularCFRPolicy"
- Experiment(π, env, StopAfterStep(300, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), description)
+ Experiment(
+ π,
+ env,
+ StopAfterStep(300, is_show_progress = !haskey(ENV, "CI")),
+ EmptyHook(),
+ description,
+ )
end
ex = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)`
-run(ex)
\ No newline at end of file
+run(ex)
diff --git a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl
index f51b4a2a9..8cee680cc 100644
--- a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl
+++ b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl
@@ -79,39 +79,35 @@ function atari_env_factory(
repeat_action_probability = 0.25,
n_replica = nothing,
)
- init(seed) =
- RewardOverriddenEnv(
- StateCachedEnv(
- StateTransformedEnv(
- AtariEnv(;
- name = string(name),
- grayscale_obs = true,
- noop_max = 30,
- frame_skip = 4,
- terminal_on_life_loss = false,
- repeat_action_probability = repeat_action_probability,
- max_num_frames_per_episode = n_frames * max_episode_steps,
- color_averaging = false,
- full_action_space = false,
- seed = seed,
- );
- state_mapping=Chain(
- ResizeImage(state_size...),
- StackFrames(state_size..., n_frames)
- ),
- state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
- )
+ init(seed) = RewardOverriddenEnv(
+ StateCachedEnv(
+ StateTransformedEnv(
+ AtariEnv(;
+ name = string(name),
+ grayscale_obs = true,
+ noop_max = 30,
+ frame_skip = 4,
+ terminal_on_life_loss = false,
+ repeat_action_probability = repeat_action_probability,
+ max_num_frames_per_episode = n_frames * max_episode_steps,
+ color_averaging = false,
+ full_action_space = false,
+ seed = seed,
+ );
+ state_mapping = Chain(
+ ResizeImage(state_size...),
+ StackFrames(state_size..., n_frames),
+ ),
+ state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
),
- r -> clamp(r, -1, 1)
- )
+ ),
+ r -> clamp(r, -1, 1),
+ )
if isnothing(n_replica)
init(seed)
else
- envs = [
- init(isnothing(seed) ? nothing : hash(seed + i))
- for i in 1:n_replica
- ]
+ envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
states = Flux.batch(state.(envs))
rewards = reward.(envs)
terminals = is_terminated.(envs)
@@ -172,7 +168,7 @@ function RL.Experiment(
::Val{:Atari},
name::AbstractString;
save_dir = nothing,
- seed = nothing
+ seed = nothing,
)
rng = Random.GLOBAL_RNG
Random.seed!(rng, seed)
@@ -190,7 +186,7 @@ function RL.Experiment(
name,
STATE_SIZE,
N_FRAMES;
- seed = isnothing(seed) ? nothing : hash(seed + 1)
+ seed = isnothing(seed) ? nothing : hash(seed + 1),
)
N_ACTIONS = length(action_space(env))
init = glorot_uniform(rng)
@@ -254,17 +250,15 @@ function RL.Experiment(
end,
DoEveryNEpisode() do t, agent, env
with_logger(lg) do
- @info "training" episode_length = step_per_episode.steps[end] reward = reward_per_episode.rewards[end] log_step_increment = 0
+ @info "training" episode_length = step_per_episode.steps[end] reward =
+ reward_per_episode.rewards[end] log_step_increment = 0
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
p = agent.policy
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
- h = ComposedHook(
- TotalOriginalRewardPerEpisode(),
- StepsPerEpisode(),
- )
+ h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode())
s = @elapsed run(
p,
atari_env_factory(
@@ -281,16 +275,18 @@ function RL.Experiment(
avg_score = mean(h[1].rewards[1:end-1])
avg_length = mean(h[2].steps[1:end-1])
- @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
+ @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
+ avg_score
with_logger(lg) do
- @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
+ @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
+ 0
end
end,
)
stop_condition = StopAfterStep(
haskey(ENV, "CI") ? 1_000 : 50_000_000,
- is_show_progress=!haskey(ENV, "CI")
+ is_show_progress = !haskey(ENV, "CI"),
)
Experiment(agent, env, stop_condition, hook, "# DQN <-> Atari($name)")
end
diff --git a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl
index 6e0305ae5..7f5c8e2e7 100644
--- a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl
+++ b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl
@@ -84,39 +84,35 @@ function atari_env_factory(
repeat_action_probability = 0.25,
n_replica = nothing,
)
- init(seed) =
- RewardOverriddenEnv(
- StateCachedEnv(
- StateTransformedEnv(
- AtariEnv(;
- name = string(name),
- grayscale_obs = true,
- noop_max = 30,
- frame_skip = 4,
- terminal_on_life_loss = false,
- repeat_action_probability = repeat_action_probability,
- max_num_frames_per_episode = n_frames * max_episode_steps,
- color_averaging = false,
- full_action_space = false,
- seed = seed,
- );
- state_mapping=Chain(
- ResizeImage(state_size...),
- StackFrames(state_size..., n_frames)
- ),
- state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
- )
+ init(seed) = RewardOverriddenEnv(
+ StateCachedEnv(
+ StateTransformedEnv(
+ AtariEnv(;
+ name = string(name),
+ grayscale_obs = true,
+ noop_max = 30,
+ frame_skip = 4,
+ terminal_on_life_loss = false,
+ repeat_action_probability = repeat_action_probability,
+ max_num_frames_per_episode = n_frames * max_episode_steps,
+ color_averaging = false,
+ full_action_space = false,
+ seed = seed,
+ );
+ state_mapping = Chain(
+ ResizeImage(state_size...),
+ StackFrames(state_size..., n_frames),
+ ),
+ state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
),
- r -> clamp(r, -1, 1)
- )
+ ),
+ r -> clamp(r, -1, 1),
+ )
if isnothing(n_replica)
init(seed)
else
- envs = [
- init(isnothing(seed) ? nothing : hash(seed + i))
- for i in 1:n_replica
- ]
+ envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
states = Flux.batch(state.(envs))
rewards = reward.(envs)
terminals = is_terminated.(envs)
@@ -195,7 +191,12 @@ function RL.Experiment(
N_FRAMES = 4
STATE_SIZE = (84, 84)
- env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 2))
+ env = atari_env_factory(
+ name,
+ STATE_SIZE,
+ N_FRAMES;
+ seed = isnothing(seed) ? nothing : hash(seed + 2),
+ )
N_ACTIONS = length(action_space(env))
Nₑₘ = 64
@@ -250,7 +251,7 @@ function RL.Experiment(
),
),
trajectory = CircularArraySARTTrajectory(
- capacity = haskey(ENV, "CI") : 1_000 : 1_000_000,
+ capacity = haskey(ENV, "CI"):1_000:1_000_000,
state = Matrix{Float32} => STATE_SIZE,
),
)
@@ -274,7 +275,7 @@ function RL.Experiment(
steps_per_episode.steps[end] log_step_increment = 0
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
p = agent.policy
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
@@ -286,7 +287,7 @@ function RL.Experiment(
STATE_SIZE,
N_FRAMES,
MAX_EPISODE_STEPS_EVAL;
- seed = isnothing(seed) ? nothing : hash(seed + t)
+ seed = isnothing(seed) ? nothing : hash(seed + t),
),
StopAfterStep(125_000; is_show_progress = false),
h,
@@ -295,16 +296,18 @@ function RL.Experiment(
avg_score = mean(h[1].rewards[1:end-1])
avg_length = mean(h[2].steps[1:end-1])
- @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
+ @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
+ avg_score
with_logger(lg) do
- @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
+ @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
+ 0
end
end,
)
stop_condition = StopAfterStep(
haskey(ENV, "CI") ? 10_000 : 50_000_000,
- is_show_progress=!haskey(ENV, "CI")
+ is_show_progress = !haskey(ENV, "CI"),
)
Experiment(agent, env, stop_condition, hook, "# IQN <-> Atari($name)")
end
diff --git a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl
index 432e110e4..fd6d76c66 100644
--- a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl
+++ b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl
@@ -83,39 +83,35 @@ function atari_env_factory(
repeat_action_probability = 0.25,
n_replica = nothing,
)
- init(seed) =
- RewardOverriddenEnv(
- StateCachedEnv(
- StateTransformedEnv(
- AtariEnv(;
- name = string(name),
- grayscale_obs = true,
- noop_max = 30,
- frame_skip = 4,
- terminal_on_life_loss = false,
- repeat_action_probability = repeat_action_probability,
- max_num_frames_per_episode = n_frames * max_episode_steps,
- color_averaging = false,
- full_action_space = false,
- seed = seed,
- );
- state_mapping=Chain(
- ResizeImage(state_size...),
- StackFrames(state_size..., n_frames)
- ),
- state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
- )
+ init(seed) = RewardOverriddenEnv(
+ StateCachedEnv(
+ StateTransformedEnv(
+ AtariEnv(;
+ name = string(name),
+ grayscale_obs = true,
+ noop_max = 30,
+ frame_skip = 4,
+ terminal_on_life_loss = false,
+ repeat_action_probability = repeat_action_probability,
+ max_num_frames_per_episode = n_frames * max_episode_steps,
+ color_averaging = false,
+ full_action_space = false,
+ seed = seed,
+ );
+ state_mapping = Chain(
+ ResizeImage(state_size...),
+ StackFrames(state_size..., n_frames),
+ ),
+ state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)),
),
- r -> clamp(r, -1, 1)
- )
+ ),
+ r -> clamp(r, -1, 1),
+ )
if isnothing(n_replica)
init(seed)
else
- envs = [
- init(isnothing(seed) ? nothing : hash(seed + i))
- for i in 1:n_replica
- ]
+ envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
states = Flux.batch(state.(envs))
rewards = reward.(envs)
terminals = is_terminated.(envs)
@@ -191,7 +187,12 @@ function RL.Experiment(
N_FRAMES = 4
STATE_SIZE = (84, 84)
- env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 1))
+ env = atari_env_factory(
+ name,
+ STATE_SIZE,
+ N_FRAMES;
+ seed = isnothing(seed) ? nothing : hash(seed + 1),
+ )
N_ACTIONS = length(action_space(env))
N_ATOMS = 51
init = glorot_uniform(rng)
@@ -238,7 +239,7 @@ function RL.Experiment(
),
),
trajectory = CircularArrayPSARTTrajectory(
- capacity = haskey(ENV, "CI") : 1_000 : 1_000_000,
+ capacity = haskey(ENV, "CI"):1_000:1_000_000,
state = Matrix{Float32} => STATE_SIZE,
),
)
@@ -262,7 +263,7 @@ function RL.Experiment(
steps_per_episode.steps[end] log_step_increment = 0
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
p = agent.policy
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
@@ -282,16 +283,18 @@ function RL.Experiment(
avg_length = mean(h[2].steps[1:end-1])
avg_score = mean(h[1].rewards[1:end-1])
- @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
+ @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
+ avg_score
with_logger(lg) do
- @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
+ @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
+ 0
end
end,
)
stop_condition = StopAfterStep(
haskey(ENV, "CI") ? 10_000 : 50_000_000,
- is_show_progress=!haskey(ENV, "CI")
+ is_show_progress = !haskey(ENV, "CI"),
)
Experiment(agent, env, stop_condition, hook, "# Rainbow <-> Atari($name)")
diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
index 9f32c48d6..d5ba9c9c7 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
@@ -51,7 +51,7 @@ function RL.Experiment(
state = Vector{Float32} => (ns,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(policy, env, stop_condition, hook, "# BasicDQN <-> CartPole")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl
index ae8f02cb5..bc79c94be 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl
@@ -51,7 +51,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(70_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(70_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl
index 4b7f2a5cd..39748e307 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl
@@ -18,47 +18,47 @@ function RL.Experiment(
::Val{:BasicDQN},
::Val{:SingleRoomUndirected},
::Nothing;
- seed=123,
+ seed = 123,
)
rng = StableRNG(seed)
- env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng=rng)
+ env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng = rng)
env = GridWorlds.RLBaseEnv(env)
- env = RLEnvs.StateTransformedEnv(env;state_mapping=x -> vec(Float32.(x)))
+ env = RLEnvs.StateTransformedEnv(env; state_mapping = x -> vec(Float32.(x)))
env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01))
env = MaxTimeoutEnv(env, 240)
ns, na = length(state(env)), length(action_space(env))
agent = Agent(
- policy=QBasedPolicy(
- learner=BasicDQNLearner(
- approximator=NeuralNetworkApproximator(
- model=Chain(
- Dense(ns, 128, relu; init=glorot_uniform(rng)),
- Dense(128, 128, relu; init=glorot_uniform(rng)),
- Dense(128, na; init=glorot_uniform(rng)),
+ policy = QBasedPolicy(
+ learner = BasicDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
+ Dense(ns, 128, relu; init = glorot_uniform(rng)),
+ Dense(128, 128, relu; init = glorot_uniform(rng)),
+ Dense(128, na; init = glorot_uniform(rng)),
) |> cpu,
- optimizer=ADAM(),
+ optimizer = ADAM(),
),
- batch_size=32,
- min_replay_history=100,
- loss_func=huber_loss,
- rng=rng,
+ batch_size = 32,
+ min_replay_history = 100,
+ loss_func = huber_loss,
+ rng = rng,
),
- explorer=EpsilonGreedyExplorer(
- kind=:exp,
- ϵ_stable=0.01,
- decay_steps=500,
- rng=rng,
+ explorer = EpsilonGreedyExplorer(
+ kind = :exp,
+ ϵ_stable = 0.01,
+ decay_steps = 500,
+ rng = rng,
),
),
- trajectory=CircularArraySARTTrajectory(
- capacity=1000,
- state=Vector{Float32} => (ns,),
+ trajectory = CircularArraySARTTrajectory(
+ capacity = 1000,
+ state = Vector{Float32} => (ns,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
index 7e922e13e..7e2e218f6 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
@@ -14,13 +14,13 @@ using Flux.Losses
function build_dueling_network(network::Chain)
lm = length(network)
- if !(network[lm] isa Dense) || !(network[lm-1] isa Dense)
+ if !(network[lm] isa Dense) || !(network[lm-1] isa Dense)
error("The Qnetwork provided is incompatible with dueling.")
end
- base = Chain([deepcopy(network[i]) for i=1:lm-2]...)
+ base = Chain([deepcopy(network[i]) for i in 1:lm-2]...)
last_layer_dims = size(network[lm].weight, 2)
val = Chain(deepcopy(network[lm-1]), Dense(last_layer_dims, 1))
- adv = Chain([deepcopy(network[i]) for i=lm-1:lm]...)
+ adv = Chain([deepcopy(network[i]) for i in lm-1:lm]...)
return DuelingNetwork(base, val, adv)
end
@@ -37,8 +37,8 @@ function RL.Experiment(
base_model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
- Dense(128, na; init = glorot_uniform(rng))
- )
+ Dense(128, na; init = glorot_uniform(rng)),
+ )
agent = Agent(
policy = QBasedPolicy(
@@ -72,7 +72,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl
index d8b1eb633..f74bcaea1 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl
@@ -64,7 +64,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(40_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(40_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
index f3ab3c98f..cba0fcea5 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
@@ -71,7 +71,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl
index 7fb238d28..c45bb0b03 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl
@@ -17,58 +17,58 @@ function RL.Experiment(
::Val{:QRDQN},
::Val{:CartPole},
::Nothing;
- seed=123,
+ seed = 123,
)
N = 10
rng = StableRNG(seed)
- env = CartPoleEnv(; T=Float32, rng=rng)
+ env = CartPoleEnv(; T = Float32, rng = rng)
ns, na = length(state(env)), length(action_space(env))
init = glorot_uniform(rng)
agent = Agent(
- policy=QBasedPolicy(
- learner=QRDQNLearner(
- approximator=NeuralNetworkApproximator(
- model=Chain(
+ policy = QBasedPolicy(
+ learner = QRDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
Dense(ns, 128, relu; init = init),
Dense(128, 128, relu; init = init),
Dense(128, N * na; init = init),
) |> cpu,
- optimizer=ADAM(),
+ optimizer = ADAM(),
),
- target_approximator=NeuralNetworkApproximator(
- model=Chain(
+ target_approximator = NeuralNetworkApproximator(
+ model = Chain(
Dense(ns, 128, relu; init = init),
Dense(128, 128, relu; init = init),
Dense(128, N * na; init = init),
) |> cpu,
),
- stack_size=nothing,
- batch_size=32,
- update_horizon=1,
- min_replay_history=100,
- update_freq=1,
- target_update_freq=100,
- n_quantile=N,
- rng=rng,
+ stack_size = nothing,
+ batch_size = 32,
+ update_horizon = 1,
+ min_replay_history = 100,
+ update_freq = 1,
+ target_update_freq = 100,
+ n_quantile = N,
+ rng = rng,
),
- explorer=EpsilonGreedyExplorer(
- kind=:exp,
- ϵ_stable=0.01,
- decay_steps=500,
- rng=rng,
+ explorer = EpsilonGreedyExplorer(
+ kind = :exp,
+ ϵ_stable = 0.01,
+ decay_steps = 500,
+ rng = rng,
),
),
- trajectory=CircularArraySARTTrajectory(
- capacity=1000,
- state=Vector{Float32} => (ns,),
+ trajectory = CircularArraySARTTrajectory(
+ capacity = 1000,
+ state = Vector{Float32} => (ns,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
index fdf473a83..7f74dd096 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
@@ -52,7 +52,7 @@ function RL.Experiment(
update_freq = 1,
target_update_freq = 100,
ensemble_num = ensemble_num,
- ensemble_method = :rand,
+ ensemble_method = :rand,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
@@ -68,7 +68,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
index f367d1cf5..d8f5d2437 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
@@ -71,7 +71,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl
index f1e42bb51..808719818 100644
--- a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl
+++ b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl
@@ -61,7 +61,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = RecordStateAction()
run(agent, env, stop_condition, hook)
@@ -84,7 +84,13 @@ function RL.Experiment(
end
hook = TotalRewardPerEpisode()
- Experiment(bc, env, StopAfterEpisode(100, is_show_progress=!haskey(ENV, "CI")), hook, "BehaviorCloning <-> CartPole")
+ Experiment(
+ bc,
+ env,
+ StopAfterEpisode(100, is_show_progress = !haskey(ENV, "CI")),
+ hook,
+ "BehaviorCloning <-> CartPole",
+ )
end
#+ tangle=false
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl
index eeed90609..2ea78b92d 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl
@@ -63,7 +63,7 @@ function RL.Experiment(
terminal = Vector{Bool} => (N_ENV,),
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# A2CGAE with CartPole")
end
@@ -78,7 +78,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_A2CGAE_CartPole.png") #hide
# ![](assets/JuliaRL_A2CGAE_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl
index cb03dd025..e56733a92 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl
@@ -59,7 +59,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=true)
+ stop_condition = StopAfterStep(50_000, is_show_progress = true)
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# A2C with CartPole")
end
@@ -73,7 +73,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_A2C_CartPole.png") #hide
# ![](assets/JuliaRL_A2C_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl
index 4e04ad1b6..9a87476f7 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl
@@ -79,11 +79,11 @@ function RL.Experiment(
trajectory = CircularArraySARTTrajectory(
capacity = 10000,
state = Vector{Float32} => (ns,),
- action = Float32 => (na, ),
+ action = Float32 => (na,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with DDPG")
end
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl
index 3559b1b03..c526fcfe3 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl
@@ -64,7 +64,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# MAC with CartPole")
end
@@ -78,7 +78,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_MAC_CartPole.png") #hide
# ![](assets/JuliaRL_MAC_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl
index 45cdd2e13..cbc0e3340 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl
@@ -62,7 +62,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# PPO with CartPole")
end
@@ -76,7 +76,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_PPO_CartPole.png") #hide
# ![](assets/JuliaRL_PPO_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl
index 80625e104..1fd85f10c 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl
@@ -32,7 +32,8 @@ function RL.Experiment(
UPDATE_FREQ = 2048
env = MultiThreadEnv([
PendulumEnv(T = Float32, rng = StableRNG(hash(seed + i))) |>
- env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high)) for i in 1:N_ENV
+ env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high))
+ for i in 1:N_ENV
])
init = glorot_uniform(rng)
@@ -78,7 +79,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with PPO")
end
@@ -92,7 +93,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_PPO_Pendulum.png") #hide
# ![](assets/JuliaRL_PPO_Pendulum.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl
index 7df613dc1..8fad88732 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl
@@ -38,12 +38,11 @@ function RL.Experiment(
create_policy_net() = NeuralNetworkApproximator(
model = GaussianNetwork(
- pre = Chain(
- Dense(ns, 30, relu),
- Dense(30, 30, relu),
- ),
+ pre = Chain(Dense(ns, 30, relu), Dense(30, 30, relu)),
μ = Chain(Dense(30, na, init = init)),
- logσ = Chain(Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init)),
+ logσ = Chain(
+ Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init),
+ ),
),
optimizer = ADAM(0.003),
)
@@ -84,7 +83,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with SAC")
end
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl
index f0545222f..f82bd8ee3 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl
@@ -86,7 +86,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with TD3")
end
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl
index 87130f3dc..8cfbe125d 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl
@@ -49,7 +49,7 @@ function RL.Experiment(
),
trajectory = ElasticSARTTrajectory(state = Vector{Float32} => (ns,)),
)
- stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterEpisode(500, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
description = "# Play CartPole with VPG"
diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl
index 7bf2add71..56eb5d90e 100644
--- a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl
+++ b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl
@@ -83,7 +83,7 @@ function RL.Experiment(
hook = ComposedHook(
total_batch_reward_per_episode,
batch_steps_per_episode,
- DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env
+ DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env
learner = agent.policy.policy.learner
with_logger(lg) do
@info "training" loss = learner.loss actor_loss = learner.actor_loss critic_loss =
@@ -94,20 +94,22 @@ function RL.Experiment(
DoEveryNStep() do t, agent, env
with_logger(lg) do
rewards = [
- total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i])
+ total_batch_reward_per_episode.rewards[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(rewards) > 0
@info "training" rewards = mean(rewards) log_step_increment = 0
end
steps = [
- batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i])
+ batch_steps_per_episode.steps[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(steps) > 0
@info "training" steps = mean(steps) log_step_increment = 0
end
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
h = TotalBatchOriginalRewardPerEpisode(N_ENV)
s = @elapsed run(
diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl
index e9eb574a3..4b33dc27f 100644
--- a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl
+++ b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl
@@ -85,7 +85,7 @@ function RL.Experiment(
hook = ComposedHook(
total_batch_reward_per_episode,
batch_steps_per_episode,
- DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env
+ DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env
p = agent.policy
with_logger(lg) do
@info "training" loss = mean(p.loss) actor_loss = mean(p.actor_loss) critic_loss =
@@ -93,7 +93,7 @@ function RL.Experiment(
mean(p.norm) log_step_increment = UPDATE_FREQ
end
end,
- DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env
+ DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env
decay = (N_TRAINING_STEPS - t) / N_TRAINING_STEPS
agent.policy.approximator.optimizer.eta = INIT_LEARNING_RATE * decay
agent.policy.clip_range = INIT_CLIP_RANGE * Float32(decay)
@@ -101,20 +101,22 @@ function RL.Experiment(
DoEveryNStep() do t, agent, env
with_logger(lg) do
rewards = [
- total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i])
+ total_batch_reward_per_episode.rewards[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(rewards) > 0
@info "training" rewards = mean(rewards) log_step_increment = 0
end
steps = [
- batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i])
+ batch_steps_per_episode.steps[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(steps) > 0
@info "training" steps = mean(steps) log_step_increment = 0
end
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
## switch to GreedyExplorer?
h = TotalBatchRewardPerEpisode(N_ENV)
diff --git a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl
index f31098c00..a6bfc1d4b 100644
--- a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl
+++ b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl
@@ -18,7 +18,13 @@ function RL.Experiment(::Val{:JuliaRL}, ::Val{:Minimax}, ::Val{:OpenSpiel}, game
)
hooks = MultiAgentHook(0 => TotalRewardPerEpisode(), 1 => TotalRewardPerEpisode())
description = "# Play `$game` in OpenSpiel with Minimax"
- Experiment(agents, env, StopAfterEpisode(1, is_show_progress=!haskey(ENV, "CI")), hooks, description)
+ Experiment(
+ agents,
+ env,
+ StopAfterEpisode(1, is_show_progress = !haskey(ENV, "CI")),
+ hooks,
+ description,
+ )
end
using Plots
diff --git a/docs/homepage/utils.jl b/docs/homepage/utils.jl
index 816810d71..64787a91a 100644
--- a/docs/homepage/utils.jl
+++ b/docs/homepage/utils.jl
@@ -5,7 +5,7 @@ html(s) = "\n~~~$s~~~\n"
function hfun_adddescription()
d = locvar(:description)
- isnothing(d) ? "" : F.fd2html(d, internal=true)
+ isnothing(d) ? "" : F.fd2html(d, internal = true)
end
function hfun_frontmatter()
@@ -28,7 +28,7 @@ function hfun_byline()
if isnothing(fm)
""
else
- ""
+ ""
end
end
@@ -62,7 +62,7 @@ function hfun_appendix()
if isfile(bib_in_cur_folder)
bib_resolved = F.parse_rpath("/" * bib_in_cur_folder)
else
- bib_resolved = F.parse_rpath(bib; canonical=false, code=true)
+ bib_resolved = F.parse_rpath(bib; canonical = false, code = true)
end
bib = ""
end
@@ -74,7 +74,7 @@ function hfun_appendix()
"""
end
-function lx_dcite(lxc,_)
+function lx_dcite(lxc, _)
content = F.content(lxc.braces[1])
"" |> html
end
@@ -92,7 +92,7 @@ end
"""
Possible layouts:
"""
-function lx_dfig(lxc,lxd)
+function lx_dfig(lxc, lxd)
content = F.content(lxc.braces[1])
info = split(content, ';')
layout = info[1]
@@ -111,7 +111,7 @@ function lx_dfig(lxc,lxd)
end
# (case 3) assume it is generated by code
- src = F.parse_rpath(src; canonical=false, code=true)
+ src = F.parse_rpath(src; canonical = false, code = true)
# !!! directly take from `lx_fig` in Franklin.jl
fdir, fext = splitext(src)
@@ -122,11 +122,10 @@ function lx_dfig(lxc,lxd)
# then in both cases there can be a relative path set but the user may mean
# that it's in the subfolder /output/ (if generated by code) so should look
# both in the relpath and if not found and if /output/ not already last dir
- candext = ifelse(isempty(fext),
- (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,))
- for ext ∈ candext
+ candext = ifelse(isempty(fext), (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,))
+ for ext in candext
candpath = fdir * ext
- syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
+ syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
isfile(syspath) && return dfigure(layout, candpath, caption)
end
# now try in the output dir just in case (provided we weren't already
@@ -134,20 +133,20 @@ function lx_dfig(lxc,lxd)
p1, p2 = splitdir(fdir)
@debug "TEST" p1 p2
if splitdir(p1)[2] != "output"
- for ext ∈ candext
+ for ext in candext
candpath = joinpath(splitdir(p1)[1], "output", p2 * ext)
- syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
+ syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
isfile(syspath) && return dfigure(layout, candpath, caption)
end
end
end
-function lx_aside(lxc,lxd)
+function lx_aside(lxc, lxd)
content = F.reprocess(F.content(lxc.braces[1]), lxd)
"" |> html
end
-function lx_footnote(lxc,lxd)
+function lx_footnote(lxc, lxd)
content = F.reprocess(F.content(lxc.braces[1]), lxd)
# workaround
if startswith(content, "
")
@@ -156,7 +155,7 @@ function lx_footnote(lxc,lxd)
"$content" |> html
end
-function lx_appendix(lxc,lxd)
+function lx_appendix(lxc, lxd)
content = F.reprocess(F.content(lxc.braces[1]), lxd)
"$content" |> html
-end
\ No newline at end of file
+end
diff --git a/docs/make.jl b/docs/make.jl
index bd61a9b1c..06a580ea6 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -14,11 +14,7 @@ end
experiments, postprocess_cb, experiments_assets = makedemos("experiments")
-assets = [
- "assets/favicon.ico",
- "assets/custom.css",
- experiments_assets
-]
+assets = ["assets/favicon.ico", "assets/custom.css", experiments_assets]
makedocs(
modules = [
@@ -39,8 +35,10 @@ makedocs(
"Home" => "index.md",
"Tutorial" => "tutorial.md",
"Guides" => [
- "How to write a customized environment?" => "How_to_write_a_customized_environment.md",
- "How to implement a new algorithm?" => "How_to_implement_a_new_algorithm.md",
+ "How to write a customized environment?" =>
+ "How_to_write_a_customized_environment.md",
+ "How to implement a new algorithm?" =>
+ "How_to_implement_a_new_algorithm.md",
"How to use hooks?" => "How_to_use_hooks.md",
"Which algorithm should I use?" => "Which_algorithm_should_I_use.md",
],
@@ -53,7 +51,7 @@ makedocs(
"RLEnvs" => "rlenvs.md",
"RLZoo" => "rlzoo.md",
],
- ]
+ ],
)
postprocess_cb()
diff --git a/src/DistributedReinforcementLearning/src/actor_model.jl b/src/DistributedReinforcementLearning/src/actor_model.jl
index 3a2414c5c..af6b05214 100644
--- a/src/DistributedReinforcementLearning/src/actor_model.jl
+++ b/src/DistributedReinforcementLearning/src/actor_model.jl
@@ -1,19 +1,12 @@
-export AbstractMessage,
- StartMsg,
- StopMsg,
- PingMsg,
- PongMsg,
- ProxyMsg,
- actor,
- self
+export AbstractMessage, StartMsg, StopMsg, PingMsg, PongMsg, ProxyMsg, actor, self
abstract type AbstractMessage end
-struct StartMsg{A, K} <: AbstractMessage
+struct StartMsg{A,K} <: AbstractMessage
args::A
kwargs::K
- StartMsg(args...;kwargs...) = new{typeof(args), typeof(kwargs)}(args, kwargs)
+ StartMsg(args...; kwargs...) = new{typeof(args),typeof(kwargs)}(args, kwargs)
end
struct StopMsg <: AbstractMessage end
@@ -45,9 +38,9 @@ const DEFAULT_MAILBOX_SIZE = 32
Create a task to handle messages one-by-one by calling `f(msg)`.
A mailbox (`RemoteChannel`) is returned.
"""
-function actor(f;sz=DEFAULT_MAILBOX_SIZE)
+function actor(f; sz = DEFAULT_MAILBOX_SIZE)
RemoteChannel() do
- Channel(sz;spawn=true) do ch
+ Channel(sz; spawn = true) do ch
task_local_storage("MAILBOX", RemoteChannel(() -> ch))
while true
msg = take!(ch)
diff --git a/src/DistributedReinforcementLearning/src/core.jl b/src/DistributedReinforcementLearning/src/core.jl
index 44ed67937..34a04571a 100644
--- a/src/DistributedReinforcementLearning/src/core.jl
+++ b/src/DistributedReinforcementLearning/src/core.jl
@@ -51,7 +51,7 @@ Base.@kwdef struct Trainer{P,S}
sealer::S = deepcopy
end
-Trainer(p) = Trainer(;policy=p)
+Trainer(p) = Trainer(; policy = p)
function (trainer::Trainer)(msg::BatchDataMsg)
update!(trainer.policy, msg.data)
@@ -94,7 +94,7 @@ mutable struct Worker
end
function (w::Worker)(msg::StartMsg)
- w.experiment = w.init(msg.args...;msg.kwargs...)
+ w.experiment = w.init(msg.args...; msg.kwargs...)
w.task = Threads.@spawn run(w.experiment)
end
@@ -128,7 +128,7 @@ end
function (wp::WorkerProxy)(::FetchParamMsg)
if !wp.is_fetch_msg_sent[]
put!(wp.target, FetchParamMsg(self()))
- wp.is_fetch_msg_sent[] = true
+ wp.is_fetch_msg_sent[] = true
end
end
@@ -172,9 +172,12 @@ function (orc::Orchestrator)(msg::InsertTrajectoryMsg)
put!(orc.trajectory_proxy, BatchSampleMsg(orc.trainer))
L.n_sample += 1
if L.n_sample == (L.n_load + 1) * L.sample_load_ratio
- put!(orc.trajectory_proxy, ProxyMsg(to=orc.trainer, msg=FetchParamMsg(orc.worker)))
+ put!(
+ orc.trajectory_proxy,
+ ProxyMsg(to = orc.trainer, msg = FetchParamMsg(orc.worker)),
+ )
L.n_load += 1
end
end
end
-end
\ No newline at end of file
+end
diff --git a/src/DistributedReinforcementLearning/src/extensions.jl b/src/DistributedReinforcementLearning/src/extensions.jl
index efeda428b..1c6f32e9d 100644
--- a/src/DistributedReinforcementLearning/src/extensions.jl
+++ b/src/DistributedReinforcementLearning/src/extensions.jl
@@ -76,4 +76,4 @@ function (hook::FetchParamsHook)(::PostActStage, agent, env)
end
end
end
-end
\ No newline at end of file
+end
diff --git a/src/DistributedReinforcementLearning/test/actor.jl b/src/DistributedReinforcementLearning/test/actor.jl
index 34881c5b1..e2cc9ae87 100644
--- a/src/DistributedReinforcementLearning/test/actor.jl
+++ b/src/DistributedReinforcementLearning/test/actor.jl
@@ -1,53 +1,53 @@
@testset "basic tests" begin
-Base.@kwdef mutable struct TestActor
- state::Union{Nothing, Int} = nothing
-end
+ Base.@kwdef mutable struct TestActor
+ state::Union{Nothing,Int} = nothing
+ end
-struct CurrentStateMsg <: AbstractMessage
- state
-end
+ struct CurrentStateMsg <: AbstractMessage
+ state::Any
+ end
-Base.@kwdef struct ReadStateMsg <: AbstractMessage
- from = self()
-end
+ Base.@kwdef struct ReadStateMsg <: AbstractMessage
+ from = self()
+ end
-struct IncMsg <: AbstractMessage end
-struct DecMsg <: AbstractMessage end
+ struct IncMsg <: AbstractMessage end
+ struct DecMsg <: AbstractMessage end
-(x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1]
-(x::TestActor)(msg::StopMsg) = x.state = nothing
-(x::TestActor)(::IncMsg) = x.state += 1
-(x::TestActor)(::DecMsg) = x.state -= 1
-(x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state))
+ (x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1]
+ (x::TestActor)(msg::StopMsg) = x.state = nothing
+ (x::TestActor)(::IncMsg) = x.state += 1
+ (x::TestActor)(::DecMsg) = x.state -= 1
+ (x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state))
-x = actor(TestActor())
-put!(x, StartMsg(0))
+ x = actor(TestActor())
+ put!(x, StartMsg(0))
-put!(x, ReadStateMsg())
-@test take!(self()).state == 0
+ put!(x, ReadStateMsg())
+ @test take!(self()).state == 0
-@sync begin
- for _ in 1:100
- Threads.@spawn put!(x, IncMsg())
- Threads.@spawn put!(x, DecMsg())
- end
- for _ in 1:10
- for _ in 1:10
+ @sync begin
+ for _ in 1:100
Threads.@spawn put!(x, IncMsg())
+ Threads.@spawn put!(x, DecMsg())
end
for _ in 1:10
- Threads.@spawn put!(x, DecMsg())
+ for _ in 1:10
+ Threads.@spawn put!(x, IncMsg())
+ end
+ for _ in 1:10
+ Threads.@spawn put!(x, DecMsg())
+ end
end
end
-end
-put!(x, ReadStateMsg())
-@test take!(self()).state == 0
+ put!(x, ReadStateMsg())
+ @test take!(self()).state == 0
-y = actor(TestActor())
-put!(x, ProxyMsg(;to=y,msg=StartMsg(0)))
-put!(x, ProxyMsg(;to=y,msg=ReadStateMsg()))
-@test take!(self()).state == 0
+ y = actor(TestActor())
+ put!(x, ProxyMsg(; to = y, msg = StartMsg(0)))
+ put!(x, ProxyMsg(; to = y, msg = ReadStateMsg()))
+ @test take!(self()).state == 0
-end
\ No newline at end of file
+end
diff --git a/src/DistributedReinforcementLearning/test/core.jl b/src/DistributedReinforcementLearning/test/core.jl
index 25d5c2c49..d4a196f37 100644
--- a/src/DistributedReinforcementLearning/test/core.jl
+++ b/src/DistributedReinforcementLearning/test/core.jl
@@ -1,181 +1,202 @@
@testset "core.jl" begin
-@testset "Trainer" begin
- _trainer = Trainer(;
- policy=BasicDQNLearner(
- approximator = NeuralNetworkApproximator(
- model = Chain(
- Dense(4, 128, relu; initW = glorot_uniform),
- Dense(128, 128, relu; initW = glorot_uniform),
- Dense(128, 2; initW = glorot_uniform),
- ) |> cpu,
- optimizer = ADAM(),
+ @testset "Trainer" begin
+ _trainer = Trainer(;
+ policy = BasicDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
+ Dense(4, 128, relu; initW = glorot_uniform),
+ Dense(128, 128, relu; initW = glorot_uniform),
+ Dense(128, 2; initW = glorot_uniform),
+ ) |> cpu,
+ optimizer = ADAM(),
+ ),
+ loss_func = huber_loss,
),
- loss_func = huber_loss,
)
- )
- trainer = actor(_trainer)
+ trainer = actor(_trainer)
- put!(trainer, FetchParamMsg())
- ps = take!(self())
- original_sum = sum(sum, ps.data)
+ put!(trainer, FetchParamMsg())
+ ps = take!(self())
+ original_sum = sum(sum, ps.data)
- for x in ps.data
- fill!(x, 0.)
- end
+ for x in ps.data
+ fill!(x, 0.0)
+ end
- put!(trainer, FetchParamMsg())
- ps = take!(self())
- new_sum = sum(sum, ps.data)
+ put!(trainer, FetchParamMsg())
+ ps = take!(self())
+ new_sum = sum(sum, ps.data)
- # make sure no state sharing between messages
- @test original_sum == new_sum
+ # make sure no state sharing between messages
+ @test original_sum == new_sum
- batch_data = (
- state = rand(4, 32),
- action = rand(1:2, 32),
- reward = rand(32),
- terminal = rand(Bool, 32),
- next_state = rand(4,32),
- next_action = rand(1:2, 32)
- )
+ batch_data = (
+ state = rand(4, 32),
+ action = rand(1:2, 32),
+ reward = rand(32),
+ terminal = rand(Bool, 32),
+ next_state = rand(4, 32),
+ next_action = rand(1:2, 32),
+ )
- put!(trainer, BatchDataMsg(batch_data))
+ put!(trainer, BatchDataMsg(batch_data))
- put!(trainer, FetchParamMsg())
- ps = take!(self())
- updated_sum = sum(sum, ps.data)
- @test original_sum != updated_sum
-end
+ put!(trainer, FetchParamMsg())
+ ps = take!(self())
+ updated_sum = sum(sum, ps.data)
+ @test original_sum != updated_sum
+ end
-@testset "TrajectoryManager" begin
- _trajectory_proxy = TrajectoryManager(
- trajectory = CircularSARTSATrajectory(;capacity=5, state_type=Any, ),
- sampler = UniformBatchSampler(3),
- inserter = NStepInserter(),
- )
+ @testset "TrajectoryManager" begin
+ _trajectory_proxy = TrajectoryManager(
+ trajectory = CircularSARTSATrajectory(; capacity = 5, state_type = Any),
+ sampler = UniformBatchSampler(3),
+ inserter = NStepInserter(),
+ )
- trajectory_proxy = actor(_trajectory_proxy)
+ trajectory_proxy = actor(_trajectory_proxy)
- # 1. init traj for testing
- traj = CircularCompactSARTSATrajectory(
- capacity = 2,
- state_type = Float32,
- state_size = (4,),
- )
- push!(traj;state=rand(Float32, 4), action=rand(1:2))
- push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2))
- push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2))
+ # 1. init traj for testing
+ traj = CircularCompactSARTSATrajectory(
+ capacity = 2,
+ state_type = Float32,
+ state_size = (4,),
+ )
+ push!(traj; state = rand(Float32, 4), action = rand(1:2))
+ push!(
+ traj;
+ reward = rand(),
+ terminal = rand(Bool),
+ state = rand(Float32, 4),
+ action = rand(1:2),
+ )
+ push!(
+ traj;
+ reward = rand(),
+ terminal = rand(Bool),
+ state = rand(Float32, 4),
+ action = rand(1:2),
+ )
- # 2. insert
- put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here
+ # 2. insert
+ put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here
- # 3. make sure the above message is already been handled
- put!(trajectory_proxy, PingMsg())
- take!(self())
+ # 3. make sure the above message is already been handled
+ put!(trajectory_proxy, PingMsg())
+ take!(self())
- # 4. test that updating traj will not affect data in trajectory_proxy
- s_tp = _trajectory_proxy.trajectory[:state]
- s_traj = traj[:state]
+ # 4. test that updating traj will not affect data in trajectory_proxy
+ s_tp = _trajectory_proxy.trajectory[:state]
+ s_traj = traj[:state]
- @test s_tp[1] == s_traj[:, 1]
+ @test s_tp[1] == s_traj[:, 1]
- push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2))
+ push!(
+ traj;
+ reward = rand(),
+ terminal = rand(Bool),
+ state = rand(Float32, 4),
+ action = rand(1:2),
+ )
- @test s_tp[1] != s_traj[:, 1]
+ @test s_tp[1] != s_traj[:, 1]
- s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler)
- fill!(s[:state], 0.)
- @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy
-end
+ s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler)
+ fill!(s[:state], 0.0)
+ @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy
+ end
-@testset "Worker" begin
- _worker = Worker() do worker_proxy
- Experiment(
- Agent(
- policy = StaticPolicy(
+ @testset "Worker" begin
+ _worker = Worker() do worker_proxy
+ Experiment(
+ Agent(
+ policy = StaticPolicy(
QBasedPolicy(
- learner = BasicDQNLearner(
- approximator = NeuralNetworkApproximator(
- model = Chain(
- Dense(4, 128, relu; initW = glorot_uniform),
- Dense(128, 128, relu; initW = glorot_uniform),
- Dense(128, 2; initW = glorot_uniform),
- ) |> cpu,
- optimizer = ADAM(),
+ learner = BasicDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
+ Dense(4, 128, relu; initW = glorot_uniform),
+ Dense(128, 128, relu; initW = glorot_uniform),
+ Dense(128, 2; initW = glorot_uniform),
+ ) |> cpu,
+ optimizer = ADAM(),
+ ),
+ loss_func = huber_loss,
+ ),
+ explorer = EpsilonGreedyExplorer(
+ kind = :exp,
+ ϵ_stable = 0.01,
+ decay_steps = 500,
),
- loss_func = huber_loss,
- ),
- explorer = EpsilonGreedyExplorer(
- kind = :exp,
- ϵ_stable = 0.01,
- decay_steps = 500,
),
),
+ trajectory = CircularCompactSARTSATrajectory(
+ capacity = 10,
+ state_type = Float32,
+ state_size = (4,),
+ ),
),
- trajectory = CircularCompactSARTSATrajectory(
- capacity = 10,
- state_type = Float32,
- state_size = (4,),
+ CartPoleEnv(; T = Float32),
+ ComposedStopCondition(StopAfterStep(1_000), StopSignal()),
+ ComposedHook(
+ UploadTrajectoryEveryNStep(
+ mailbox = worker_proxy,
+ n = 10,
+ sealer = x -> InsertTrajectoryMsg(deepcopy(x)),
+ ),
+ LoadParamsHook(),
+ TotalRewardPerEpisode(),
),
- ),
- CartPoleEnv(; T = Float32),
- ComposedStopCondition(
- StopAfterStep(1_000),
- StopSignal(),
- ),
- ComposedHook(
- UploadTrajectoryEveryNStep(mailbox=worker_proxy, n=10, sealer=x -> InsertTrajectoryMsg(deepcopy(x))),
- LoadParamsHook(),
- TotalRewardPerEpisode(),
- ),
- "experimenting..."
- )
- end
+ "experimenting...",
+ )
+ end
- worker = actor(_worker)
- tmp_mailbox = Channel(100)
- put!(worker, StartMsg(tmp_mailbox))
-end
-
-@testset "WorkerProxy" begin
- target = RemoteChannel(() -> Channel(10))
- workers = [RemoteChannel(()->Channel(10)) for _ in 1:10]
- _wp = WorkerProxy(workers)
- wp = actor(_wp)
-
- put!(wp, StartMsg(target))
- for w in workers
- # @test take!(w).args[1] === wp
- @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) === Distributed.channel_from_id(remoteref_id(wp))
+ worker = actor(_worker)
+ tmp_mailbox = Channel(100)
+ put!(worker, StartMsg(tmp_mailbox))
end
- msg = InsertTrajectoryMsg(1)
- put!(wp, msg)
- @test take!(target) === msg
-
- for w in workers
- put!(wp, FetchParamMsg(w))
+ @testset "WorkerProxy" begin
+ target = RemoteChannel(() -> Channel(10))
+ workers = [RemoteChannel(() -> Channel(10)) for _ in 1:10]
+ _wp = WorkerProxy(workers)
+ wp = actor(_wp)
+
+ put!(wp, StartMsg(target))
+ for w in workers
+ # @test take!(w).args[1] === wp
+ @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) ===
+ Distributed.channel_from_id(remoteref_id(wp))
+ end
+
+ msg = InsertTrajectoryMsg(1)
+ put!(wp, msg)
+ @test take!(target) === msg
+
+ for w in workers
+ put!(wp, FetchParamMsg(w))
+ end
+ # @test take!(target).from === wp
+ @test Distributed.channel_from_id(remoteref_id(take!(target).from)) ===
+ Distributed.channel_from_id(remoteref_id(wp))
+
+ # make sure target only received one FetchParamMsg
+ msg = PingMsg()
+ put!(target, msg)
+ @test take!(target) === msg
+
+ msg = LoadParamMsg([])
+ put!(wp, msg)
+ for w in workers
+ @test take!(w) === msg
+ end
end
- # @test take!(target).from === wp
- @test Distributed.channel_from_id(remoteref_id(take!(target).from)) === Distributed.channel_from_id(remoteref_id(wp))
-
- # make sure target only received one FetchParamMsg
- msg = PingMsg()
- put!(target, msg)
- @test take!(target) === msg
-
- msg = LoadParamMsg([])
- put!(wp, msg)
- for w in workers
- @test take!(w) === msg
+
+ @testset "Orchestrator" begin
+ # TODO
+ # Add an integration test
end
-end
-@testset "Orchestrator" begin
- # TODO
- # Add an integration test
end
-
-end
\ No newline at end of file
diff --git a/src/DistributedReinforcementLearning/test/runtests.jl b/src/DistributedReinforcementLearning/test/runtests.jl
index f9f572d55..a73cd0521 100644
--- a/src/DistributedReinforcementLearning/test/runtests.jl
+++ b/src/DistributedReinforcementLearning/test/runtests.jl
@@ -9,7 +9,7 @@ using Flux
@testset "DistributedReinforcementLearning.jl" begin
-include("actor.jl")
-include("core.jl")
+ include("actor.jl")
+ include("core.jl")
end
diff --git a/src/ReinforcementLearningBase/src/CommonRLInterface.jl b/src/ReinforcementLearningBase/src/CommonRLInterface.jl
index af86a2d83..28c73ec40 100644
--- a/src/ReinforcementLearningBase/src/CommonRLInterface.jl
+++ b/src/ReinforcementLearningBase/src/CommonRLInterface.jl
@@ -41,7 +41,8 @@ end
# !!! may need to be extended by user
CRL.@provide CRL.observe(env::CommonRLEnv) = state(env.env)
-CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = !isnothing(find_state_style(env.env, InternalState))
+CRL.provided(::typeof(CRL.state), env::CommonRLEnv) =
+ !isnothing(find_state_style(env.env, InternalState))
CRL.state(env::CommonRLEnv) = state(env.env, find_state_style(env.env, InternalState))
CRL.@provide CRL.clone(env::CommonRLEnv) = CommonRLEnv(copy(env.env))
@@ -94,4 +95,4 @@ ActionStyle(env::RLBaseEnv) =
CRL.provided(CRL.valid_actions, env.env) ? FullActionSet() : MinimalActionSet()
current_player(env::RLBaseEnv) = CRL.player(env.env)
-players(env::RLBaseEnv) = CRL.players(env.env)
\ No newline at end of file
+players(env::RLBaseEnv) = CRL.players(env.env)
diff --git a/src/ReinforcementLearningBase/src/interface.jl b/src/ReinforcementLearningBase/src/interface.jl
index 8de38c832..ee199aee5 100644
--- a/src/ReinforcementLearningBase/src/interface.jl
+++ b/src/ReinforcementLearningBase/src/interface.jl
@@ -410,12 +410,13 @@ Make an independent copy of `env`,
!!! warning
Only check the state of all players in the env.
"""
-function Base.:(==)(env1::T, env2::T) where T<:AbstractEnv
+function Base.:(==)(env1::T, env2::T) where {T<:AbstractEnv}
len = length(players(env1))
- len == length(players(env2)) &&
- all(state(env1, player) == state(env2, player) for player in players(env1))
+ len == length(players(env2)) &&
+ all(state(env1, player) == state(env2, player) for player in players(env1))
end
-Base.hash(env::AbstractEnv, h::UInt) = hash([state(env, player) for player in players(env)], h)
+Base.hash(env::AbstractEnv, h::UInt) =
+ hash([state(env, player) for player in players(env)], h)
@api nameof(env::AbstractEnv) = nameof(typeof(env))
diff --git a/src/ReinforcementLearningBase/test/CommonRLInterface.jl b/src/ReinforcementLearningBase/test/CommonRLInterface.jl
index fc38a102b..7b32dbe08 100644
--- a/src/ReinforcementLearningBase/test/CommonRLInterface.jl
+++ b/src/ReinforcementLearningBase/test/CommonRLInterface.jl
@@ -1,34 +1,34 @@
@testset "CommonRLInterface" begin
-@testset "MDPEnv" begin
- struct RLTestMDP <: MDP{Int, Int} end
+ @testset "MDPEnv" begin
+ struct RLTestMDP <: MDP{Int,Int} end
- POMDPs.actions(m::RLTestMDP) = [-1, 1]
- POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
- POMDPs.initialstate(m::RLTestMDP) = Deterministic(1)
- POMDPs.isterminal(m::RLTestMDP, s) = s == 3
- POMDPs.reward(m::RLTestMDP, s, a, sp) = sp
- POMDPs.states(m::RLTestMDP) = 1:3
+ POMDPs.actions(m::RLTestMDP) = [-1, 1]
+ POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
+ POMDPs.initialstate(m::RLTestMDP) = Deterministic(1)
+ POMDPs.isterminal(m::RLTestMDP, s) = s == 3
+ POMDPs.reward(m::RLTestMDP, s, a, sp) = sp
+ POMDPs.states(m::RLTestMDP) = 1:3
- env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP()))
- RLBase.test_runnable!(env)
-end
+ env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP()))
+ RLBase.test_runnable!(env)
+ end
-@testset "POMDPEnv" begin
+ @testset "POMDPEnv" begin
- struct RLTestPOMDP <: POMDP{Int, Int, Int} end
+ struct RLTestPOMDP <: POMDP{Int,Int,Int} end
- POMDPs.actions(m::RLTestPOMDP) = [-1, 1]
- POMDPs.states(m::RLTestPOMDP) = 1:3
- POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
- POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1)
- POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1)
- POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1)
- POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3
- POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp
- POMDPs.observations(m::RLTestPOMDP) = 2:4
+ POMDPs.actions(m::RLTestPOMDP) = [-1, 1]
+ POMDPs.states(m::RLTestPOMDP) = 1:3
+ POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
+ POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1)
+ POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1)
+ POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1)
+ POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3
+ POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp
+ POMDPs.observations(m::RLTestPOMDP) = 2:4
- env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP()))
+ env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP()))
- RLBase.test_runnable!(env)
+ RLBase.test_runnable!(env)
+ end
end
-end
\ No newline at end of file
diff --git a/src/ReinforcementLearningBase/test/runtests.jl b/src/ReinforcementLearningBase/test/runtests.jl
index 6b44a29b8..d4f743f68 100644
--- a/src/ReinforcementLearningBase/test/runtests.jl
+++ b/src/ReinforcementLearningBase/test/runtests.jl
@@ -8,5 +8,5 @@ using POMDPs
using POMDPModelTools: Deterministic
@testset "ReinforcementLearningBase" begin
-include("CommonRLInterface.jl")
-end
\ No newline at end of file
+ include("CommonRLInterface.jl")
+end
diff --git a/src/ReinforcementLearningCore/src/core/hooks.jl b/src/ReinforcementLearningCore/src/core/hooks.jl
index 11c65e686..a8a58be3e 100644
--- a/src/ReinforcementLearningCore/src/core/hooks.jl
+++ b/src/ReinforcementLearningCore/src/core/hooks.jl
@@ -13,7 +13,7 @@ export AbstractHook,
UploadTrajectoryEveryNStep,
MultiAgentHook
-using UnicodePlots:lineplot, lineplot!
+using UnicodePlots: lineplot, lineplot!
using Statistics
"""
@@ -155,7 +155,14 @@ end
function (hook::TotalRewardPerEpisode)(::PostExperimentStage, agent, env)
if hook.is_display_on_exit
- println(lineplot(hook.rewards, title="Total reward per episode", xlabel="Episode", ylabel="Score"))
+ println(
+ lineplot(
+ hook.rewards,
+ title = "Total reward per episode",
+ xlabel = "Episode",
+ ylabel = "Score",
+ ),
+ )
end
end
@@ -178,8 +185,12 @@ which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
If `is_display_on_exit` is set to `true`, a ribbon plot will be shown to reflect
the mean and std of rewards.
"""
-function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit=true)
- TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size), is_display_on_exit)
+function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit = true)
+ TotalBatchRewardPerEpisode(
+ [Float64[] for _ in 1:batch_size],
+ zeros(batch_size),
+ is_display_on_exit,
+ )
end
function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env)
@@ -198,7 +209,12 @@ function (hook::TotalBatchRewardPerEpisode)(::PostExperimentStage, agent, env)
n = minimum(map(length, hook.rewards))
m = mean([@view(x[1:n]) for x in hook.rewards])
s = std([@view(x[1:n]) for x in hook.rewards])
- p = lineplot(m, title="Avg total reward per episode", xlabel="Episode", ylabel="Score")
+ p = lineplot(
+ m,
+ title = "Avg total reward per episode",
+ xlabel = "Episode",
+ ylabel = "Score",
+ )
lineplot!(p, m .- s)
lineplot!(p, m .+ s)
println(p)
@@ -288,8 +304,7 @@ end
Execute `f(t, agent, env)` every `n` episode.
`t` is a counter of episodes.
"""
-mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <:
- AbstractHook
+mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: AbstractHook
f::F
n::Int
t::Int
diff --git a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl
index d641c615b..507907b8a 100644
--- a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl
+++ b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl
@@ -1,7 +1,10 @@
using ArrayInterface
-function ArrayInterface.restructure(x::AbstractArray{T1, 0}, y::AbstractArray{T2, 0}) where {T1, T2}
+function ArrayInterface.restructure(
+ x::AbstractArray{T1,0},
+ y::AbstractArray{T2,0},
+) where {T1,T2}
out = similar(x, eltype(y))
out .= y
out
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl
index a5d143530..8dfb73c12 100644
--- a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl
+++ b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl
@@ -140,7 +140,7 @@ end
function fetch!(
sampler::NStepBatchSampler{traces},
- traj::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory},
+ traj::Union{CircularArraySARTTrajectory,CircularArraySLARTTrajectory},
inds::Vector{Int},
) where {traces}
γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size
diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl
index 43d70a735..5d9d6a743 100644
--- a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl
+++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl
@@ -89,9 +89,14 @@ This function is compatible with a multidimensional action space. When outputtin
- `is_sampling::Bool=false`, whether to sample from the obtained normal distribution.
- `is_return_log_prob`, whether to calculate the conditional probability of getting actions in the given state.
"""
-function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
+function (model::GaussianNetwork)(
+ rng::AbstractRNG,
+ state;
+ is_sampling::Bool = false,
+ is_return_log_prob::Bool = false,
+)
x = model.pre(state)
- μ, logσ = model.μ(x), model.logσ(x)
+ μ, logσ = model.μ(x), model.logσ(x)
if is_sampling
π_dist = Normal.(μ, exp.(logσ))
z = rand.(rng, π_dist)
@@ -107,8 +112,17 @@ function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=fal
end
end
-function (model::GaussianNetwork)(state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
- model(Random.GLOBAL_RNG, state; is_sampling=is_sampling, is_return_log_prob=is_return_log_prob)
+function (model::GaussianNetwork)(
+ state;
+ is_sampling::Bool = false,
+ is_return_log_prob::Bool = false,
+)
+ model(
+ Random.GLOBAL_RNG,
+ state;
+ is_sampling = is_sampling,
+ is_return_log_prob = is_return_log_prob,
+ )
end
#####
@@ -131,5 +145,5 @@ Flux.@functor DuelingNetwork
function (m::DuelingNetwork)(state)
x = m.base(state)
val = m.val(x)
- return val .+ m.adv(x) .- mean(m.adv(x), dims=1)
-end
\ No newline at end of file
+ return val .+ m.adv(x) .- mean(m.adv(x), dims = 1)
+end
diff --git a/src/ReinforcementLearningCore/test/components/trajectories.jl b/src/ReinforcementLearningCore/test/components/trajectories.jl
index 8fb8eb9ae..c7c60e163 100644
--- a/src/ReinforcementLearningCore/test/components/trajectories.jl
+++ b/src/ReinforcementLearningCore/test/components/trajectories.jl
@@ -52,9 +52,9 @@
t = CircularArraySLARTTrajectory(
capacity = 3,
state = Vector{Int} => (4,),
- legal_actions_mask = Vector{Bool} => (4, ),
+ legal_actions_mask = Vector{Bool} => (4,),
)
-
+
# test instance type is same as type
@test isa(t, CircularArraySLARTTrajectory)
diff --git a/src/ReinforcementLearningCore/test/core/core.jl b/src/ReinforcementLearningCore/test/core/core.jl
index fb6f5fde5..56a81b809 100644
--- a/src/ReinforcementLearningCore/test/core/core.jl
+++ b/src/ReinforcementLearningCore/test/core/core.jl
@@ -1,22 +1,18 @@
@testset "simple workflow" begin
- env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy)
+ env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy)
policy = RandomPolicy(action_space(env))
N_EPISODE = 10_000
hook = TotalRewardPerEpisode()
run(policy, env, StopAfterEpisode(N_EPISODE), hook)
- @test isapprox(sum(hook[]) / N_EPISODE, 21; atol=2)
+ @test isapprox(sum(hook[]) / N_EPISODE, 21; atol = 2)
end
@testset "multi agent" begin
# https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/393
rps = RockPaperScissorsEnv() |> SequentialEnv
- ma_policy = MultiAgentManager(
- (
- NamedPolicy(p => RandomPolicy())
- for p in players(rps)
- )...
- )
+ ma_policy =
+ MultiAgentManager((NamedPolicy(p => RandomPolicy()) for p in players(rps))...)
run(ma_policy, rps, StopAfterEpisode(10))
end
diff --git a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl
index a2d657fa1..fc5b0bfc8 100644
--- a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl
+++ b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl
@@ -1,5 +1,5 @@
@testset "test StopAfterNoImprovement" begin
- env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy)
+ env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy)
policy = RandomPolicy(action_space(env))
total_reward_per_episode = TotalRewardPerEpisode()
@@ -14,7 +14,8 @@
hook = ComposedHook(total_reward_per_episode)
run(policy, env, stop_condition, hook)
- @test argmax(total_reward_per_episode.rewards) + patience == length(total_reward_per_episode.rewards)
+ @test argmax(total_reward_per_episode.rewards) + patience ==
+ length(total_reward_per_episode.rewards)
end
@testset "StopAfterNSeconds" begin
diff --git a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl
index 55f721c89..1d8b67f7b 100644
--- a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl
+++ b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl
@@ -8,4 +8,4 @@ using DataDeps
include("d4rl/register.jl")
include("d4rl/d4rl_dataset.jl")
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl
index 6557edce3..4c772b349 100644
--- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl
+++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl
@@ -24,7 +24,7 @@ Represents a iterable dataset from d4rl with the following fields:
`is_shuffle`: Bool, determines if the batches returned by `iterate` are shuffled.
"""
struct D4RLDataSet{T<:AbstractRNG}
- dataset::Dict{Symbol, Any}
+ dataset::Dict{Symbol,Any}
size::Integer
batch_size::Integer
style::Tuple
@@ -51,40 +51,44 @@ The dataset type is an iterable that fetches batches when used in a for loop for
The returned type is an infinite iterator which can be called using `iterate` and will return batches as specified in the dataset.
"""
-function dataset(dataset::String;
- style=SARTS,
- rng = StableRNG(123),
- is_shuffle = true,
- batch_size=256
+function dataset(
+ dataset::String;
+ style = SARTS,
+ rng = StableRNG(123),
+ is_shuffle = true,
+ batch_size = 256,
)
-
- try
- @datadep_str "d4rl-"*dataset
- catch
- throw("The provided dataset is not available")
+
+ try
+ @datadep_str "d4rl-" * dataset
+ catch
+ throw("The provided dataset is not available")
end
-
- path = @datadep_str "d4rl-"*dataset
+
+ path = @datadep_str "d4rl-" * dataset
@assert length(readdir(path)) == 1
file_name = readdir(path)[1]
-
- data = h5open(path*"/"*file_name, "r") do file
+
+ data = h5open(path * "/" * file_name, "r") do file
read(file)
end
# sanity checks on data
verify(data)
- dataset = Dict{Symbol, Any}()
- meta = Dict{String, Any}()
+ dataset = Dict{Symbol,Any}()
+ meta = Dict{String,Any}()
N_samples = size(data["terminals"])[1]
-
- for (key, d_key) in zip(["observations", "actions", "rewards", "terminals"], Symbol.(["state", "action", "reward", "terminal"]))
- dataset[d_key] = data[key]
+
+ for (key, d_key) in zip(
+ ["observations", "actions", "rewards", "terminals"],
+ Symbol.(["state", "action", "reward", "terminal"]),
+ )
+ dataset[d_key] = data[key]
end
-
+
for key in keys(data)
if !(key in ["observations", "actions", "rewards", "terminals"])
meta[key] = data[key]
@@ -104,9 +108,13 @@ function iterate(ds::D4RLDataSet, state = 0)
if is_shuffle
inds = rand(rng, 1:size, batch_size)
- map((x)-> if x <= size x else 1 end, inds)
+ map((x) -> if x <= size
+ x
+ else
+ 1
+ end, inds)
else
- if (state+1) * batch_size <= size
+ if (state + 1) * batch_size <= size
inds = state*batch_size+1:(state+1)*batch_size
else
return nothing
@@ -114,15 +122,17 @@ function iterate(ds::D4RLDataSet, state = 0)
state += 1
end
- batch = (state = copy(ds.dataset[:state][:, inds]),
- action = copy(ds.dataset[:action][:, inds]),
- reward = copy(ds.dataset[:reward][inds]),
- terminal = copy(ds.dataset[:terminal][inds]))
+ batch = (
+ state = copy(ds.dataset[:state][:, inds]),
+ action = copy(ds.dataset[:action][:, inds]),
+ reward = copy(ds.dataset[:reward][inds]),
+ terminal = copy(ds.dataset[:terminal][inds]),
+ )
if style == SARTS
batch = merge(batch, (next_state = copy(ds.dataset[:state][:, (1).+(inds)]),))
end
-
+
return batch, state
end
@@ -132,11 +142,11 @@ length(ds::D4RLDataSet) = ds.size
IteratorEltype(::Type{D4RLDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit)
-function verify(data::Dict{String, Any})
+function verify(data::Dict{String,Any})
for key in ["observations", "actions", "rewards", "terminals"]
@assert (key in keys(data)) "Expected keys not present in data"
end
N_samples = size(data["observations"])[2]
@assert size(data["rewards"]) == (N_samples,)
@assert size(data["terminals"]) == (N_samples,)
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/d4rl/register.jl b/src/ReinforcementLearningDatasets/src/d4rl/register.jl
index 0ca3aaa39..0e6ead74a 100644
--- a/src/ReinforcementLearningDatasets/src/d4rl/register.jl
+++ b/src/ReinforcementLearningDatasets/src/d4rl/register.jl
@@ -7,265 +7,354 @@ This file holds the registration information for d4rl datasets.
It also registers the information in DataDeps for further use in this package.
"""
-const DATASET_URLS = Dict{String, String}(
- "maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5",
- "maze2d-umaze-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5",
- "maze2d-medium-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5",
- "maze2d-large-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5",
- "maze2d-eval-umaze-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5",
- "maze2d-eval-medium-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5",
- "maze2d-eval-large-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5",
- "maze2d-open-dense-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5",
- "maze2d-umaze-dense-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5",
- "maze2d-medium-dense-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5",
- "maze2d-large-dense-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5",
- "maze2d-eval-umaze-dense-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5",
- "maze2d-eval-medium-dense-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5",
- "maze2d-eval-large-dense-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5",
- "minigrid-fourrooms-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5",
- "minigrid-fourrooms-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5",
- "pen-human-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5",
- "pen-cloned-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5",
- "pen-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5",
- "hammer-human-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5",
- "hammer-cloned-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5",
- "hammer-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5",
- "relocate-human-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5",
- "relocate-cloned-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5",
- "relocate-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5",
- "door-human-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5",
- "door-cloned-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5",
- "door-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5",
- "halfcheetah-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5",
- "halfcheetah-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5",
- "halfcheetah-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5",
- "halfcheetah-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5",
- "halfcheetah-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5",
- "walker2d-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5",
- "walker2d-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5",
- "walker2d-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5",
- "walker2d-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5",
- "walker2d-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5",
- "hopper-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5",
- "hopper-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5",
- "hopper-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5",
- "hopper-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5",
- "hopper-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5",
- "ant-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5",
- "ant-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5",
- "ant-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5",
- "ant-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5",
- "ant-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5",
- "ant-random-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5",
- "antmaze-umaze-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5",
- "antmaze-umaze-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
- "antmaze-medium-play-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5",
- "antmaze-medium-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
- "antmaze-large-play-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5",
- "antmaze-large-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
- "flow-ring-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5",
- "flow-ring-controller-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5",
- "flow-merge-random-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5",
- "flow-merge-controller-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5",
- "kitchen-complete-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5",
- "kitchen-partial-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5",
- "kitchen-mixed-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5",
- "carla-lane-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5",
- "carla-town-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5",
- "carla-town-full-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5",
- "bullet-halfcheetah-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5",
- "bullet-halfcheetah-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5",
- "bullet-halfcheetah-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5",
- "bullet-halfcheetah-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5",
- "bullet-halfcheetah-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5",
- "bullet-hopper-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5",
- "bullet-hopper-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5",
- "bullet-hopper-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5",
- "bullet-hopper-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5",
- "bullet-hopper-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5",
- "bullet-ant-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5",
- "bullet-ant-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5",
- "bullet-ant-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5",
- "bullet-ant-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5",
- "bullet-ant-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5",
- "bullet-walker2d-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5",
- "bullet-walker2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5",
- "bullet-walker2d-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5",
- "bullet-walker2d-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5",
- "bullet-walker2d-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5",
- "bullet-maze2d-open-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5",
- "bullet-maze2d-umaze-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5",
- "bullet-maze2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5",
- "bullet-maze2d-large-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5",
+const DATASET_URLS = Dict{String,String}(
+ "maze2d-open-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5",
+ "maze2d-umaze-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5",
+ "maze2d-medium-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5",
+ "maze2d-large-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-sparse-v1.hdf5",
+ "maze2d-eval-umaze-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-sparse-v1.hdf5",
+ "maze2d-eval-medium-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-sparse-v1.hdf5",
+ "maze2d-eval-large-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-sparse-v1.hdf5",
+ "maze2d-open-dense-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-dense.hdf5",
+ "maze2d-umaze-dense-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-dense-v1.hdf5",
+ "maze2d-medium-dense-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-dense-v1.hdf5",
+ "maze2d-large-dense-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-large-dense-v1.hdf5",
+ "maze2d-eval-umaze-dense-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-umaze-dense-v1.hdf5",
+ "maze2d-eval-medium-dense-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-medium-dense-v1.hdf5",
+ "maze2d-eval-large-dense-v1" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-eval-large-dense-v1.hdf5",
+ "minigrid-fourrooms-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms.hdf5",
+ "minigrid-fourrooms-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/minigrid/minigrid4rooms_random.hdf5",
+ "pen-human-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_demos_clipped.hdf5",
+ "pen-cloned-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-demos-v0-bc-combined.hdf5",
+ "pen-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/pen-v0_expert_clipped.hdf5",
+ "hammer-human-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_demos_clipped.hdf5",
+ "hammer-cloned-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-demos-v0-bc-combined.hdf5",
+ "hammer-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/hammer-v0_expert_clipped.hdf5",
+ "relocate-human-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_demos_clipped.hdf5",
+ "relocate-cloned-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-demos-v0-bc-combined.hdf5",
+ "relocate-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/relocate-v0_expert_clipped.hdf5",
+ "door-human-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_demos_clipped.hdf5",
+ "door-cloned-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-demos-v0-bc-combined.hdf5",
+ "door-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/hand_dapg/door-v0_expert_clipped.hdf5",
+ "halfcheetah-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_random.hdf5",
+ "halfcheetah-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium.hdf5",
+ "halfcheetah-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_expert.hdf5",
+ "halfcheetah-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_mixed.hdf5",
+ "halfcheetah-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/halfcheetah_medium_expert.hdf5",
+ "walker2d-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_random.hdf5",
+ "walker2d-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium.hdf5",
+ "walker2d-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_expert.hdf5",
+ "walker2d-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker_mixed.hdf5",
+ "walker2d-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/walker2d_medium_expert.hdf5",
+ "hopper-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_random.hdf5",
+ "hopper-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium.hdf5",
+ "hopper-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_expert.hdf5",
+ "hopper-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_mixed.hdf5",
+ "hopper-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/hopper_medium_expert.hdf5",
+ "ant-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random.hdf5",
+ "ant-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium.hdf5",
+ "ant-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_expert.hdf5",
+ "ant-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_mixed.hdf5",
+ "ant-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_medium_expert.hdf5",
+ "ant-random-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/gym_mujoco/ant_random_expert.hdf5",
+ "antmaze-umaze-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_False_multigoal_False_sparse.hdf5",
+ "antmaze-umaze-diverse-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_u-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
+ "antmaze-medium-play-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_False_sparse.hdf5",
+ "antmaze-medium-diverse-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
+ "antmaze-large-play-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5",
+ "antmaze-large-diverse-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
+ "flow-ring-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5",
+ "flow-ring-controller-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5",
+ "flow-merge-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5",
+ "flow-merge-controller-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5",
+ "kitchen-complete-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5",
+ "kitchen-partial-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5",
+ "kitchen-mixed-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5",
+ "carla-lane-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5",
+ "carla-town-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5",
+ "carla-town-full-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5",
+ "bullet-halfcheetah-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5",
+ "bullet-halfcheetah-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5",
+ "bullet-halfcheetah-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5",
+ "bullet-halfcheetah-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5",
+ "bullet-halfcheetah-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5",
+ "bullet-hopper-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5",
+ "bullet-hopper-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5",
+ "bullet-hopper-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5",
+ "bullet-hopper-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5",
+ "bullet-hopper-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5",
+ "bullet-ant-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5",
+ "bullet-ant-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5",
+ "bullet-ant-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5",
+ "bullet-ant-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5",
+ "bullet-ant-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5",
+ "bullet-walker2d-random-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5",
+ "bullet-walker2d-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5",
+ "bullet-walker2d-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5",
+ "bullet-walker2d-medium-expert-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5",
+ "bullet-walker2d-medium-replay-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5",
+ "bullet-maze2d-open-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5",
+ "bullet-maze2d-umaze-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5",
+ "bullet-maze2d-medium-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5",
+ "bullet-maze2d-large-v0" =>
+ "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5",
)
-const REF_MIN_SCORE = Dict{String, Float32}(
- "maze2d-open-v0" => 0.01 ,
- "maze2d-umaze-v1" => 23.85 ,
- "maze2d-medium-v1" => 13.13 ,
- "maze2d-large-v1" => 6.7 ,
- "maze2d-open-dense-v0" => 11.17817 ,
- "maze2d-umaze-dense-v1" => 68.537689 ,
- "maze2d-medium-dense-v1" => 44.264742 ,
- "maze2d-large-dense-v1" => 30.569041 ,
- "minigrid-fourrooms-v0" => 0.01442 ,
- "minigrid-fourrooms-random-v0" => 0.01442 ,
- "pen-human-v0" => 96.262799 ,
- "pen-cloned-v0" => 96.262799 ,
- "pen-expert-v0" => 96.262799 ,
- "hammer-human-v0" => -274.856578 ,
- "hammer-cloned-v0" => -274.856578 ,
- "hammer-expert-v0" => -274.856578 ,
- "relocate-human-v0" => -6.425911 ,
- "relocate-cloned-v0" => -6.425911 ,
- "relocate-expert-v0" => -6.425911 ,
- "door-human-v0" => -56.512833 ,
- "door-cloned-v0" => -56.512833 ,
- "door-expert-v0" => -56.512833 ,
- "halfcheetah-random-v0" => -280.178953 ,
- "halfcheetah-medium-v0" => -280.178953 ,
- "halfcheetah-expert-v0" => -280.178953 ,
- "halfcheetah-medium-replay-v0" => -280.178953 ,
- "halfcheetah-medium-expert-v0" => -280.178953 ,
- "walker2d-random-v0" => 1.629008 ,
- "walker2d-medium-v0" => 1.629008 ,
- "walker2d-expert-v0" => 1.629008 ,
- "walker2d-medium-replay-v0" => 1.629008 ,
- "walker2d-medium-expert-v0" => 1.629008 ,
- "hopper-random-v0" => -20.272305 ,
- "hopper-medium-v0" => -20.272305 ,
- "hopper-expert-v0" => -20.272305 ,
- "hopper-medium-replay-v0" => -20.272305 ,
- "hopper-medium-expert-v0" => -20.272305 ,
+const REF_MIN_SCORE = Dict{String,Float32}(
+ "maze2d-open-v0" => 0.01,
+ "maze2d-umaze-v1" => 23.85,
+ "maze2d-medium-v1" => 13.13,
+ "maze2d-large-v1" => 6.7,
+ "maze2d-open-dense-v0" => 11.17817,
+ "maze2d-umaze-dense-v1" => 68.537689,
+ "maze2d-medium-dense-v1" => 44.264742,
+ "maze2d-large-dense-v1" => 30.569041,
+ "minigrid-fourrooms-v0" => 0.01442,
+ "minigrid-fourrooms-random-v0" => 0.01442,
+ "pen-human-v0" => 96.262799,
+ "pen-cloned-v0" => 96.262799,
+ "pen-expert-v0" => 96.262799,
+ "hammer-human-v0" => -274.856578,
+ "hammer-cloned-v0" => -274.856578,
+ "hammer-expert-v0" => -274.856578,
+ "relocate-human-v0" => -6.425911,
+ "relocate-cloned-v0" => -6.425911,
+ "relocate-expert-v0" => -6.425911,
+ "door-human-v0" => -56.512833,
+ "door-cloned-v0" => -56.512833,
+ "door-expert-v0" => -56.512833,
+ "halfcheetah-random-v0" => -280.178953,
+ "halfcheetah-medium-v0" => -280.178953,
+ "halfcheetah-expert-v0" => -280.178953,
+ "halfcheetah-medium-replay-v0" => -280.178953,
+ "halfcheetah-medium-expert-v0" => -280.178953,
+ "walker2d-random-v0" => 1.629008,
+ "walker2d-medium-v0" => 1.629008,
+ "walker2d-expert-v0" => 1.629008,
+ "walker2d-medium-replay-v0" => 1.629008,
+ "walker2d-medium-expert-v0" => 1.629008,
+ "hopper-random-v0" => -20.272305,
+ "hopper-medium-v0" => -20.272305,
+ "hopper-expert-v0" => -20.272305,
+ "hopper-medium-replay-v0" => -20.272305,
+ "hopper-medium-expert-v0" => -20.272305,
"ant-random-v0" => -325.6,
"ant-medium-v0" => -325.6,
"ant-expert-v0" => -325.6,
"ant-medium-replay-v0" => -325.6,
"ant-medium-expert-v0" => -325.6,
- "antmaze-umaze-v0" => 0.0 ,
- "antmaze-umaze-diverse-v0" => 0.0 ,
- "antmaze-medium-play-v0" => 0.0 ,
- "antmaze-medium-diverse-v0" => 0.0 ,
- "antmaze-large-play-v0" => 0.0 ,
- "antmaze-large-diverse-v0" => 0.0 ,
- "kitchen-complete-v0" => 0.0 ,
- "kitchen-partial-v0" => 0.0 ,
- "kitchen-mixed-v0" => 0.0 ,
- "flow-ring-random-v0" => -165.22 ,
- "flow-ring-controller-v0" => -165.22 ,
- "flow-merge-random-v0" => 118.67993 ,
- "flow-merge-controller-v0" => 118.67993 ,
- "carla-lane-v0"=> -0.8503839912088142,
- "carla-town-v0"=> -114.81579500772153, # random score
- "bullet-halfcheetah-random-v0"=> -1275.766996,
- "bullet-halfcheetah-medium-v0"=> -1275.766996,
- "bullet-halfcheetah-expert-v0"=> -1275.766996,
- "bullet-halfcheetah-medium-expert-v0"=> -1275.766996,
- "bullet-halfcheetah-medium-replay-v0"=> -1275.766996,
- "bullet-hopper-random-v0"=> 20.058972,
- "bullet-hopper-medium-v0"=> 20.058972,
- "bullet-hopper-expert-v0"=> 20.058972,
- "bullet-hopper-medium-expert-v0"=> 20.058972,
- "bullet-hopper-medium-replay-v0"=> 20.058972,
- "bullet-ant-random-v0"=> 373.705955,
- "bullet-ant-medium-v0"=> 373.705955,
- "bullet-ant-expert-v0"=> 373.705955,
- "bullet-ant-medium-expert-v0"=> 373.705955,
- "bullet-ant-medium-replay-v0"=> 373.705955,
- "bullet-walker2d-random-v0"=> 16.523877,
- "bullet-walker2d-medium-v0"=> 16.523877,
- "bullet-walker2d-expert-v0"=> 16.523877,
- "bullet-walker2d-medium-expert-v0"=> 16.523877,
- "bullet-walker2d-medium-replay-v0"=> 16.523877,
- "bullet-maze2d-open-v0"=> 8.750000,
- "bullet-maze2d-umaze-v0"=> 32.460000,
- "bullet-maze2d-medium-v0"=> 14.870000,
- "bullet-maze2d-large-v0"=> 1.820000,
+ "antmaze-umaze-v0" => 0.0,
+ "antmaze-umaze-diverse-v0" => 0.0,
+ "antmaze-medium-play-v0" => 0.0,
+ "antmaze-medium-diverse-v0" => 0.0,
+ "antmaze-large-play-v0" => 0.0,
+ "antmaze-large-diverse-v0" => 0.0,
+ "kitchen-complete-v0" => 0.0,
+ "kitchen-partial-v0" => 0.0,
+ "kitchen-mixed-v0" => 0.0,
+ "flow-ring-random-v0" => -165.22,
+ "flow-ring-controller-v0" => -165.22,
+ "flow-merge-random-v0" => 118.67993,
+ "flow-merge-controller-v0" => 118.67993,
+ "carla-lane-v0" => -0.8503839912088142,
+ "carla-town-v0" => -114.81579500772153, # random score
+ "bullet-halfcheetah-random-v0" => -1275.766996,
+ "bullet-halfcheetah-medium-v0" => -1275.766996,
+ "bullet-halfcheetah-expert-v0" => -1275.766996,
+ "bullet-halfcheetah-medium-expert-v0" => -1275.766996,
+ "bullet-halfcheetah-medium-replay-v0" => -1275.766996,
+ "bullet-hopper-random-v0" => 20.058972,
+ "bullet-hopper-medium-v0" => 20.058972,
+ "bullet-hopper-expert-v0" => 20.058972,
+ "bullet-hopper-medium-expert-v0" => 20.058972,
+ "bullet-hopper-medium-replay-v0" => 20.058972,
+ "bullet-ant-random-v0" => 373.705955,
+ "bullet-ant-medium-v0" => 373.705955,
+ "bullet-ant-expert-v0" => 373.705955,
+ "bullet-ant-medium-expert-v0" => 373.705955,
+ "bullet-ant-medium-replay-v0" => 373.705955,
+ "bullet-walker2d-random-v0" => 16.523877,
+ "bullet-walker2d-medium-v0" => 16.523877,
+ "bullet-walker2d-expert-v0" => 16.523877,
+ "bullet-walker2d-medium-expert-v0" => 16.523877,
+ "bullet-walker2d-medium-replay-v0" => 16.523877,
+ "bullet-maze2d-open-v0" => 8.750000,
+ "bullet-maze2d-umaze-v0" => 32.460000,
+ "bullet-maze2d-medium-v0" => 14.870000,
+ "bullet-maze2d-large-v0" => 1.820000,
)
-const REF_MAX_SCORE = Dict{String, Float32}(
- "maze2d-open-v0" => 20.66 ,
- "maze2d-umaze-v1" => 161.86 ,
- "maze2d-medium-v1" => 277.39 ,
- "maze2d-large-v1" => 273.99 ,
- "maze2d-open-dense-v0" => 27.166538620695782 ,
- "maze2d-umaze-dense-v1" => 193.66285642381482 ,
- "maze2d-medium-dense-v1" => 297.4552547777125 ,
- "maze2d-large-dense-v1" => 303.4857382709002 ,
- "minigrid-fourrooms-v0" => 2.89685 ,
- "minigrid-fourrooms-random-v0" => 2.89685 ,
- "pen-human-v0" => 3076.8331017826877 ,
- "pen-cloned-v0" => 3076.8331017826877 ,
- "pen-expert-v0" => 3076.8331017826877 ,
- "hammer-human-v0" => 12794.134825156867 ,
- "hammer-cloned-v0" => 12794.134825156867 ,
- "hammer-expert-v0" => 12794.134825156867 ,
- "relocate-human-v0" => 4233.877797728884 ,
- "relocate-cloned-v0" => 4233.877797728884 ,
- "relocate-expert-v0" => 4233.877797728884 ,
- "door-human-v0" => 2880.5693087298737 ,
- "door-cloned-v0" => 2880.5693087298737 ,
- "door-expert-v0" => 2880.5693087298737 ,
- "halfcheetah-random-v0" => 12135.0 ,
- "halfcheetah-medium-v0" => 12135.0 ,
- "halfcheetah-expert-v0" => 12135.0 ,
- "halfcheetah-medium-replay-v0" => 12135.0 ,
- "halfcheetah-medium-expert-v0" => 12135.0 ,
- "walker2d-random-v0" => 4592.3 ,
- "walker2d-medium-v0" => 4592.3 ,
- "walker2d-expert-v0" => 4592.3 ,
- "walker2d-medium-replay-v0" => 4592.3 ,
- "walker2d-medium-expert-v0" => 4592.3 ,
- "hopper-random-v0" => 3234.3 ,
- "hopper-medium-v0" => 3234.3 ,
- "hopper-expert-v0" => 3234.3 ,
- "hopper-medium-replay-v0" => 3234.3 ,
- "hopper-medium-expert-v0" => 3234.3 ,
+const REF_MAX_SCORE = Dict{String,Float32}(
+ "maze2d-open-v0" => 20.66,
+ "maze2d-umaze-v1" => 161.86,
+ "maze2d-medium-v1" => 277.39,
+ "maze2d-large-v1" => 273.99,
+ "maze2d-open-dense-v0" => 27.166538620695782,
+ "maze2d-umaze-dense-v1" => 193.66285642381482,
+ "maze2d-medium-dense-v1" => 297.4552547777125,
+ "maze2d-large-dense-v1" => 303.4857382709002,
+ "minigrid-fourrooms-v0" => 2.89685,
+ "minigrid-fourrooms-random-v0" => 2.89685,
+ "pen-human-v0" => 3076.8331017826877,
+ "pen-cloned-v0" => 3076.8331017826877,
+ "pen-expert-v0" => 3076.8331017826877,
+ "hammer-human-v0" => 12794.134825156867,
+ "hammer-cloned-v0" => 12794.134825156867,
+ "hammer-expert-v0" => 12794.134825156867,
+ "relocate-human-v0" => 4233.877797728884,
+ "relocate-cloned-v0" => 4233.877797728884,
+ "relocate-expert-v0" => 4233.877797728884,
+ "door-human-v0" => 2880.5693087298737,
+ "door-cloned-v0" => 2880.5693087298737,
+ "door-expert-v0" => 2880.5693087298737,
+ "halfcheetah-random-v0" => 12135.0,
+ "halfcheetah-medium-v0" => 12135.0,
+ "halfcheetah-expert-v0" => 12135.0,
+ "halfcheetah-medium-replay-v0" => 12135.0,
+ "halfcheetah-medium-expert-v0" => 12135.0,
+ "walker2d-random-v0" => 4592.3,
+ "walker2d-medium-v0" => 4592.3,
+ "walker2d-expert-v0" => 4592.3,
+ "walker2d-medium-replay-v0" => 4592.3,
+ "walker2d-medium-expert-v0" => 4592.3,
+ "hopper-random-v0" => 3234.3,
+ "hopper-medium-v0" => 3234.3,
+ "hopper-expert-v0" => 3234.3,
+ "hopper-medium-replay-v0" => 3234.3,
+ "hopper-medium-expert-v0" => 3234.3,
"ant-random-v0" => 3879.7,
"ant-medium-v0" => 3879.7,
"ant-expert-v0" => 3879.7,
"ant-medium-replay-v0" => 3879.7,
"ant-medium-expert-v0" => 3879.7,
- "antmaze-umaze-v0" => 1.0 ,
- "antmaze-umaze-diverse-v0" => 1.0 ,
- "antmaze-medium-play-v0" => 1.0 ,
- "antmaze-medium-diverse-v0" => 1.0 ,
- "antmaze-large-play-v0" => 1.0 ,
- "antmaze-large-diverse-v0" => 1.0 ,
- "kitchen-complete-v0" => 4.0 ,
- "kitchen-partial-v0" => 4.0 ,
- "kitchen-mixed-v0" => 4.0 ,
- "flow-ring-random-v0" => 24.42 ,
- "flow-ring-controller-v0" => 24.42 ,
- "flow-merge-random-v0" => 330.03179 ,
- "flow-merge-controller-v0" => 330.03179 ,
- "carla-lane-v0"=> 1023.5784385429523,
- "carla-town-v0"=> 2440.1772022247314, # avg dataset score
- "bullet-halfcheetah-random-v0"=> 2381.6725,
- "bullet-halfcheetah-medium-v0"=> 2381.6725,
- "bullet-halfcheetah-expert-v0"=> 2381.6725,
- "bullet-halfcheetah-medium-expert-v0"=> 2381.6725,
- "bullet-halfcheetah-medium-replay-v0"=> 2381.6725,
- "bullet-hopper-random-v0"=> 1441.8059623430963,
- "bullet-hopper-medium-v0"=> 1441.8059623430963,
- "bullet-hopper-expert-v0"=> 1441.8059623430963,
- "bullet-hopper-medium-expert-v0"=> 1441.8059623430963,
- "bullet-hopper-medium-replay-v0"=> 1441.8059623430963,
- "bullet-ant-random-v0"=> 2650.495,
- "bullet-ant-medium-v0"=> 2650.495,
- "bullet-ant-expert-v0"=> 2650.495,
- "bullet-ant-medium-expert-v0"=> 2650.495,
- "bullet-ant-medium-replay-v0"=> 2650.495,
- "bullet-walker2d-random-v0"=> 1623.6476303317536,
- "bullet-walker2d-medium-v0"=> 1623.6476303317536,
- "bullet-walker2d-expert-v0"=> 1623.6476303317536,
- "bullet-walker2d-medium-expert-v0"=> 1623.6476303317536,
- "bullet-walker2d-medium-replay-v0"=> 1623.6476303317536,
- "bullet-maze2d-open-v0"=> 64.15,
- "bullet-maze2d-umaze-v0"=> 153.99,
- "bullet-maze2d-medium-v0"=> 238.05,
- "bullet-maze2d-large-v0"=> 285.92,
+ "antmaze-umaze-v0" => 1.0,
+ "antmaze-umaze-diverse-v0" => 1.0,
+ "antmaze-medium-play-v0" => 1.0,
+ "antmaze-medium-diverse-v0" => 1.0,
+ "antmaze-large-play-v0" => 1.0,
+ "antmaze-large-diverse-v0" => 1.0,
+ "kitchen-complete-v0" => 4.0,
+ "kitchen-partial-v0" => 4.0,
+ "kitchen-mixed-v0" => 4.0,
+ "flow-ring-random-v0" => 24.42,
+ "flow-ring-controller-v0" => 24.42,
+ "flow-merge-random-v0" => 330.03179,
+ "flow-merge-controller-v0" => 330.03179,
+ "carla-lane-v0" => 1023.5784385429523,
+ "carla-town-v0" => 2440.1772022247314, # avg dataset score
+ "bullet-halfcheetah-random-v0" => 2381.6725,
+ "bullet-halfcheetah-medium-v0" => 2381.6725,
+ "bullet-halfcheetah-expert-v0" => 2381.6725,
+ "bullet-halfcheetah-medium-expert-v0" => 2381.6725,
+ "bullet-halfcheetah-medium-replay-v0" => 2381.6725,
+ "bullet-hopper-random-v0" => 1441.8059623430963,
+ "bullet-hopper-medium-v0" => 1441.8059623430963,
+ "bullet-hopper-expert-v0" => 1441.8059623430963,
+ "bullet-hopper-medium-expert-v0" => 1441.8059623430963,
+ "bullet-hopper-medium-replay-v0" => 1441.8059623430963,
+ "bullet-ant-random-v0" => 2650.495,
+ "bullet-ant-medium-v0" => 2650.495,
+ "bullet-ant-expert-v0" => 2650.495,
+ "bullet-ant-medium-expert-v0" => 2650.495,
+ "bullet-ant-medium-replay-v0" => 2650.495,
+ "bullet-walker2d-random-v0" => 1623.6476303317536,
+ "bullet-walker2d-medium-v0" => 1623.6476303317536,
+ "bullet-walker2d-expert-v0" => 1623.6476303317536,
+ "bullet-walker2d-medium-expert-v0" => 1623.6476303317536,
+ "bullet-walker2d-medium-replay-v0" => 1623.6476303317536,
+ "bullet-maze2d-open-v0" => 64.15,
+ "bullet-maze2d-umaze-v0" => 153.99,
+ "bullet-maze2d-medium-v0" => 238.05,
+ "bullet-maze2d-large-v0" => 285.92,
)
# give a prompt for flow and carla tasks
@@ -280,15 +369,15 @@ function __init__()
Credits: https://arxiv.org/abs/2004.07219
The following dataset is fetched from the d4rl.
The dataset is fetched and modified in a form that is useful for RL.jl package.
-
+
Dataset information:
Name: $(ds)
$(if ds in keys(REF_MAX_SCORE) "MAXIMUM_SCORE: " * string(REF_MAX_SCORE[ds]) end)
$(if ds in keys(REF_MIN_SCORE) "MINIMUM_SCORE: " * string(REF_MIN_SCORE[ds]) end)
""", #check if the MAX and MIN score part is even necessary and make the log file prettier
DATASET_URLS[ds],
- )
+ ),
)
end
nothing
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/test/d4rl/d4rl_dataset.jl b/src/ReinforcementLearningDatasets/test/d4rl/d4rl_dataset.jl
index dd9da8304..7ed5b9a89 100644
--- a/src/ReinforcementLearningDatasets/test/d4rl/d4rl_dataset.jl
+++ b/src/ReinforcementLearningDatasets/test/d4rl/d4rl_dataset.jl
@@ -11,7 +11,7 @@ rng = StableRNG(123)
style = style,
rng = rng,
is_shuffle = true,
- batch_size = batch_size
+ batch_size = batch_size,
)
data_dict = ds.dataset
@@ -24,7 +24,9 @@ rng = StableRNG(123)
i = 1
for sample in ds
- if i > 5 break end
+ if i > 5
+ break
+ end
@test typeof(sample) <: NamedTuple
i += 1
end
@@ -60,7 +62,7 @@ end
style = style,
rng = rng,
is_shuffle = false,
- batch_size = batch_size
+ batch_size = batch_size,
)
@@ -74,7 +76,9 @@ end
i = 1
for sample in ds
- if i > 5 break end
+ if i > 5
+ break
+ end
@test typeof(sample) <: NamedTuple
i += 1
end
diff --git a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl
index f3d2dd7e9..dfd1f0ee8 100644
--- a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl
+++ b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl
@@ -4,7 +4,7 @@ using ReinforcementLearningBase
using Random
using Requires
using IntervalSets
-using Base.Threads:@spawn
+using Base.Threads: @spawn
using Markdown
const RLEnvs = ReinforcementLearningEnvironments
@@ -31,9 +31,7 @@ function __init__()
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" include(
"environments/3rd_party/AcrobotEnv.jl",
)
- @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include(
- "plots.jl",
- )
+ @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plots.jl")
end
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl
index c371bb80b..eac7957c7 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl
@@ -18,23 +18,23 @@ AcrobotEnv(;kwargs...)
- `avail_torque = [T(-1.), T(0.), T(1.)]`
"""
function AcrobotEnv(;
- T=Float64,
- link_length_a=T(1.0),
- link_length_b=T(1.0),
- link_mass_a=T(1.0),
- link_mass_b=T(1.0),
- link_com_pos_a=T(0.5),
- link_com_pos_b=T(0.5),
- link_moi=T(1.0),
- max_torque_noise=T(0.0),
- max_vel_a=T(4 * π),
- max_vel_b=T(9 * π),
- g=T(9.8),
- dt=T(0.2),
- max_steps=200,
- rng=Random.GLOBAL_RNG,
- book_or_nips="book",
- avail_torque=[T(-1.0), T(0.0), T(1.0)],
+ T = Float64,
+ link_length_a = T(1.0),
+ link_length_b = T(1.0),
+ link_mass_a = T(1.0),
+ link_mass_b = T(1.0),
+ link_com_pos_a = T(0.5),
+ link_com_pos_b = T(0.5),
+ link_moi = T(1.0),
+ max_torque_noise = T(0.0),
+ max_vel_a = T(4 * π),
+ max_vel_b = T(9 * π),
+ g = T(9.8),
+ dt = T(0.2),
+ max_steps = 200,
+ rng = Random.GLOBAL_RNG,
+ book_or_nips = "book",
+ avail_torque = [T(-1.0), T(0.0), T(1.0)],
)
params = AcrobotEnvParams{T}(
@@ -81,7 +81,7 @@ RLBase.is_terminated(env::AcrobotEnv) = env.done
RLBase.state(env::AcrobotEnv) = acrobot_observation(env.state)
RLBase.reward(env::AcrobotEnv) = env.reward
-function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number}
+function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number}
env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05)
env.t = 0
env.action = 2
@@ -91,7 +91,7 @@ function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number}
end
# governing equations as per python gym
-function (env::AcrobotEnv{T})(a) where {T <: Number}
+function (env::AcrobotEnv{T})(a) where {T<:Number}
env.action = a
env.t += 1
torque = env.avail_torque[a]
@@ -137,7 +137,7 @@ function dsdt(du, s_augmented, env::AcrobotEnv, t)
# extract action and state
a = s_augmented[end]
- s = s_augmented[1:end - 1]
+ s = s_augmented[1:end-1]
# writing in standard form
theta1 = s[1]
@@ -201,7 +201,7 @@ function wrap(x, m, M)
while x < m
x = x + diff
end
-return x
+ return x
end
function bound(x, m, M)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl
index 3974bb37b..843bd39ee 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl
@@ -44,7 +44,7 @@ import .OpenSpiel:
`True` or `False` (instead of `true` or `false`). Another approach is to just
specify parameters in `kwargs` in the Julia style.
"""
-function OpenSpielEnv(name="kuhn_poker"; kwargs...)
+function OpenSpielEnv(name = "kuhn_poker"; kwargs...)
game = load_game(String(name); kwargs...)
state = new_initial_state(game)
OpenSpielEnv(state, game)
@@ -60,7 +60,7 @@ RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state)
RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER)
function RLBase.players(env::OpenSpielEnv)
- p = 0:(num_players(env.game) - 1)
+ p = 0:(num_players(env.game)-1)
if ChanceStyle(env) === EXPLICIT_STOCHASTIC
(p..., RLBase.chance_player(env))
else
@@ -91,7 +91,7 @@ function RLBase.prob(env::OpenSpielEnv, player)
# @assert player == chance_player(env)
p = zeros(length(action_space(env)))
for (k, v) in chance_outcomes(env.state)
- p[k + 1] = v
+ p[k+1] = v
end
p
end
@@ -102,7 +102,7 @@ function RLBase.legal_action_space_mask(env::OpenSpielEnv, player)
num_distinct_actions(env.game)
mask = BitArray(undef, n)
for a in legal_actions(env.state, player)
- mask[a + 1] = true
+ mask[a+1] = true
end
mask
end
@@ -136,12 +136,16 @@ end
_state(env::OpenSpielEnv, ::RLBase.InformationSet{String}, player) =
information_state_string(env.state, player)
-_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) =
- reshape(information_state_tensor(env.state, player), reverse(information_state_tensor_shape(env.game))...)
+_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) = reshape(
+ information_state_tensor(env.state, player),
+ reverse(information_state_tensor_shape(env.game))...,
+)
_state(env::OpenSpielEnv, ::Observation{String}, player) =
observation_string(env.state, player)
-_state(env::OpenSpielEnv, ::Observation{Array}, player) =
- reshape(observation_tensor(env.state, player), reverse(observation_tensor_shape(env.game))...)
+_state(env::OpenSpielEnv, ::Observation{Array}, player) = reshape(
+ observation_tensor(env.state, player),
+ reverse(observation_tensor_shape(env.game))...,
+)
RLBase.state_space(
env::OpenSpielEnv,
@@ -149,16 +153,18 @@ RLBase.state_space(
p,
) = WorldSpace{AbstractString}()
-RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array},
- p,
-) = Space(
- fill(typemin(Float64)..typemax(Float64), reverse(information_state_tensor_shape(env.game))...),
+RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array}, p) = Space(
+ fill(
+ typemin(Float64)..typemax(Float64),
+ reverse(information_state_tensor_shape(env.game))...,
+ ),
)
-RLBase.state_space(env::OpenSpielEnv, ::Observation{Array},
- p,
-) = Space(
- fill(typemin(Float64)..typemax(Float64), reverse(observation_tensor_shape(env.game))...),
+RLBase.state_space(env::OpenSpielEnv, ::Observation{Array}, p) = Space(
+ fill(
+ typemin(Float64)..typemax(Float64),
+ reverse(observation_tensor_shape(env.game))...,
+ ),
)
Random.seed!(env::OpenSpielEnv, s) = @warn "seed!(OpenSpielEnv) is not supported currently."
@@ -199,7 +205,9 @@ RLBase.RewardStyle(env::OpenSpielEnv) =
reward_model(get_type(env.game)) == OpenSpiel.REWARDS ? RLBase.STEP_REWARD :
RLBase.TERMINAL_REWARD
-RLBase.StateStyle(env::OpenSpielEnv) = (RLBase.InformationSet{String}(),
+RLBase.StateStyle(env::OpenSpielEnv) = (
+ RLBase.InformationSet{String}(),
RLBase.InformationSet{Array}(),
RLBase.Observation{String}(),
- RLBase.Observation{Array}(),)
+ RLBase.Observation{Array}(),
+)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl
index 83586f4e3..0acd51427 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl
@@ -6,7 +6,7 @@ struct GymEnv{T,Ta,To,P} <: AbstractEnv
end
export GymEnv
-mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S <: AbstractRNG} <: AbstractEnv
+mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv
ale::Ptr{Nothing}
name::String
screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling
@@ -65,7 +65,7 @@ end
export AcrobotEnvParams
-mutable struct AcrobotEnv{T,R <: AbstractRNG} <: AbstractEnv
+mutable struct AcrobotEnv{T,R<:AbstractRNG} <: AbstractEnv
params::AcrobotEnvParams{T}
state::Vector{T}
action::Int
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl
index b178b6ecf..3e5bd2264 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl
@@ -13,7 +13,7 @@ end
`legal_action_space(env)`. `action_mapping` will be applied to `action` before
feeding it into `env`.
"""
-ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) =
+ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) =
ActionTransformedEnv(env, action_mapping, action_space_mapping)
RLBase.action_space(env::ActionTransformedEnv, args...) =
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl
index 3e6996c52..95ce36788 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl
@@ -13,9 +13,10 @@ DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)
RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S
-RLBase.state(env::DefaultStateStyleEnv{S}) where S = state(env.env, S)
+RLBase.state(env::DefaultStateStyleEnv{S}) where {S} = state(env.env, S)
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
-RLBase.state(env::DefaultStateStyleEnv{S}, player) where S = state(env.env, S, player)
+RLBase.state(env::DefaultStateStyleEnv{S}, player) where {S} = state(env.env, S, player)
-RLBase.state_space(env::DefaultStateStyleEnv{S}) where S = state_space(env.env, S)
-RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss)
\ No newline at end of file
+RLBase.state_space(env::DefaultStateStyleEnv{S}) where {S} = state_space(env.env, S)
+RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
+ state_space(env.env, ss)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl
index 4f18af426..2a88903dd 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl
@@ -9,7 +9,7 @@ mutable struct SequentialEnv{E<:AbstractEnv} <: AbstractEnvWrapper
env::E
current_player_idx::Int
actions::Vector{Any}
- function SequentialEnv(env::T) where T<:AbstractEnv
+ function SequentialEnv(env::T) where {T<:AbstractEnv}
@assert DynamicStyle(env) === SIMULTANEOUS "The SequentialEnv wrapper can only be applied to SIMULTANEOUS environments"
new{T}(env, 1, Vector{Any}(undef, length(players(env))))
end
@@ -32,7 +32,8 @@ end
RLBase.reward(env::SequentialEnv) = reward(env, current_player(env))
-RLBase.reward(env::SequentialEnv, player) = current_player(env) == 1 ? reward(env.env, player) : 0
+RLBase.reward(env::SequentialEnv, player) =
+ current_player(env) == 1 ? reward(env.env, player) : 0
function (env::SequentialEnv)(action)
env.actions[env.current_player_idx] = action
@@ -43,4 +44,3 @@ function (env::SequentialEnv)(action)
env.current_player_idx += 1
end
end
-
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl
index e8626a3b8..97e18e928 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl
@@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some
environments are stateful during each `state(env)`. For example:
`StateTransformedEnv(StackFrames(...))`.
"""
-mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper
+mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper
s::S
env::E
is_state_cached::Bool
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl
index dfe90bddd..840c9ecdb 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl
@@ -12,11 +12,11 @@ end
`state_mapping` will be applied on the original state when calling `state(env)`,
and similarly `state_space_mapping` will be applied when calling `state_space(env)`.
"""
-StateTransformedEnv(env; state_mapping=identity, state_space_mapping=identity) =
+StateTransformedEnv(env; state_mapping = identity, state_space_mapping = identity) =
StateTransformedEnv(env, state_mapping, state_space_mapping)
RLBase.state(env::StateTransformedEnv, args...; kwargs...) =
env.state_mapping(state(env.env, args...; kwargs...))
-RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) =
+RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) =
env.state_space_mapping(state_space(env.env, args...; kwargs...))
diff --git a/src/ReinforcementLearningEnvironments/src/plots.jl b/src/ReinforcementLearningEnvironments/src/plots.jl
index c43a59f4a..35126cab4 100644
--- a/src/ReinforcementLearningEnvironments/src/plots.jl
+++ b/src/ReinforcementLearningEnvironments/src/plots.jl
@@ -8,35 +8,35 @@ function plot(env::CartPoleEnv)
xthreshold = env.params.xthreshold
# set the frame
p = plot(
- xlims=(-xthreshold, xthreshold),
- ylims=(-.1, l + 0.1),
- legend=false,
- border=:none,
+ xlims = (-xthreshold, xthreshold),
+ ylims = (-.1, l + 0.1),
+ legend = false,
+ border = :none,
)
# plot the cart
- plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05];
- seriestype=:shape,
- )
+ plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05]; seriestype = :shape)
# plot the pole
- plot!([x, x + l * sin(theta)], [0, l * cos(theta)];
- linewidth=3,
- )
+ plot!([x, x + l * sin(theta)], [0, l * cos(theta)]; linewidth = 3)
# plot the arrow
- plot!([x + (a == 1) - 0.5, x + 1.4 * (a == 1)-0.7], [ -.025, -.025];
- linewidth=3,
- arrow=true,
- color=2,
+ plot!(
+ [x + (a == 1) - 0.5, x + 1.4 * (a == 1) - 0.7],
+ [-.025, -.025];
+ linewidth = 3,
+ arrow = true,
+ color = 2,
)
# if done plot pink circle in top right
if d
- plot!([xthreshold - 0.2], [l];
- marker=:circle,
- markersize=20,
- markerstrokewidth=0.,
- color=:pink,
+ plot!(
+ [xthreshold - 0.2],
+ [l];
+ marker = :circle,
+ markersize = 20,
+ markerstrokewidth = 0.0,
+ color = :pink,
)
end
-
+
p
end
@@ -51,10 +51,10 @@ function plot(env::MountainCarEnv)
d = env.done
p = plot(
- xlims=(env.params.min_pos - 0.1, env.params.max_pos + 0.2),
- ylims=(-.1, height(env.params.max_pos) + 0.2),
- legend=false,
- border=:none,
+ xlims = (env.params.min_pos - 0.1, env.params.max_pos + 0.2),
+ ylims = (-.1, height(env.params.max_pos) + 0.2),
+ legend = false,
+ border = :none,
)
# plot the terrain
xs = LinRange(env.params.min_pos, env.params.max_pos, 100)
@@ -72,17 +72,19 @@ function plot(env::MountainCarEnv)
ys .+= clearance
xs, ys = rotate(xs, ys, θ)
xs, ys = translate(xs, ys, [x, height(x)])
- plot!(xs, ys; seriestype=:shape)
+ plot!(xs, ys; seriestype = :shape)
# if done plot pink circle in top right
if d
- plot!([xthreshold - 0.2], [l];
- marker=:circle,
- markersize=20,
- markerstrokewidth=0.,
- color=:pink,
+ plot!(
+ [xthreshold - 0.2],
+ [l];
+ marker = :circle,
+ markersize = 20,
+ markerstrokewidth = 0.0,
+ color = :pink,
)
end
p
- end
+end
diff --git a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl
index 208307a67..049defa22 100644
--- a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl
+++ b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl
@@ -1,11 +1,11 @@
@testset "wrappers" begin
@testset "ActionTransformedEnv" begin
- env = TigerProblemEnv(; rng=StableRNG(123))
+ env = TigerProblemEnv(; rng = StableRNG(123))
env′ = ActionTransformedEnv(
env;
- action_space_mapping=x -> Base.OneTo(3),
- action_mapping=i -> action_space(env)[i],
+ action_space_mapping = x -> Base.OneTo(3),
+ action_mapping = i -> action_space(env)[i],
)
RLBase.test_interfaces!(env′)
@@ -14,7 +14,7 @@
@testset "DefaultStateStyleEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
S = InternalState{Int}()
env′ = DefaultStateStyleEnv{S}(env)
@test DefaultStateStyle(env′) === S
@@ -35,7 +35,7 @@
@testset "MaxTimeoutEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
n = 100
env′ = MaxTimeoutEnv(env, n)
@@ -55,7 +55,7 @@
@testset "RewardOverriddenEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
env′ = RewardOverriddenEnv(env, x -> sign(x))
RLBase.test_interfaces!(env′)
@@ -69,7 +69,7 @@
@testset "StateCachedEnv" begin
rng = StableRNG(123)
- env = CartPoleEnv(; rng=rng)
+ env = CartPoleEnv(; rng = rng)
env′ = StateCachedEnv(env)
RLBase.test_interfaces!(env′)
@@ -85,7 +85,7 @@
@testset "StateTransformedEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
# S = (:door1, :door2, :door3, :none)
# env′ = StateTransformedEnv(env, state_mapping=s -> s+1)
# RLBase.state_space(env::typeof(env′), ::RLBase.AbstractStateStyle, ::Any) = S
@@ -97,14 +97,14 @@
@testset "StochasticEnv" begin
env = KuhnPokerEnv()
rng = StableRNG(123)
- env′ = StochasticEnv(env; rng=rng)
+ env′ = StochasticEnv(env; rng = rng)
RLBase.test_interfaces!(env′)
RLBase.test_runnable!(env′)
end
@testset "SequentialEnv" begin
- env = RockPaperScissorsEnv()
+ env = RockPaperScissorsEnv()
env′ = SequentialEnv(env)
RLBase.test_interfaces!(env′)
RLBase.test_runnable!(env′)
diff --git a/src/ReinforcementLearningExperiments/deps/build.jl b/src/ReinforcementLearningExperiments/deps/build.jl
index def42b559..6dfde6249 100644
--- a/src/ReinforcementLearningExperiments/deps/build.jl
+++ b/src/ReinforcementLearningExperiments/deps/build.jl
@@ -2,10 +2,11 @@ using Weave
const DEST_DIR = joinpath(@__DIR__, "..", "src", "experiments")
-for (root, dirs, files) in walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments"))
+for (root, dirs, files) in
+ walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments"))
for f in files
if splitext(f)[2] == ".jl"
- tangle(joinpath(root,f);informat="script", out_path=DEST_DIR)
+ tangle(joinpath(root, f); informat = "script", out_path = DEST_DIR)
end
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl
index 0f615ba3f..2a273443a 100644
--- a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl
+++ b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl
@@ -23,7 +23,6 @@ for f in readdir(EXPERIMENTS_DIR)
end
# dynamic loading environments
-function __init__()
-end
+function __init__() end
end # module
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl
index ddd2c6ace..793eaa897 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl
@@ -4,7 +4,10 @@
const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}
-function RLBase.update!(learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory)
+function RLBase.update!(
+ learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners},
+ t::AbstractTrajectory,
+)
length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return
learner.update_step += 1
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl
index b2ebab93c..ad2808f26 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl
@@ -55,7 +55,7 @@ function DQNLearner(;
traces = SARTS,
update_step = 0,
rng = Random.GLOBAL_RNG,
- is_enable_double_DQN::Bool = true
+ is_enable_double_DQN::Bool = true,
) where {Tq,Tt,Tf}
copyto!(approximator, target_approximator)
sampler = NStepBatchSampler{traces}(;
@@ -75,7 +75,7 @@ function DQNLearner(;
sampler,
rng,
0.0f0,
- is_enable_double_DQN
+ is_enable_double_DQN,
)
end
@@ -117,14 +117,14 @@ function RLBase.update!(learner::DQNLearner, batch::NamedTuple)
else
q_values = Qₜ(s′)
end
-
+
if haskey(batch, :next_legal_actions_mask)
l′ = send_to_device(D, batch[:next_legal_actions_mask])
q_values .+= ifelse.(l′, 0.0f0, typemin(Float32))
end
if is_enable_double_DQN
- selected_actions = dropdims(argmax(q_values, dims=1), dims=1)
+ selected_actions = dropdims(argmax(q_values, dims = 1), dims = 1)
q′ = Qₜ(s′)[selected_actions]
else
q′ = dropdims(maximum(q_values; dims = 1), dims = 1)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
index 4ec47c5ba..8f190ba2c 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
@@ -5,4 +5,4 @@ include("qr_dqn.jl")
include("rem_dqn.jl")
include("rainbow.jl")
include("iqn.jl")
-include("common.jl")
\ No newline at end of file
+include("common.jl")
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl
index 2832b1905..0f6c1b519 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl
@@ -1,6 +1,6 @@
export QRDQNLearner, quantile_huber_loss
-function quantile_huber_loss(ŷ, y; κ=1.0f0)
+function quantile_huber_loss(ŷ, y; κ = 1.0f0)
N, B = size(y)
Δ = reshape(y, N, 1, B) .- reshape(ŷ, 1, N, B)
abs_error = abs.(Δ)
@@ -8,12 +8,13 @@ function quantile_huber_loss(ŷ, y; κ=1.0f0)
linear = abs_error .- quadratic
huber_loss = 0.5f0 .* quadratic .* quadratic .+ κ .* linear
- cum_prob = send_to_device(device(y), range(0.5f0 / N; length=N, step=1.0f0 / N))
+ cum_prob = send_to_device(device(y), range(0.5f0 / N; length = N, step = 1.0f0 / N))
loss = Zygote.dropgrad(abs.(cum_prob .- (Δ .< 0))) .* huber_loss
- mean(sum(loss;dims=1))
+ mean(sum(loss; dims = 1))
end
-mutable struct QRDQNLearner{Tq <: AbstractApproximator,Tt <: AbstractApproximator,Tf,R} <: AbstractLearner
+mutable struct QRDQNLearner{Tq<:AbstractApproximator,Tt<:AbstractApproximator,Tf,R} <:
+ AbstractLearner
approximator::Tq
target_approximator::Tt
min_replay_history::Int
@@ -51,25 +52,25 @@ See paper: [Distributional Reinforcement Learning with Quantile Regression](http
function QRDQNLearner(;
approximator,
target_approximator,
- stack_size::Union{Int,Nothing}=nothing,
- γ::Float32=0.99f0,
- batch_size::Int=32,
- update_horizon::Int=1,
- min_replay_history::Int=32,
- update_freq::Int=1,
- n_quantile::Int=1,
- target_update_freq::Int=100,
- traces=SARTS,
- update_step=0,
- loss_func=quantile_huber_loss,
- rng=Random.GLOBAL_RNG
+ stack_size::Union{Int,Nothing} = nothing,
+ γ::Float32 = 0.99f0,
+ batch_size::Int = 32,
+ update_horizon::Int = 1,
+ min_replay_history::Int = 32,
+ update_freq::Int = 1,
+ n_quantile::Int = 1,
+ target_update_freq::Int = 100,
+ traces = SARTS,
+ update_step = 0,
+ loss_func = quantile_huber_loss,
+ rng = Random.GLOBAL_RNG,
)
copyto!(approximator, target_approximator)
sampler = NStepBatchSampler{traces}(;
- γ=γ,
- n=update_horizon,
- stack_size=stack_size,
- batch_size=batch_size,
+ γ = γ,
+ n = update_horizon,
+ stack_size = stack_size,
+ batch_size = batch_size,
)
N = n_quantile
@@ -100,7 +101,7 @@ function (learner::QRDQNLearner)(env)
s = send_to_device(device(learner.approximator), state(env))
s = Flux.unsqueeze(s, ndims(s) + 1)
q = reshape(learner.approximator(s), learner.n_quantile, :)
- vec(mean(q, dims=1)) |> send_to_host
+ vec(mean(q, dims = 1)) |> send_to_host
end
function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple)
@@ -117,10 +118,12 @@ function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple)
a = CartesianIndex.(a, 1:batch_size)
target_quantiles = reshape(Qₜ(s′), N, :, batch_size)
- qₜ = dropdims(mean(target_quantiles; dims=1); dims=1)
- aₜ = dropdims(argmax(qₜ, dims=1); dims=1)
+ qₜ = dropdims(mean(target_quantiles; dims = 1); dims = 1)
+ aₜ = dropdims(argmax(qₜ, dims = 1); dims = 1)
@views target_quantile_aₜ = target_quantiles[:, aₜ]
- y = reshape(r, 1, batch_size) .+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ
+ y =
+ reshape(r, 1, batch_size) .+
+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ
gs = gradient(params(Q)) do
q = reshape(Q(s), N, :, batch_size)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl
index 182ce253f..08f270429 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl
@@ -120,7 +120,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
target_q = Qₜ(s′)
target_q = convex_polygon .* reshape(target_q, :, ensemble_num, batch_size)
- target_q = dropdims(sum(target_q, dims=2), dims=2)
+ target_q = dropdims(sum(target_q, dims = 2), dims = 2)
if haskey(batch, :next_legal_actions_mask)
l′ = send_to_device(D, batch[:next_legal_actions_mask])
@@ -133,7 +133,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
gs = gradient(params(Q)) do
q = Q(s)
q = convex_polygon .* reshape(q, :, ensemble_num, batch_size)
- q = dropdims(sum(q, dims=2), dims=2)[a]
+ q = dropdims(sum(q, dims = 2), dims = 2)[a]
loss = loss_func(G, q)
ignore() do
@@ -143,5 +143,4 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
end
update!(Q, gs)
-end
-
+end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl
index 20aa98d79..b0f76c2de 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl
@@ -19,19 +19,14 @@ end
- `rng = Random.GLOBAL_RNG`
"""
function BehaviorCloningPolicy(;
- approximator::A,
- explorer::Any = GreedyExplorer(),
- batch_size::Int = 32,
- min_reservoir_history::Int = 100,
- rng = Random.GLOBAL_RNG
+ approximator::A,
+ explorer::Any = GreedyExplorer(),
+ batch_size::Int = 32,
+ min_reservoir_history::Int = 100,
+ rng = Random.GLOBAL_RNG,
) where {A}
sampler = BatchSampler{(:state, :action)}(batch_size; rng = rng)
- BehaviorCloningPolicy(
- approximator,
- explorer,
- sampler,
- min_reservoir_history,
- )
+ BehaviorCloningPolicy(approximator, explorer, sampler, min_reservoir_history)
end
function (p::BehaviorCloningPolicy)(env::AbstractEnv)
@@ -39,7 +34,8 @@ function (p::BehaviorCloningPolicy)(env::AbstractEnv)
s_batch = Flux.unsqueeze(s, ndims(s) + 1)
s_batch = send_to_device(device(p.approximator), s_batch)
logits = p.approximator(s_batch) |> vec |> send_to_host # drop dimension
- typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) : p.explorer(logits, legal_action_space_mask(env))
+ typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) :
+ p.explorer(logits, legal_action_space_mask(env))
end
function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :action)})
@@ -64,7 +60,8 @@ function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv)
s = state(env)
s_batch = Flux.unsqueeze(s, ndims(s) + 1)
values = p.approximator(s_batch) |> vec |> send_to_host
- typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) : prob(p.explorer, values, legal_action_space_mask(env))
+ typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) :
+ prob(p.explorer, values, legal_action_space_mask(env))
end
function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv, action)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl
index 641418b61..cba65bf4e 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl
@@ -118,7 +118,12 @@ function (p::DDPGPolicy)(env)
s = state(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
actions = p.behavior_actor(send_to_device(D, s)) |> vec |> send_to_host
- c = clamp.(actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na), -p.act_limit, p.act_limit)
+ c =
+ clamp.(
+ actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na),
+ -p.act_limit,
+ p.act_limit,
+ )
p.na == 1 && return c[1]
c
end
@@ -154,7 +159,7 @@ function RLBase.update!(p::DDPGPolicy, batch::NamedTuple{SARTS})
a′ = Aₜ(s′)
qₜ = Cₜ(vcat(s′, a′)) |> vec
y = r .+ γ .* (1 .- t) .* qₜ
- a = Flux.unsqueeze(a, ndims(a)+1)
+ a = Flux.unsqueeze(a, ndims(a) + 1)
gs1 = gradient(Flux.params(C)) do
q = C(vcat(s, a)) |> vec
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl
index fa06a9132..bf9e8f27e 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl
@@ -148,7 +148,9 @@ function RLBase.prob(
if p.update_step < p.n_random_start
@error "todo"
else
- μ, logσ = p.approximator.actor(send_to_device(device(p.approximator), state)) |> send_to_host
+ μ, logσ =
+ p.approximator.actor(send_to_device(device(p.approximator), state)) |>
+ send_to_host
StructArray{Normal}((μ, exp.(logσ)))
end
end
@@ -256,11 +258,11 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
end
s = send_to_device(D, select_last_dim(states_flatten, inds)) # !!! performance critical
a = send_to_device(D, select_last_dim(actions_flatten, inds))
-
+
if eltype(a) === Int
a = CartesianIndex.(a, 1:length(a))
end
-
+
r = send_to_device(D, vec(returns)[inds])
log_p = send_to_device(D, vec(action_log_probs)[inds])
adv = send_to_device(D, vec(advantages)[inds])
@@ -275,7 +277,8 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
else
log_p′ₐ = normlogpdf(μ, exp.(logσ), a)
end
- entropy_loss = mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2
+ entropy_loss =
+ mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2
else
# actor is assumed to return discrete logits
logit′ = AC.actor(s)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl
index f567fac76..73b5ef83d 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl
@@ -104,8 +104,8 @@ function SACPolicy(;
Float32(-action_dims),
step,
rng,
- 0f0,
- 0f0,
+ 0.0f0,
+ 0.0f0,
)
end
@@ -120,7 +120,10 @@ function (p::SACPolicy)(env)
s = state(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
# trainmode:
- action = dropdims(p.policy.model(s; is_sampling=true, is_return_log_prob=true)[1], dims=2) # Single action vec, drop second dim
+ action = dropdims(
+ p.policy.model(s; is_sampling = true, is_return_log_prob = true)[1],
+ dims = 2,
+ ) # Single action vec, drop second dim
# testmode:
# if testing dont sample an action, but act deterministically by
@@ -146,7 +149,7 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS})
γ, τ, α = p.γ, p.τ, p.α
- a′, log_π = p.policy.model(s′; is_sampling=true, is_return_log_prob=true)
+ a′, log_π = p.policy.model(s′; is_sampling = true, is_return_log_prob = true)
q′_input = vcat(s′, a′)
q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input))
@@ -168,12 +171,12 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS})
# Train Policy
p_grad = gradient(Flux.params(p.policy)) do
- a, log_π = p.policy.model(s; is_sampling=true, is_return_log_prob=true)
+ a, log_π = p.policy.model(s; is_sampling = true, is_return_log_prob = true)
q_input = vcat(s, a)
q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input))
reward = mean(q)
entropy = mean(log_π)
- ignore() do
+ ignore() do
p.reward_term = reward
p.entropy_term = entropy
end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl
index a91c22b62..d0bbedfa3 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl
@@ -11,7 +11,7 @@ A `Dict` is used internally to store the mapping from state to action.
"""
Base.@kwdef struct TabularPolicy{S,A} <: AbstractPolicy
table::Dict{S,A} = Dict{Int,Int}()
- n_action::Union{Int, Nothing} = nothing
+ n_action::Union{Int,Nothing} = nothing
end
(p::TabularPolicy)(env::AbstractEnv) = p(state(env))
diff --git a/test/runtests.jl b/test/runtests.jl
index 04be42cbd..d6dd95c9e 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,4 @@
using Test
using ReinforcementLearning
-@testset "ReinforcementLearning" begin
-end
+@testset "ReinforcementLearning" begin end