diff --git a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl
index ff647a3a4..2eeecbd74 100644
--- a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl
+++ b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl
@@ -61,5 +61,11 @@ function RL.Experiment(
batch_size_Π = 2048,
initializer = glorot_normal(CUDA.CURAND.default_rng()),
)
- Experiment(p, env, StopAfterStep(500, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), "# run DeepcCFR on leduc_poker")
-end
\ No newline at end of file
+ Experiment(
+ p,
+ env,
+ StopAfterStep(500, is_show_progress = !haskey(ENV, "CI")),
+ EmptyHook(),
+ "# run DeepcCFR on leduc_poker",
+ )
+end
diff --git a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl
index edfd7f199..d89cabb16 100644
--- a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl
+++ b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl
@@ -23,8 +23,14 @@ function RL.Experiment(
π = TabularCFRPolicy(; rng = rng)
description = "# Play `$game` in OpenSpiel with TabularCFRPolicy"
- Experiment(π, env, StopAfterStep(300, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), description)
+ Experiment(
+ π,
+ env,
+ StopAfterStep(300, is_show_progress = !haskey(ENV, "CI")),
+ EmptyHook(),
+ description,
+ )
end
ex = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)`
-run(ex)
\ No newline at end of file
+run(ex)
diff --git a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl
index f51b4a2a9..59574a7a2 100644
--- a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl
+++ b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl
@@ -79,44 +79,41 @@ function atari_env_factory(
repeat_action_probability = 0.25,
n_replica = nothing,
)
- init(seed) =
- RewardOverriddenEnv(
- StateCachedEnv(
- StateTransformedEnv(
- AtariEnv(;
- name = string(name),
- grayscale_obs = true,
- noop_max = 30,
- frame_skip = 4,
- terminal_on_life_loss = false,
- repeat_action_probability = repeat_action_probability,
- max_num_frames_per_episode = n_frames * max_episode_steps,
- color_averaging = false,
- full_action_space = false,
- seed = seed,
- );
- state_mapping=Chain(
- ResizeImage(state_size...),
- StackFrames(state_size..., n_frames)
- ),
- state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
- )
+ init(seed) = RewardOverriddenEnv(
+ StateCachedEnv(
+ StateTransformedEnv(
+ AtariEnv(;
+ name = string(name),
+ grayscale_obs = true,
+ noop_max = 30,
+ frame_skip = 4,
+ terminal_on_life_loss = false,
+ repeat_action_probability = repeat_action_probability,
+ max_num_frames_per_episode = n_frames * max_episode_steps,
+ color_averaging = false,
+ full_action_space = false,
+ seed = seed,
+ );
+ state_mapping = Chain(
+ ResizeImage(state_size...),
+ StackFrames(state_size..., n_frames),
+ ),
+ state_space_mapping = _ ->
+ Space(fill(0 .. 256, state_size..., n_frames)),
),
- r -> clamp(r, -1, 1)
- )
+ ),
+ r -> clamp(r, -1, 1),
+ )
if isnothing(n_replica)
init(seed)
else
- envs = [
- init(isnothing(seed) ? nothing : hash(seed + i))
- for i in 1:n_replica
- ]
+ envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
states = Flux.batch(state.(envs))
rewards = reward.(envs)
terminals = is_terminated.(envs)
A = Space([action_space(x) for x in envs])
- S = Space(fill(0..255, size(states)))
+ S = Space(fill(0 .. 255, size(states)))
MultiThreadEnv(envs, states, rewards, terminals, A, S, nothing)
end
end
@@ -172,7 +169,7 @@ function RL.Experiment(
::Val{:Atari},
name::AbstractString;
save_dir = nothing,
- seed = nothing
+ seed = nothing,
)
rng = Random.GLOBAL_RNG
Random.seed!(rng, seed)
@@ -190,7 +187,7 @@ function RL.Experiment(
name,
STATE_SIZE,
N_FRAMES;
- seed = isnothing(seed) ? nothing : hash(seed + 1)
+ seed = isnothing(seed) ? nothing : hash(seed + 1),
)
N_ACTIONS = length(action_space(env))
init = glorot_uniform(rng)
@@ -254,17 +251,15 @@ function RL.Experiment(
end,
DoEveryNEpisode() do t, agent, env
with_logger(lg) do
- @info "training" episode_length = step_per_episode.steps[end] reward = reward_per_episode.rewards[end] log_step_increment = 0
+ @info "training" episode_length = step_per_episode.steps[end] reward =
+ reward_per_episode.rewards[end] log_step_increment = 0
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
p = agent.policy
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
- h = ComposedHook(
- TotalOriginalRewardPerEpisode(),
- StepsPerEpisode(),
- )
+ h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode())
s = @elapsed run(
p,
atari_env_factory(
@@ -281,16 +276,18 @@ function RL.Experiment(
avg_score = mean(h[1].rewards[1:end-1])
avg_length = mean(h[2].steps[1:end-1])
- @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
+ @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
+ avg_score
with_logger(lg) do
- @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
+ @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
+ 0
end
end,
)
stop_condition = StopAfterStep(
haskey(ENV, "CI") ? 1_000 : 50_000_000,
- is_show_progress=!haskey(ENV, "CI")
+ is_show_progress = !haskey(ENV, "CI"),
)
Experiment(agent, env, stop_condition, hook, "# DQN <-> Atari($name)")
end
diff --git a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl
index 6e0305ae5..f0a2b1bdb 100644
--- a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl
+++ b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl
@@ -84,44 +84,41 @@ function atari_env_factory(
repeat_action_probability = 0.25,
n_replica = nothing,
)
- init(seed) =
- RewardOverriddenEnv(
- StateCachedEnv(
- StateTransformedEnv(
- AtariEnv(;
- name = string(name),
- grayscale_obs = true,
- noop_max = 30,
- frame_skip = 4,
- terminal_on_life_loss = false,
- repeat_action_probability = repeat_action_probability,
- max_num_frames_per_episode = n_frames * max_episode_steps,
- color_averaging = false,
- full_action_space = false,
- seed = seed,
- );
- state_mapping=Chain(
- ResizeImage(state_size...),
- StackFrames(state_size..., n_frames)
- ),
- state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
- )
+ init(seed) = RewardOverriddenEnv(
+ StateCachedEnv(
+ StateTransformedEnv(
+ AtariEnv(;
+ name = string(name),
+ grayscale_obs = true,
+ noop_max = 30,
+ frame_skip = 4,
+ terminal_on_life_loss = false,
+ repeat_action_probability = repeat_action_probability,
+ max_num_frames_per_episode = n_frames * max_episode_steps,
+ color_averaging = false,
+ full_action_space = false,
+ seed = seed,
+ );
+ state_mapping = Chain(
+ ResizeImage(state_size...),
+ StackFrames(state_size..., n_frames),
+ ),
+ state_space_mapping = _ ->
+ Space(fill(0 .. 256, state_size..., n_frames)),
),
- r -> clamp(r, -1, 1)
- )
+ ),
+ r -> clamp(r, -1, 1),
+ )
if isnothing(n_replica)
init(seed)
else
- envs = [
- init(isnothing(seed) ? nothing : hash(seed + i))
- for i in 1:n_replica
- ]
+ envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
states = Flux.batch(state.(envs))
rewards = reward.(envs)
terminals = is_terminated.(envs)
A = Space([action_space(x) for x in envs])
- S = Space(fill(0..255, size(states)))
+ S = Space(fill(0 .. 255, size(states)))
MultiThreadEnv(envs, states, rewards, terminals, A, S, nothing)
end
end
@@ -195,7 +192,12 @@ function RL.Experiment(
N_FRAMES = 4
STATE_SIZE = (84, 84)
- env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 2))
+ env = atari_env_factory(
+ name,
+ STATE_SIZE,
+ N_FRAMES;
+ seed = isnothing(seed) ? nothing : hash(seed + 2),
+ )
N_ACTIONS = length(action_space(env))
Nₑₘ = 64
@@ -250,7 +252,7 @@ function RL.Experiment(
),
),
trajectory = CircularArraySARTTrajectory(
- capacity = haskey(ENV, "CI") : 1_000 : 1_000_000,
+ capacity = haskey(ENV, "CI"):1_000:1_000_000,
state = Matrix{Float32} => STATE_SIZE,
),
)
@@ -274,7 +276,7 @@ function RL.Experiment(
steps_per_episode.steps[end] log_step_increment = 0
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
p = agent.policy
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
@@ -286,7 +288,7 @@ function RL.Experiment(
STATE_SIZE,
N_FRAMES,
MAX_EPISODE_STEPS_EVAL;
- seed = isnothing(seed) ? nothing : hash(seed + t)
+ seed = isnothing(seed) ? nothing : hash(seed + t),
),
StopAfterStep(125_000; is_show_progress = false),
h,
@@ -295,16 +297,18 @@ function RL.Experiment(
avg_score = mean(h[1].rewards[1:end-1])
avg_length = mean(h[2].steps[1:end-1])
- @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
+ @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
+ avg_score
with_logger(lg) do
- @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
+ @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
+ 0
end
end,
)
stop_condition = StopAfterStep(
haskey(ENV, "CI") ? 10_000 : 50_000_000,
- is_show_progress=!haskey(ENV, "CI")
+ is_show_progress = !haskey(ENV, "CI"),
)
Experiment(agent, env, stop_condition, hook, "# IQN <-> Atari($name)")
end
diff --git a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl
index 432e110e4..a49efb658 100644
--- a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl
+++ b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl
@@ -83,44 +83,41 @@ function atari_env_factory(
repeat_action_probability = 0.25,
n_replica = nothing,
)
- init(seed) =
- RewardOverriddenEnv(
- StateCachedEnv(
- StateTransformedEnv(
- AtariEnv(;
- name = string(name),
- grayscale_obs = true,
- noop_max = 30,
- frame_skip = 4,
- terminal_on_life_loss = false,
- repeat_action_probability = repeat_action_probability,
- max_num_frames_per_episode = n_frames * max_episode_steps,
- color_averaging = false,
- full_action_space = false,
- seed = seed,
- );
- state_mapping=Chain(
- ResizeImage(state_size...),
- StackFrames(state_size..., n_frames)
- ),
- state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames))
- )
+ init(seed) = RewardOverriddenEnv(
+ StateCachedEnv(
+ StateTransformedEnv(
+ AtariEnv(;
+ name = string(name),
+ grayscale_obs = true,
+ noop_max = 30,
+ frame_skip = 4,
+ terminal_on_life_loss = false,
+ repeat_action_probability = repeat_action_probability,
+ max_num_frames_per_episode = n_frames * max_episode_steps,
+ color_averaging = false,
+ full_action_space = false,
+ seed = seed,
+ );
+ state_mapping = Chain(
+ ResizeImage(state_size...),
+ StackFrames(state_size..., n_frames),
+ ),
+ state_space_mapping = _ ->
+ Space(fill(0 .. 256, state_size..., n_frames)),
),
- r -> clamp(r, -1, 1)
- )
+ ),
+ r -> clamp(r, -1, 1),
+ )
if isnothing(n_replica)
init(seed)
else
- envs = [
- init(isnothing(seed) ? nothing : hash(seed + i))
- for i in 1:n_replica
- ]
+ envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica]
states = Flux.batch(state.(envs))
rewards = reward.(envs)
terminals = is_terminated.(envs)
A = Space([action_space(x) for x in envs])
- S = Space(fill(0..255, size(states)))
+ S = Space(fill(0 .. 255, size(states)))
MultiThreadEnv(envs, states, rewards, terminals, A, S, nothing)
end
end
@@ -191,7 +188,12 @@ function RL.Experiment(
N_FRAMES = 4
STATE_SIZE = (84, 84)
- env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 1))
+ env = atari_env_factory(
+ name,
+ STATE_SIZE,
+ N_FRAMES;
+ seed = isnothing(seed) ? nothing : hash(seed + 1),
+ )
N_ACTIONS = length(action_space(env))
N_ATOMS = 51
init = glorot_uniform(rng)
@@ -238,7 +240,7 @@ function RL.Experiment(
),
),
trajectory = CircularArrayPSARTTrajectory(
- capacity = haskey(ENV, "CI") : 1_000 : 1_000_000,
+ capacity = haskey(ENV, "CI"):1_000:1_000_000,
state = Matrix{Float32} => STATE_SIZE,
),
)
@@ -262,7 +264,7 @@ function RL.Experiment(
steps_per_episode.steps[end] log_step_increment = 0
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
p = agent.policy
p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon
@@ -282,16 +284,18 @@ function RL.Experiment(
avg_length = mean(h[2].steps[1:end-1])
avg_score = mean(h[1].rewards[1:end-1])
- @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score
+ @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score =
+ avg_score
with_logger(lg) do
- @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0
+ @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment =
+ 0
end
end,
)
stop_condition = StopAfterStep(
haskey(ENV, "CI") ? 10_000 : 50_000_000,
- is_show_progress=!haskey(ENV, "CI")
+ is_show_progress = !haskey(ENV, "CI"),
)
Experiment(agent, env, stop_condition, hook, "# Rainbow <-> Atari($name)")
diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
index 9f32c48d6..d5ba9c9c7 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl
@@ -51,7 +51,7 @@ function RL.Experiment(
state = Vector{Float32} => (ns,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(policy, env, stop_condition, hook, "# BasicDQN <-> CartPole")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl
index ae8f02cb5..bc79c94be 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl
@@ -51,7 +51,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(70_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(70_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl
index 4b7f2a5cd..39748e307 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl
@@ -18,47 +18,47 @@ function RL.Experiment(
::Val{:BasicDQN},
::Val{:SingleRoomUndirected},
::Nothing;
- seed=123,
+ seed = 123,
)
rng = StableRNG(seed)
- env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng=rng)
+ env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng = rng)
env = GridWorlds.RLBaseEnv(env)
- env = RLEnvs.StateTransformedEnv(env;state_mapping=x -> vec(Float32.(x)))
+ env = RLEnvs.StateTransformedEnv(env; state_mapping = x -> vec(Float32.(x)))
env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01))
env = MaxTimeoutEnv(env, 240)
ns, na = length(state(env)), length(action_space(env))
agent = Agent(
- policy=QBasedPolicy(
- learner=BasicDQNLearner(
- approximator=NeuralNetworkApproximator(
- model=Chain(
- Dense(ns, 128, relu; init=glorot_uniform(rng)),
- Dense(128, 128, relu; init=glorot_uniform(rng)),
- Dense(128, na; init=glorot_uniform(rng)),
+ policy = QBasedPolicy(
+ learner = BasicDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
+ Dense(ns, 128, relu; init = glorot_uniform(rng)),
+ Dense(128, 128, relu; init = glorot_uniform(rng)),
+ Dense(128, na; init = glorot_uniform(rng)),
) |> cpu,
- optimizer=ADAM(),
+ optimizer = ADAM(),
),
- batch_size=32,
- min_replay_history=100,
- loss_func=huber_loss,
- rng=rng,
+ batch_size = 32,
+ min_replay_history = 100,
+ loss_func = huber_loss,
+ rng = rng,
),
- explorer=EpsilonGreedyExplorer(
- kind=:exp,
- ϵ_stable=0.01,
- decay_steps=500,
- rng=rng,
+ explorer = EpsilonGreedyExplorer(
+ kind = :exp,
+ ϵ_stable = 0.01,
+ decay_steps = 500,
+ rng = rng,
),
),
- trajectory=CircularArraySARTTrajectory(
- capacity=1000,
- state=Vector{Float32} => (ns,),
+ trajectory = CircularArraySARTTrajectory(
+ capacity = 1000,
+ state = Vector{Float32} => (ns,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
index 7e922e13e..7e2e218f6 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl
@@ -14,13 +14,13 @@ using Flux.Losses
function build_dueling_network(network::Chain)
lm = length(network)
- if !(network[lm] isa Dense) || !(network[lm-1] isa Dense)
+ if !(network[lm] isa Dense) || !(network[lm-1] isa Dense)
error("The Qnetwork provided is incompatible with dueling.")
end
- base = Chain([deepcopy(network[i]) for i=1:lm-2]...)
+ base = Chain([deepcopy(network[i]) for i in 1:lm-2]...)
last_layer_dims = size(network[lm].weight, 2)
val = Chain(deepcopy(network[lm-1]), Dense(last_layer_dims, 1))
- adv = Chain([deepcopy(network[i]) for i=lm-1:lm]...)
+ adv = Chain([deepcopy(network[i]) for i in lm-1:lm]...)
return DuelingNetwork(base, val, adv)
end
@@ -37,8 +37,8 @@ function RL.Experiment(
base_model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
- Dense(128, na; init = glorot_uniform(rng))
- )
+ Dense(128, na; init = glorot_uniform(rng)),
+ )
agent = Agent(
policy = QBasedPolicy(
@@ -72,7 +72,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl
index d8b1eb633..f74bcaea1 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl
@@ -64,7 +64,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(40_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(40_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
index f3ab3c98f..cba0fcea5 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl
@@ -71,7 +71,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl
index 7fb238d28..c45bb0b03 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl
@@ -17,58 +17,58 @@ function RL.Experiment(
::Val{:QRDQN},
::Val{:CartPole},
::Nothing;
- seed=123,
+ seed = 123,
)
N = 10
rng = StableRNG(seed)
- env = CartPoleEnv(; T=Float32, rng=rng)
+ env = CartPoleEnv(; T = Float32, rng = rng)
ns, na = length(state(env)), length(action_space(env))
init = glorot_uniform(rng)
agent = Agent(
- policy=QBasedPolicy(
- learner=QRDQNLearner(
- approximator=NeuralNetworkApproximator(
- model=Chain(
+ policy = QBasedPolicy(
+ learner = QRDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
Dense(ns, 128, relu; init = init),
Dense(128, 128, relu; init = init),
Dense(128, N * na; init = init),
) |> cpu,
- optimizer=ADAM(),
+ optimizer = ADAM(),
),
- target_approximator=NeuralNetworkApproximator(
- model=Chain(
+ target_approximator = NeuralNetworkApproximator(
+ model = Chain(
Dense(ns, 128, relu; init = init),
Dense(128, 128, relu; init = init),
Dense(128, N * na; init = init),
) |> cpu,
),
- stack_size=nothing,
- batch_size=32,
- update_horizon=1,
- min_replay_history=100,
- update_freq=1,
- target_update_freq=100,
- n_quantile=N,
- rng=rng,
+ stack_size = nothing,
+ batch_size = 32,
+ update_horizon = 1,
+ min_replay_history = 100,
+ update_freq = 1,
+ target_update_freq = 100,
+ n_quantile = N,
+ rng = rng,
),
- explorer=EpsilonGreedyExplorer(
- kind=:exp,
- ϵ_stable=0.01,
- decay_steps=500,
- rng=rng,
+ explorer = EpsilonGreedyExplorer(
+ kind = :exp,
+ ϵ_stable = 0.01,
+ decay_steps = 500,
+ rng = rng,
),
),
- trajectory=CircularArraySARTTrajectory(
- capacity=1000,
- state=Vector{Float32} => (ns,),
+ trajectory = CircularArraySARTTrajectory(
+ capacity = 1000,
+ state = Vector{Float32} => (ns,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
index fdf473a83..7f74dd096 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl
@@ -52,7 +52,7 @@ function RL.Experiment(
update_freq = 1,
target_update_freq = 100,
ensemble_num = ensemble_num,
- ensemble_method = :rand,
+ ensemble_method = :rand,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
@@ -68,7 +68,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
index f367d1cf5..d8f5d2437 100644
--- a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
+++ b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl
@@ -71,7 +71,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "")
end
diff --git a/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl b/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl
index c7e0bb2fa..83549aab0 100644
--- a/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl
+++ b/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl
@@ -19,41 +19,37 @@ end
function (hook::KuhnOpenEDHook)(::PreEpisodeStage, policy, env)
## get nash_conv of the current policy.
push!(hook.results, RLZoo.nash_conv(policy, env))
-
+
## update agents' learning rate.
for (_, agent) in policy.agents
agent.learner.optimizer[2].eta = 1.0 / sqrt(length(hook.results))
end
end
-function RL.Experiment(
- ::Val{:JuliaRL},
- ::Val{:ED},
- ::Val{:OpenSpiel},
- game;
- seed = 123,
-)
+function RL.Experiment(::Val{:JuliaRL}, ::Val{:ED}, ::Val{:OpenSpiel}, game; seed = 123)
rng = StableRNG(seed)
-
+
env = OpenSpielEnv(game)
wrapped_env = ActionTransformedEnv(
env,
- action_mapping = a -> RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1),
- action_space_mapping = as -> RLBase.current_player(env) == chance_player(env) ?
- as : Base.OneTo(num_distinct_actions(env.game)),
+ action_mapping = a ->
+ RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1),
+ action_space_mapping = as ->
+ RLBase.current_player(env) == chance_player(env) ? as :
+ Base.OneTo(num_distinct_actions(env.game)),
)
wrapped_env = DefaultStateStyleEnv{InformationSet{Array}()}(wrapped_env)
player = 0 # or 1
ns, na = length(state(wrapped_env, player)), length(action_space(wrapped_env, player))
create_network() = Chain(
- Dense(ns, 64, relu;init = glorot_uniform(rng)),
- Dense(64, na;init = glorot_uniform(rng))
+ Dense(ns, 64, relu; init = glorot_uniform(rng)),
+ Dense(64, na; init = glorot_uniform(rng)),
)
create_learner() = NeuralNetworkApproximator(
model = create_network(),
- optimizer = Flux.Optimise.Optimiser(WeightDecay(0.001), Descent())
+ optimizer = Flux.Optimise.Optimiser(WeightDecay(0.001), Descent()),
)
EDmanager = EDManager(
@@ -63,20 +59,26 @@ function RL.Experiment(
create_learner(), # neural network learner
WeightedSoftmaxExplorer(), # explorer
) for player in players(env) if player != chance_player(env)
- )
+ ),
)
- stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterEpisode(500, is_show_progress = !haskey(ENV, "CI"))
hook = KuhnOpenEDHook([])
- Experiment(EDmanager, wrapped_env, stop_condition, hook, "# play OpenSpiel $game with ED algorithm")
+ Experiment(
+ EDmanager,
+ wrapped_env,
+ stop_condition,
+ hook,
+ "# play OpenSpiel $game with ED algorithm",
+ )
end
using Plots
ex = E`JuliaRL_ED_OpenSpiel(kuhn_poker)`
results = run(ex)
-plot(ex.hook.results, xlabel="episode", ylabel="nash_conv")
+plot(ex.hook.results, xlabel = "episode", ylabel = "nash_conv")
savefig("assets/JuliaRL_ED_OpenSpiel(kuhn_poker).png")#hide
-# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png)
\ No newline at end of file
+# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png)
diff --git a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl
index 7c69f4f63..d234d78e3 100644
--- a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl
+++ b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl
@@ -35,14 +35,15 @@ function RL.Experiment(
seed = 123,
)
rng = StableRNG(seed)
-
+
## Encode the KuhnPokerEnv's states for training.
env = KuhnPokerEnv()
wrapped_env = StateTransformedEnv(
env;
state_mapping = s -> [findfirst(==(s), state_space(env))],
- state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)]
- )
+ state_space_mapping = ss ->
+ [[findfirst(==(s), state_space(env))] for s in state_space(env)],
+ )
player = 1 # or 2
ns, na = length(state(wrapped_env, player)), length(action_space(wrapped_env, player))
@@ -53,14 +54,14 @@ function RL.Experiment(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 64, relu; init = glorot_normal(rng)),
- Dense(64, na; init = glorot_normal(rng))
+ Dense(64, na; init = glorot_normal(rng)),
) |> cpu,
optimizer = Descent(0.01),
),
target_approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 64, relu; init = glorot_normal(rng)),
- Dense(64, na; init = glorot_normal(rng))
+ Dense(64, na; init = glorot_normal(rng)),
) |> cpu,
),
γ = 1.0f0,
@@ -81,7 +82,7 @@ function RL.Experiment(
),
trajectory = CircularArraySARTTrajectory(
capacity = 200_000,
- state = Vector{Int} => (ns, ),
+ state = Vector{Int} => (ns,),
),
)
@@ -89,9 +90,9 @@ function RL.Experiment(
policy = BehaviorCloningPolicy(;
approximator = NeuralNetworkApproximator(
model = Chain(
- Dense(ns, 64, relu; init = glorot_normal(rng)),
- Dense(64, na; init = glorot_normal(rng))
- ) |> cpu,
+ Dense(ns, 64, relu; init = glorot_normal(rng)),
+ Dense(64, na; init = glorot_normal(rng)),
+ ) |> cpu,
optimizer = Descent(0.01),
),
explorer = WeightedSoftmaxExplorer(),
@@ -111,19 +112,22 @@ function RL.Experiment(
η = 0.1 # anticipatory parameter
nfsp = NFSPAgentManager(
Dict(
- (player, NFSPAgent(
- deepcopy(rl_agent),
- deepcopy(sl_agent),
- η,
- rng,
- 128, # update_freq
- 0, # initial update_step
- true, # initial NFSPAgent's training mode
- )) for player in players(wrapped_env) if player != chance_player(wrapped_env)
- )
+ (
+ player,
+ NFSPAgent(
+ deepcopy(rl_agent),
+ deepcopy(sl_agent),
+ η,
+ rng,
+ 128, # update_freq
+ 0, # initial update_step
+ true, # initial NFSPAgent's training mode
+ ),
+ ) for player in players(wrapped_env) if player != chance_player(wrapped_env)
+ ),
)
- stop_condition = StopAfterEpisode(1_200_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterEpisode(1_200_000, is_show_progress = !haskey(ENV, "CI"))
hook = KuhnNFSPHook(10_000, 0, [], [])
Experiment(nfsp, wrapped_env, stop_condition, hook, "# run NFSP on KuhnPokerEnv")
@@ -133,8 +137,14 @@ end
using Plots
ex = E`JuliaRL_NFSP_KuhnPoker`
run(ex)
-plot(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="nash_conv")
+plot(
+ ex.hook.episode,
+ ex.hook.results,
+ xaxis = :log,
+ xlabel = "episode",
+ ylabel = "nash_conv",
+)
savefig("assets/JuliaRL_NFSP_KuhnPoker.png")#hide
-# ![](assets/JuliaRL_NFSP_KuhnPoker.png)
\ No newline at end of file
+# ![](assets/JuliaRL_NFSP_KuhnPoker.png)
diff --git a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl
index 1e8509f57..d0a8a099c 100644
--- a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl
+++ b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl
@@ -29,21 +29,17 @@ function (hook::KuhnOpenNFSPHook)(::PostEpisodeStage, policy, env)
end
end
-function RL.Experiment(
- ::Val{:JuliaRL},
- ::Val{:NFSP},
- ::Val{:OpenSpiel},
- game;
- seed = 123,
-)
+function RL.Experiment(::Val{:JuliaRL}, ::Val{:NFSP}, ::Val{:OpenSpiel}, game; seed = 123)
rng = StableRNG(seed)
-
+
env = OpenSpielEnv(game)
wrapped_env = ActionTransformedEnv(
env,
- action_mapping = a -> RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1),
- action_space_mapping = as -> RLBase.current_player(env) == chance_player(env) ?
- as : Base.OneTo(num_distinct_actions(env.game)),
+ action_mapping = a ->
+ RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1),
+ action_space_mapping = as ->
+ RLBase.current_player(env) == chance_player(env) ? as :
+ Base.OneTo(num_distinct_actions(env.game)),
)
wrapped_env = DefaultStateStyleEnv{InformationSet{Array}()}(wrapped_env)
player = 0 # or 1
@@ -56,14 +52,14 @@ function RL.Experiment(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_normal(rng)),
- Dense(128, na; init = glorot_normal(rng))
+ Dense(128, na; init = glorot_normal(rng)),
) |> cpu,
optimizer = Descent(0.01),
),
target_approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_normal(rng)),
- Dense(128, na; init = glorot_normal(rng))
+ Dense(128, na; init = glorot_normal(rng)),
) |> cpu,
),
γ = 1.0f0,
@@ -84,7 +80,7 @@ function RL.Experiment(
),
trajectory = CircularArraySARTTrajectory(
capacity = 200_000,
- state = Vector{Float64} => (ns, ),
+ state = Vector{Float64} => (ns,),
),
)
@@ -92,9 +88,9 @@ function RL.Experiment(
policy = BehaviorCloningPolicy(;
approximator = NeuralNetworkApproximator(
model = Chain(
- Dense(ns, 128, relu; init = glorot_normal(rng)),
- Dense(128, na; init = glorot_normal(rng))
- ) |> cpu,
+ Dense(ns, 128, relu; init = glorot_normal(rng)),
+ Dense(128, na; init = glorot_normal(rng)),
+ ) |> cpu,
optimizer = Descent(0.01),
),
explorer = WeightedSoftmaxExplorer(),
@@ -114,29 +110,44 @@ function RL.Experiment(
η = 0.1 # anticipatory parameter
nfsp = NFSPAgentManager(
Dict(
- (player, NFSPAgent(
- deepcopy(rl_agent),
- deepcopy(sl_agent),
- η,
- rng,
- 128, # update_freq
- 0, # initial update_step
- true, # initial NFSPAgent's training mode
- )) for player in players(wrapped_env) if player != chance_player(wrapped_env)
- )
+ (
+ player,
+ NFSPAgent(
+ deepcopy(rl_agent),
+ deepcopy(sl_agent),
+ η,
+ rng,
+ 128, # update_freq
+ 0, # initial update_step
+ true, # initial NFSPAgent's training mode
+ ),
+ ) for player in players(wrapped_env) if player != chance_player(wrapped_env)
+ ),
)
- stop_condition = StopAfterEpisode(1_200_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterEpisode(1_200_000, is_show_progress = !haskey(ENV, "CI"))
hook = KuhnOpenNFSPHook(10_000, 0, [], [])
- Experiment(nfsp, wrapped_env, stop_condition, hook, "# Play kuhn_poker in OpenSpiel with NFSP")
+ Experiment(
+ nfsp,
+ wrapped_env,
+ stop_condition,
+ hook,
+ "# Play kuhn_poker in OpenSpiel with NFSP",
+ )
end
using Plots
ex = E`JuliaRL_NFSP_OpenSpiel(kuhn_poker)`
run(ex)
-plot(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="nash_conv")
+plot(
+ ex.hook.episode,
+ ex.hook.results,
+ xaxis = :log,
+ xlabel = "episode",
+ ylabel = "nash_conv",
+)
savefig("assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png")#hide
-# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png)
\ No newline at end of file
+# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png)
diff --git a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl
index f1e42bb51..808719818 100644
--- a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl
+++ b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl
@@ -61,7 +61,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = RecordStateAction()
run(agent, env, stop_condition, hook)
@@ -84,7 +84,13 @@ function RL.Experiment(
end
hook = TotalRewardPerEpisode()
- Experiment(bc, env, StopAfterEpisode(100, is_show_progress=!haskey(ENV, "CI")), hook, "BehaviorCloning <-> CartPole")
+ Experiment(
+ bc,
+ env,
+ StopAfterEpisode(100, is_show_progress = !haskey(ENV, "CI")),
+ hook,
+ "BehaviorCloning <-> CartPole",
+ )
end
#+ tangle=false
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl
index eeed90609..2ea78b92d 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl
@@ -63,7 +63,7 @@ function RL.Experiment(
terminal = Vector{Bool} => (N_ENV,),
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# A2CGAE with CartPole")
end
@@ -78,7 +78,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_A2CGAE_CartPole.png") #hide
# ![](assets/JuliaRL_A2CGAE_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl
index cb03dd025..e56733a92 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl
@@ -59,7 +59,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=true)
+ stop_condition = StopAfterStep(50_000, is_show_progress = true)
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# A2C with CartPole")
end
@@ -73,7 +73,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_A2C_CartPole.png") #hide
# ![](assets/JuliaRL_A2C_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl
index 927cdc2d2..0120abecd 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl
@@ -69,7 +69,7 @@ function RL.Experiment(
na = 1,
batch_size = 64,
start_steps = 1000,
- start_policy = RandomPolicy(-1.0..1.0; rng = rng),
+ start_policy = RandomPolicy(-1.0 .. 1.0; rng = rng),
update_after = 1000,
update_freq = 1,
act_limit = 1.0,
@@ -79,11 +79,11 @@ function RL.Experiment(
trajectory = CircularArraySARTTrajectory(
capacity = 10000,
state = Vector{Float32} => (ns,),
- action = Float32 => (na, ),
+ action = Float32 => (na,),
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with DDPG")
end
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl
index 3559b1b03..c526fcfe3 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl
@@ -64,7 +64,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# MAC with CartPole")
end
@@ -78,7 +78,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_MAC_CartPole.png") #hide
# ![](assets/JuliaRL_MAC_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl
index 9d13ec830..542d22100 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl
@@ -32,7 +32,7 @@ function RL.Experiment(
::Val{:MADDPG},
::Val{:KuhnPoker},
::Nothing;
- seed=123,
+ seed = 123,
)
rng = StableRNG(seed)
env = KuhnPokerEnv()
@@ -40,10 +40,12 @@ function RL.Experiment(
StateTransformedEnv(
env;
state_mapping = s -> [findfirst(==(s), state_space(env))],
- state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)]
- ),
+ state_space_mapping = ss ->
+ [[findfirst(==(s), state_space(env))] for s in state_space(env)],
+ ),
## drop the dummy action of the other agent.
- action_mapping = x -> length(x) == 1 ? x : Int(ceil(x[current_player(env)]) + 1),
+ action_mapping = x ->
+ length(x) == 1 ? x : Int(ceil(x[current_player(env)]) + 1),
)
ns, na = 1, 1 # dimension of the state and action.
n_players = 2 # number of players
@@ -51,18 +53,18 @@ function RL.Experiment(
init = glorot_uniform(rng)
create_actor() = Chain(
- Dense(ns, 64, relu; init = init),
- Dense(64, 64, relu; init = init),
- Dense(64, na, tanh; init = init),
- )
+ Dense(ns, 64, relu; init = init),
+ Dense(64, 64, relu; init = init),
+ Dense(64, na, tanh; init = init),
+ )
create_critic() = Chain(
Dense(n_players * ns + n_players * na, 64, relu; init = init),
Dense(64, 64, relu; init = init),
Dense(64, 1; init = init),
- )
+ )
+
-
policy = DDPGPolicy(
behavior_actor = NeuralNetworkApproximator(
model = create_actor(),
@@ -84,31 +86,36 @@ function RL.Experiment(
ρ = 0.99f0,
na = na,
start_steps = 1000,
- start_policy = RandomPolicy(-0.99..0.99; rng = rng),
+ start_policy = RandomPolicy(-0.99 .. 0.99; rng = rng),
update_after = 1000,
act_limit = 0.99,
- act_noise = 0.,
+ act_noise = 0.0,
rng = rng,
)
trajectory = CircularArraySARTTrajectory(
capacity = 100_000, # replay buffer capacity
- state = Vector{Int} => (ns, ),
- action = Float32 => (na, ),
+ state = Vector{Int} => (ns,),
+ action = Float32 => (na,),
)
agents = MADDPGManager(
- Dict((player, Agent(
- policy = NamedPolicy(player, deepcopy(policy)),
- trajectory = deepcopy(trajectory),
- )) for player in players(env) if player != chance_player(env)),
+ Dict(
+ (
+ player,
+ Agent(
+ policy = NamedPolicy(player, deepcopy(policy)),
+ trajectory = deepcopy(trajectory),
+ ),
+ ) for player in players(env) if player != chance_player(env)
+ ),
SARTS, # trace's type
512, # batch_size
100, # update_freq
0, # initial update_step
- rng
+ rng,
)
- stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterEpisode(100_000, is_show_progress = !haskey(ENV, "CI"))
hook = KuhnMADDPGHook(1000, 0, [], [])
Experiment(agents, wrapped_env, stop_condition, hook, "# play KuhnPoker with MADDPG")
end
@@ -117,7 +124,13 @@ end
using Plots
ex = E`JuliaRL_MADDPG_KuhnPoker`
run(ex)
-scatter(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="reward of player 1")
+scatter(
+ ex.hook.episode,
+ ex.hook.results,
+ xaxis = :log,
+ xlabel = "episode",
+ ylabel = "reward of player 1",
+)
savefig("assets/JuliaRL_MADDPG_KuhnPoker.png") #hide
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl
index 6f00a1dc2..116e4f6d1 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl
@@ -43,53 +43,51 @@ function RL.Experiment(
::Val{:MADDPG},
::Val{:SpeakerListener},
::Nothing;
- seed=123,
+ seed = 123,
)
rng = StableRNG(seed)
env = SpeakerListenerEnv(max_steps = 25)
init = glorot_uniform(rng)
- critic_dim = sum(length(state(env, p)) + length(action_space(env, p)) for p in (:Speaker, :Listener))
+ critic_dim = sum(
+ length(state(env, p)) + length(action_space(env, p)) for p in (:Speaker, :Listener)
+ )
create_actor(player) = Chain(
Dense(length(state(env, player)), 64, relu; init = init),
Dense(64, 64, relu; init = init),
- Dense(64, length(action_space(env, player)); init = init)
- )
+ Dense(64, length(action_space(env, player)); init = init),
+ )
create_critic(critic_dim) = Chain(
Dense(critic_dim, 64, relu; init = init),
Dense(64, 64, relu; init = init),
Dense(64, 1; init = init),
- )
+ )
create_policy(player) = DDPGPolicy(
- behavior_actor = NeuralNetworkApproximator(
- model = create_actor(player),
- optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)),
- ),
- behavior_critic = NeuralNetworkApproximator(
- model = create_critic(critic_dim),
- optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)),
- ),
- target_actor = NeuralNetworkApproximator(
- model = create_actor(player),
- ),
- target_critic = NeuralNetworkApproximator(
- model = create_critic(critic_dim),
- ),
- γ = 0.95f0,
- ρ = 0.99f0,
- na = length(action_space(env, player)),
- start_steps = 0,
- start_policy = nothing,
- update_after = 512 * env.max_steps, # batch_size * env.max_steps
- act_limit = 1.0,
- act_noise = 0.,
- )
+ behavior_actor = NeuralNetworkApproximator(
+ model = create_actor(player),
+ optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)),
+ ),
+ behavior_critic = NeuralNetworkApproximator(
+ model = create_critic(critic_dim),
+ optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)),
+ ),
+ target_actor = NeuralNetworkApproximator(model = create_actor(player)),
+ target_critic = NeuralNetworkApproximator(model = create_critic(critic_dim)),
+ γ = 0.95f0,
+ ρ = 0.99f0,
+ na = length(action_space(env, player)),
+ start_steps = 0,
+ start_policy = nothing,
+ update_after = 512 * env.max_steps, # batch_size * env.max_steps
+ act_limit = 1.0,
+ act_noise = 0.0,
+ )
create_trajectory(player) = CircularArraySARTTrajectory(
- capacity = 1_000_000, # replay buffer capacity
- state = Vector{Float64} => (length(state(env, player)), ),
- action = Vector{Float64} => (length(action_space(env, player)), ),
- )
+ capacity = 1_000_000, # replay buffer capacity
+ state = Vector{Float64} => (length(state(env, player)),),
+ action = Vector{Float64} => (length(action_space(env, player)),),
+ )
agents = MADDPGManager(
Dict(
@@ -102,10 +100,10 @@ function RL.Experiment(
512, # batch_size
100, # update_freq
0, # initial update_step
- rng
+ rng,
)
- stop_condition = StopAfterEpisode(8_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterEpisode(8_000, is_show_progress = !haskey(ENV, "CI"))
hook = MeanRewardHook(0, 800, 100, [], [])
Experiment(agents, env, stop_condition, hook, "# play SpeakerListener with MADDPG")
end
@@ -114,7 +112,12 @@ end
using Plots
ex = E`JuliaRL_MADDPG_SpeakerListener`
run(ex)
-plot(ex.hook.episodes, ex.hook.mean_rewards, xlabel="episode", ylabel="mean episode reward")
+plot(
+ ex.hook.episodes,
+ ex.hook.mean_rewards,
+ xlabel = "episode",
+ ylabel = "mean episode reward",
+)
savefig("assets/JuliaRL_MADDPG_SpeakerListenerEnv.png") #hide
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl
index 45cdd2e13..cbc0e3340 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl
@@ -62,7 +62,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# PPO with CartPole")
end
@@ -76,7 +76,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_PPO_CartPole.png") #hide
# ![](assets/JuliaRL_PPO_CartPole.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl
index 80625e104..1fd85f10c 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl
@@ -32,7 +32,8 @@ function RL.Experiment(
UPDATE_FREQ = 2048
env = MultiThreadEnv([
PendulumEnv(T = Float32, rng = StableRNG(hash(seed + i))) |>
- env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high)) for i in 1:N_ENV
+ env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high))
+ for i in 1:N_ENV
])
init = glorot_uniform(rng)
@@ -78,7 +79,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalBatchRewardPerEpisode(N_ENV)
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with PPO")
end
@@ -92,7 +93,7 @@ run(ex)
n = minimum(map(length, ex.hook.rewards))
m = mean([@view(x[1:n]) for x in ex.hook.rewards])
s = std([@view(x[1:n]) for x in ex.hook.rewards])
-plot(m,ribbon=s)
+plot(m, ribbon = s)
savefig("assets/JuliaRL_PPO_Pendulum.png") #hide
# ![](assets/JuliaRL_PPO_Pendulum.png)
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl
index c15e3c07d..ffd1b93eb 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl
@@ -38,12 +38,11 @@ function RL.Experiment(
create_policy_net() = NeuralNetworkApproximator(
model = GaussianNetwork(
- pre = Chain(
- Dense(ns, 30, relu),
- Dense(30, 30, relu),
- ),
+ pre = Chain(Dense(ns, 30, relu), Dense(30, 30, relu)),
μ = Chain(Dense(30, na, init = init)),
- logσ = Chain(Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init)),
+ logσ = Chain(
+ Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init),
+ ),
),
optimizer = ADAM(0.003),
)
@@ -69,7 +68,7 @@ function RL.Experiment(
α = 0.2f0,
batch_size = 64,
start_steps = 1000,
- start_policy = RandomPolicy(Space([-1.0..1.0 for _ in 1:na]); rng = rng),
+ start_policy = RandomPolicy(Space([-1.0 .. 1.0 for _ in 1:na]); rng = rng),
update_after = 1000,
update_freq = 1,
automatic_entropy_tuning = true,
@@ -84,7 +83,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with SAC")
end
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl
index 49bc26c64..5d24117b4 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl
@@ -69,7 +69,7 @@ function RL.Experiment(
ρ = 0.99f0,
batch_size = 64,
start_steps = 1000,
- start_policy = RandomPolicy(-1.0..1.0; rng = rng),
+ start_policy = RandomPolicy(-1.0 .. 1.0; rng = rng),
update_after = 1000,
update_freq = 1,
policy_freq = 2,
@@ -86,7 +86,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play Pendulum with TD3")
end
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl
index 506bba5d1..005c29f33 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl
@@ -55,7 +55,7 @@ function RL.Experiment(
),
)
- stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# VMPO with CartPole")
diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl
index 87130f3dc..8cfbe125d 100644
--- a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl
+++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl
@@ -49,7 +49,7 @@ function RL.Experiment(
),
trajectory = ElasticSARTTrajectory(state = Vector{Float32} => (ns,)),
)
- stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI"))
+ stop_condition = StopAfterEpisode(500, is_show_progress = !haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
description = "# Play CartPole with VPG"
diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl
index 7bf2add71..56eb5d90e 100644
--- a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl
+++ b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl
@@ -83,7 +83,7 @@ function RL.Experiment(
hook = ComposedHook(
total_batch_reward_per_episode,
batch_steps_per_episode,
- DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env
+ DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env
learner = agent.policy.policy.learner
with_logger(lg) do
@info "training" loss = learner.loss actor_loss = learner.actor_loss critic_loss =
@@ -94,20 +94,22 @@ function RL.Experiment(
DoEveryNStep() do t, agent, env
with_logger(lg) do
rewards = [
- total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i])
+ total_batch_reward_per_episode.rewards[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(rewards) > 0
@info "training" rewards = mean(rewards) log_step_increment = 0
end
steps = [
- batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i])
+ batch_steps_per_episode.steps[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(steps) > 0
@info "training" steps = mean(steps) log_step_increment = 0
end
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
h = TotalBatchOriginalRewardPerEpisode(N_ENV)
s = @elapsed run(
diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl
index e9eb574a3..4b33dc27f 100644
--- a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl
+++ b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl
@@ -85,7 +85,7 @@ function RL.Experiment(
hook = ComposedHook(
total_batch_reward_per_episode,
batch_steps_per_episode,
- DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env
+ DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env
p = agent.policy
with_logger(lg) do
@info "training" loss = mean(p.loss) actor_loss = mean(p.actor_loss) critic_loss =
@@ -93,7 +93,7 @@ function RL.Experiment(
mean(p.norm) log_step_increment = UPDATE_FREQ
end
end,
- DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env
+ DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env
decay = (N_TRAINING_STEPS - t) / N_TRAINING_STEPS
agent.policy.approximator.optimizer.eta = INIT_LEARNING_RATE * decay
agent.policy.clip_range = INIT_CLIP_RANGE * Float32(decay)
@@ -101,20 +101,22 @@ function RL.Experiment(
DoEveryNStep() do t, agent, env
with_logger(lg) do
rewards = [
- total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i])
+ total_batch_reward_per_episode.rewards[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(rewards) > 0
@info "training" rewards = mean(rewards) log_step_increment = 0
end
steps = [
- batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i])
+ batch_steps_per_episode.steps[i][end] for
+ i in 1:length(env) if is_terminated(env[i])
]
if length(steps) > 0
@info "training" steps = mean(steps) log_step_increment = 0
end
end
end,
- DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env
+ DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env
@info "evaluating agent at $t step..."
## switch to GreedyExplorer?
h = TotalBatchRewardPerEpisode(N_ENV)
diff --git a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl
index f31098c00..a6bfc1d4b 100644
--- a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl
+++ b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl
@@ -18,7 +18,13 @@ function RL.Experiment(::Val{:JuliaRL}, ::Val{:Minimax}, ::Val{:OpenSpiel}, game
)
hooks = MultiAgentHook(0 => TotalRewardPerEpisode(), 1 => TotalRewardPerEpisode())
description = "# Play `$game` in OpenSpiel with Minimax"
- Experiment(agents, env, StopAfterEpisode(1, is_show_progress=!haskey(ENV, "CI")), hooks, description)
+ Experiment(
+ agents,
+ env,
+ StopAfterEpisode(1, is_show_progress = !haskey(ENV, "CI")),
+ hooks,
+ description,
+ )
end
using Plots
diff --git a/docs/homepage/utils.jl b/docs/homepage/utils.jl
index 816810d71..64787a91a 100644
--- a/docs/homepage/utils.jl
+++ b/docs/homepage/utils.jl
@@ -5,7 +5,7 @@ html(s) = "\n~~~$s~~~\n"
function hfun_adddescription()
d = locvar(:description)
- isnothing(d) ? "" : F.fd2html(d, internal=true)
+ isnothing(d) ? "" : F.fd2html(d, internal = true)
end
function hfun_frontmatter()
@@ -28,7 +28,7 @@ function hfun_byline()
if isnothing(fm)
""
else
- ""
+ ""
end
end
@@ -62,7 +62,7 @@ function hfun_appendix()
if isfile(bib_in_cur_folder)
bib_resolved = F.parse_rpath("/" * bib_in_cur_folder)
else
- bib_resolved = F.parse_rpath(bib; canonical=false, code=true)
+ bib_resolved = F.parse_rpath(bib; canonical = false, code = true)
end
bib = ""
end
@@ -74,7 +74,7 @@ function hfun_appendix()
"""
end
-function lx_dcite(lxc,_)
+function lx_dcite(lxc, _)
content = F.content(lxc.braces[1])
"" |> html
end
@@ -92,7 +92,7 @@ end
"""
Possible layouts:
"""
-function lx_dfig(lxc,lxd)
+function lx_dfig(lxc, lxd)
content = F.content(lxc.braces[1])
info = split(content, ';')
layout = info[1]
@@ -111,7 +111,7 @@ function lx_dfig(lxc,lxd)
end
# (case 3) assume it is generated by code
- src = F.parse_rpath(src; canonical=false, code=true)
+ src = F.parse_rpath(src; canonical = false, code = true)
# !!! directly take from `lx_fig` in Franklin.jl
fdir, fext = splitext(src)
@@ -122,11 +122,10 @@ function lx_dfig(lxc,lxd)
# then in both cases there can be a relative path set but the user may mean
# that it's in the subfolder /output/ (if generated by code) so should look
# both in the relpath and if not found and if /output/ not already last dir
- candext = ifelse(isempty(fext),
- (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,))
- for ext ∈ candext
+ candext = ifelse(isempty(fext), (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,))
+ for ext in candext
candpath = fdir * ext
- syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
+ syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
isfile(syspath) && return dfigure(layout, candpath, caption)
end
# now try in the output dir just in case (provided we weren't already
@@ -134,20 +133,20 @@ function lx_dfig(lxc,lxd)
p1, p2 = splitdir(fdir)
@debug "TEST" p1 p2
if splitdir(p1)[2] != "output"
- for ext ∈ candext
+ for ext in candext
candpath = joinpath(splitdir(p1)[1], "output", p2 * ext)
- syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
+ syspath = joinpath(F.PATHS[:site], split(candpath, '/')...)
isfile(syspath) && return dfigure(layout, candpath, caption)
end
end
end
-function lx_aside(lxc,lxd)
+function lx_aside(lxc, lxd)
content = F.reprocess(F.content(lxc.braces[1]), lxd)
"" |> html
end
-function lx_footnote(lxc,lxd)
+function lx_footnote(lxc, lxd)
content = F.reprocess(F.content(lxc.braces[1]), lxd)
# workaround
if startswith(content, "
")
@@ -156,7 +155,7 @@ function lx_footnote(lxc,lxd)
"$content" |> html
end
-function lx_appendix(lxc,lxd)
+function lx_appendix(lxc, lxd)
content = F.reprocess(F.content(lxc.braces[1]), lxd)
"$content" |> html
-end
\ No newline at end of file
+end
diff --git a/docs/make.jl b/docs/make.jl
index 8f947388d..4d7fd750d 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -15,11 +15,7 @@ end
experiments, postprocess_cb, experiments_assets = makedemos("experiments")
-assets = [
- "assets/favicon.ico",
- "assets/custom.css",
- experiments_assets
-]
+assets = ["assets/favicon.ico", "assets/custom.css", experiments_assets]
makedocs(
modules = [
@@ -56,7 +52,7 @@ makedocs(
"RLZoo" => "rlzoo.md",
"RLDatasets" => "rldatasets.md",
],
- ]
+ ],
)
postprocess_cb()
diff --git a/src/DistributedReinforcementLearning/src/actor_model.jl b/src/DistributedReinforcementLearning/src/actor_model.jl
index 3a2414c5c..af6b05214 100644
--- a/src/DistributedReinforcementLearning/src/actor_model.jl
+++ b/src/DistributedReinforcementLearning/src/actor_model.jl
@@ -1,19 +1,12 @@
-export AbstractMessage,
- StartMsg,
- StopMsg,
- PingMsg,
- PongMsg,
- ProxyMsg,
- actor,
- self
+export AbstractMessage, StartMsg, StopMsg, PingMsg, PongMsg, ProxyMsg, actor, self
abstract type AbstractMessage end
-struct StartMsg{A, K} <: AbstractMessage
+struct StartMsg{A,K} <: AbstractMessage
args::A
kwargs::K
- StartMsg(args...;kwargs...) = new{typeof(args), typeof(kwargs)}(args, kwargs)
+ StartMsg(args...; kwargs...) = new{typeof(args),typeof(kwargs)}(args, kwargs)
end
struct StopMsg <: AbstractMessage end
@@ -45,9 +38,9 @@ const DEFAULT_MAILBOX_SIZE = 32
Create a task to handle messages one-by-one by calling `f(msg)`.
A mailbox (`RemoteChannel`) is returned.
"""
-function actor(f;sz=DEFAULT_MAILBOX_SIZE)
+function actor(f; sz = DEFAULT_MAILBOX_SIZE)
RemoteChannel() do
- Channel(sz;spawn=true) do ch
+ Channel(sz; spawn = true) do ch
task_local_storage("MAILBOX", RemoteChannel(() -> ch))
while true
msg = take!(ch)
diff --git a/src/DistributedReinforcementLearning/src/core.jl b/src/DistributedReinforcementLearning/src/core.jl
index 44ed67937..34a04571a 100644
--- a/src/DistributedReinforcementLearning/src/core.jl
+++ b/src/DistributedReinforcementLearning/src/core.jl
@@ -51,7 +51,7 @@ Base.@kwdef struct Trainer{P,S}
sealer::S = deepcopy
end
-Trainer(p) = Trainer(;policy=p)
+Trainer(p) = Trainer(; policy = p)
function (trainer::Trainer)(msg::BatchDataMsg)
update!(trainer.policy, msg.data)
@@ -94,7 +94,7 @@ mutable struct Worker
end
function (w::Worker)(msg::StartMsg)
- w.experiment = w.init(msg.args...;msg.kwargs...)
+ w.experiment = w.init(msg.args...; msg.kwargs...)
w.task = Threads.@spawn run(w.experiment)
end
@@ -128,7 +128,7 @@ end
function (wp::WorkerProxy)(::FetchParamMsg)
if !wp.is_fetch_msg_sent[]
put!(wp.target, FetchParamMsg(self()))
- wp.is_fetch_msg_sent[] = true
+ wp.is_fetch_msg_sent[] = true
end
end
@@ -172,9 +172,12 @@ function (orc::Orchestrator)(msg::InsertTrajectoryMsg)
put!(orc.trajectory_proxy, BatchSampleMsg(orc.trainer))
L.n_sample += 1
if L.n_sample == (L.n_load + 1) * L.sample_load_ratio
- put!(orc.trajectory_proxy, ProxyMsg(to=orc.trainer, msg=FetchParamMsg(orc.worker)))
+ put!(
+ orc.trajectory_proxy,
+ ProxyMsg(to = orc.trainer, msg = FetchParamMsg(orc.worker)),
+ )
L.n_load += 1
end
end
end
-end
\ No newline at end of file
+end
diff --git a/src/DistributedReinforcementLearning/src/extensions.jl b/src/DistributedReinforcementLearning/src/extensions.jl
index efeda428b..1c6f32e9d 100644
--- a/src/DistributedReinforcementLearning/src/extensions.jl
+++ b/src/DistributedReinforcementLearning/src/extensions.jl
@@ -76,4 +76,4 @@ function (hook::FetchParamsHook)(::PostActStage, agent, env)
end
end
end
-end
\ No newline at end of file
+end
diff --git a/src/DistributedReinforcementLearning/test/actor.jl b/src/DistributedReinforcementLearning/test/actor.jl
index 34881c5b1..e2cc9ae87 100644
--- a/src/DistributedReinforcementLearning/test/actor.jl
+++ b/src/DistributedReinforcementLearning/test/actor.jl
@@ -1,53 +1,53 @@
@testset "basic tests" begin
-Base.@kwdef mutable struct TestActor
- state::Union{Nothing, Int} = nothing
-end
+ Base.@kwdef mutable struct TestActor
+ state::Union{Nothing,Int} = nothing
+ end
-struct CurrentStateMsg <: AbstractMessage
- state
-end
+ struct CurrentStateMsg <: AbstractMessage
+ state::Any
+ end
-Base.@kwdef struct ReadStateMsg <: AbstractMessage
- from = self()
-end
+ Base.@kwdef struct ReadStateMsg <: AbstractMessage
+ from = self()
+ end
-struct IncMsg <: AbstractMessage end
-struct DecMsg <: AbstractMessage end
+ struct IncMsg <: AbstractMessage end
+ struct DecMsg <: AbstractMessage end
-(x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1]
-(x::TestActor)(msg::StopMsg) = x.state = nothing
-(x::TestActor)(::IncMsg) = x.state += 1
-(x::TestActor)(::DecMsg) = x.state -= 1
-(x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state))
+ (x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1]
+ (x::TestActor)(msg::StopMsg) = x.state = nothing
+ (x::TestActor)(::IncMsg) = x.state += 1
+ (x::TestActor)(::DecMsg) = x.state -= 1
+ (x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state))
-x = actor(TestActor())
-put!(x, StartMsg(0))
+ x = actor(TestActor())
+ put!(x, StartMsg(0))
-put!(x, ReadStateMsg())
-@test take!(self()).state == 0
+ put!(x, ReadStateMsg())
+ @test take!(self()).state == 0
-@sync begin
- for _ in 1:100
- Threads.@spawn put!(x, IncMsg())
- Threads.@spawn put!(x, DecMsg())
- end
- for _ in 1:10
- for _ in 1:10
+ @sync begin
+ for _ in 1:100
Threads.@spawn put!(x, IncMsg())
+ Threads.@spawn put!(x, DecMsg())
end
for _ in 1:10
- Threads.@spawn put!(x, DecMsg())
+ for _ in 1:10
+ Threads.@spawn put!(x, IncMsg())
+ end
+ for _ in 1:10
+ Threads.@spawn put!(x, DecMsg())
+ end
end
end
-end
-put!(x, ReadStateMsg())
-@test take!(self()).state == 0
+ put!(x, ReadStateMsg())
+ @test take!(self()).state == 0
-y = actor(TestActor())
-put!(x, ProxyMsg(;to=y,msg=StartMsg(0)))
-put!(x, ProxyMsg(;to=y,msg=ReadStateMsg()))
-@test take!(self()).state == 0
+ y = actor(TestActor())
+ put!(x, ProxyMsg(; to = y, msg = StartMsg(0)))
+ put!(x, ProxyMsg(; to = y, msg = ReadStateMsg()))
+ @test take!(self()).state == 0
-end
\ No newline at end of file
+end
diff --git a/src/DistributedReinforcementLearning/test/core.jl b/src/DistributedReinforcementLearning/test/core.jl
index 25d5c2c49..d4a196f37 100644
--- a/src/DistributedReinforcementLearning/test/core.jl
+++ b/src/DistributedReinforcementLearning/test/core.jl
@@ -1,181 +1,202 @@
@testset "core.jl" begin
-@testset "Trainer" begin
- _trainer = Trainer(;
- policy=BasicDQNLearner(
- approximator = NeuralNetworkApproximator(
- model = Chain(
- Dense(4, 128, relu; initW = glorot_uniform),
- Dense(128, 128, relu; initW = glorot_uniform),
- Dense(128, 2; initW = glorot_uniform),
- ) |> cpu,
- optimizer = ADAM(),
+ @testset "Trainer" begin
+ _trainer = Trainer(;
+ policy = BasicDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
+ Dense(4, 128, relu; initW = glorot_uniform),
+ Dense(128, 128, relu; initW = glorot_uniform),
+ Dense(128, 2; initW = glorot_uniform),
+ ) |> cpu,
+ optimizer = ADAM(),
+ ),
+ loss_func = huber_loss,
),
- loss_func = huber_loss,
)
- )
- trainer = actor(_trainer)
+ trainer = actor(_trainer)
- put!(trainer, FetchParamMsg())
- ps = take!(self())
- original_sum = sum(sum, ps.data)
+ put!(trainer, FetchParamMsg())
+ ps = take!(self())
+ original_sum = sum(sum, ps.data)
- for x in ps.data
- fill!(x, 0.)
- end
+ for x in ps.data
+ fill!(x, 0.0)
+ end
- put!(trainer, FetchParamMsg())
- ps = take!(self())
- new_sum = sum(sum, ps.data)
+ put!(trainer, FetchParamMsg())
+ ps = take!(self())
+ new_sum = sum(sum, ps.data)
- # make sure no state sharing between messages
- @test original_sum == new_sum
+ # make sure no state sharing between messages
+ @test original_sum == new_sum
- batch_data = (
- state = rand(4, 32),
- action = rand(1:2, 32),
- reward = rand(32),
- terminal = rand(Bool, 32),
- next_state = rand(4,32),
- next_action = rand(1:2, 32)
- )
+ batch_data = (
+ state = rand(4, 32),
+ action = rand(1:2, 32),
+ reward = rand(32),
+ terminal = rand(Bool, 32),
+ next_state = rand(4, 32),
+ next_action = rand(1:2, 32),
+ )
- put!(trainer, BatchDataMsg(batch_data))
+ put!(trainer, BatchDataMsg(batch_data))
- put!(trainer, FetchParamMsg())
- ps = take!(self())
- updated_sum = sum(sum, ps.data)
- @test original_sum != updated_sum
-end
+ put!(trainer, FetchParamMsg())
+ ps = take!(self())
+ updated_sum = sum(sum, ps.data)
+ @test original_sum != updated_sum
+ end
-@testset "TrajectoryManager" begin
- _trajectory_proxy = TrajectoryManager(
- trajectory = CircularSARTSATrajectory(;capacity=5, state_type=Any, ),
- sampler = UniformBatchSampler(3),
- inserter = NStepInserter(),
- )
+ @testset "TrajectoryManager" begin
+ _trajectory_proxy = TrajectoryManager(
+ trajectory = CircularSARTSATrajectory(; capacity = 5, state_type = Any),
+ sampler = UniformBatchSampler(3),
+ inserter = NStepInserter(),
+ )
- trajectory_proxy = actor(_trajectory_proxy)
+ trajectory_proxy = actor(_trajectory_proxy)
- # 1. init traj for testing
- traj = CircularCompactSARTSATrajectory(
- capacity = 2,
- state_type = Float32,
- state_size = (4,),
- )
- push!(traj;state=rand(Float32, 4), action=rand(1:2))
- push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2))
- push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2))
+ # 1. init traj for testing
+ traj = CircularCompactSARTSATrajectory(
+ capacity = 2,
+ state_type = Float32,
+ state_size = (4,),
+ )
+ push!(traj; state = rand(Float32, 4), action = rand(1:2))
+ push!(
+ traj;
+ reward = rand(),
+ terminal = rand(Bool),
+ state = rand(Float32, 4),
+ action = rand(1:2),
+ )
+ push!(
+ traj;
+ reward = rand(),
+ terminal = rand(Bool),
+ state = rand(Float32, 4),
+ action = rand(1:2),
+ )
- # 2. insert
- put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here
+ # 2. insert
+ put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here
- # 3. make sure the above message is already been handled
- put!(trajectory_proxy, PingMsg())
- take!(self())
+ # 3. make sure the above message is already been handled
+ put!(trajectory_proxy, PingMsg())
+ take!(self())
- # 4. test that updating traj will not affect data in trajectory_proxy
- s_tp = _trajectory_proxy.trajectory[:state]
- s_traj = traj[:state]
+ # 4. test that updating traj will not affect data in trajectory_proxy
+ s_tp = _trajectory_proxy.trajectory[:state]
+ s_traj = traj[:state]
- @test s_tp[1] == s_traj[:, 1]
+ @test s_tp[1] == s_traj[:, 1]
- push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2))
+ push!(
+ traj;
+ reward = rand(),
+ terminal = rand(Bool),
+ state = rand(Float32, 4),
+ action = rand(1:2),
+ )
- @test s_tp[1] != s_traj[:, 1]
+ @test s_tp[1] != s_traj[:, 1]
- s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler)
- fill!(s[:state], 0.)
- @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy
-end
+ s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler)
+ fill!(s[:state], 0.0)
+ @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy
+ end
-@testset "Worker" begin
- _worker = Worker() do worker_proxy
- Experiment(
- Agent(
- policy = StaticPolicy(
+ @testset "Worker" begin
+ _worker = Worker() do worker_proxy
+ Experiment(
+ Agent(
+ policy = StaticPolicy(
QBasedPolicy(
- learner = BasicDQNLearner(
- approximator = NeuralNetworkApproximator(
- model = Chain(
- Dense(4, 128, relu; initW = glorot_uniform),
- Dense(128, 128, relu; initW = glorot_uniform),
- Dense(128, 2; initW = glorot_uniform),
- ) |> cpu,
- optimizer = ADAM(),
+ learner = BasicDQNLearner(
+ approximator = NeuralNetworkApproximator(
+ model = Chain(
+ Dense(4, 128, relu; initW = glorot_uniform),
+ Dense(128, 128, relu; initW = glorot_uniform),
+ Dense(128, 2; initW = glorot_uniform),
+ ) |> cpu,
+ optimizer = ADAM(),
+ ),
+ loss_func = huber_loss,
+ ),
+ explorer = EpsilonGreedyExplorer(
+ kind = :exp,
+ ϵ_stable = 0.01,
+ decay_steps = 500,
),
- loss_func = huber_loss,
- ),
- explorer = EpsilonGreedyExplorer(
- kind = :exp,
- ϵ_stable = 0.01,
- decay_steps = 500,
),
),
+ trajectory = CircularCompactSARTSATrajectory(
+ capacity = 10,
+ state_type = Float32,
+ state_size = (4,),
+ ),
),
- trajectory = CircularCompactSARTSATrajectory(
- capacity = 10,
- state_type = Float32,
- state_size = (4,),
+ CartPoleEnv(; T = Float32),
+ ComposedStopCondition(StopAfterStep(1_000), StopSignal()),
+ ComposedHook(
+ UploadTrajectoryEveryNStep(
+ mailbox = worker_proxy,
+ n = 10,
+ sealer = x -> InsertTrajectoryMsg(deepcopy(x)),
+ ),
+ LoadParamsHook(),
+ TotalRewardPerEpisode(),
),
- ),
- CartPoleEnv(; T = Float32),
- ComposedStopCondition(
- StopAfterStep(1_000),
- StopSignal(),
- ),
- ComposedHook(
- UploadTrajectoryEveryNStep(mailbox=worker_proxy, n=10, sealer=x -> InsertTrajectoryMsg(deepcopy(x))),
- LoadParamsHook(),
- TotalRewardPerEpisode(),
- ),
- "experimenting..."
- )
- end
+ "experimenting...",
+ )
+ end
- worker = actor(_worker)
- tmp_mailbox = Channel(100)
- put!(worker, StartMsg(tmp_mailbox))
-end
-
-@testset "WorkerProxy" begin
- target = RemoteChannel(() -> Channel(10))
- workers = [RemoteChannel(()->Channel(10)) for _ in 1:10]
- _wp = WorkerProxy(workers)
- wp = actor(_wp)
-
- put!(wp, StartMsg(target))
- for w in workers
- # @test take!(w).args[1] === wp
- @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) === Distributed.channel_from_id(remoteref_id(wp))
+ worker = actor(_worker)
+ tmp_mailbox = Channel(100)
+ put!(worker, StartMsg(tmp_mailbox))
end
- msg = InsertTrajectoryMsg(1)
- put!(wp, msg)
- @test take!(target) === msg
-
- for w in workers
- put!(wp, FetchParamMsg(w))
+ @testset "WorkerProxy" begin
+ target = RemoteChannel(() -> Channel(10))
+ workers = [RemoteChannel(() -> Channel(10)) for _ in 1:10]
+ _wp = WorkerProxy(workers)
+ wp = actor(_wp)
+
+ put!(wp, StartMsg(target))
+ for w in workers
+ # @test take!(w).args[1] === wp
+ @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) ===
+ Distributed.channel_from_id(remoteref_id(wp))
+ end
+
+ msg = InsertTrajectoryMsg(1)
+ put!(wp, msg)
+ @test take!(target) === msg
+
+ for w in workers
+ put!(wp, FetchParamMsg(w))
+ end
+ # @test take!(target).from === wp
+ @test Distributed.channel_from_id(remoteref_id(take!(target).from)) ===
+ Distributed.channel_from_id(remoteref_id(wp))
+
+ # make sure target only received one FetchParamMsg
+ msg = PingMsg()
+ put!(target, msg)
+ @test take!(target) === msg
+
+ msg = LoadParamMsg([])
+ put!(wp, msg)
+ for w in workers
+ @test take!(w) === msg
+ end
end
- # @test take!(target).from === wp
- @test Distributed.channel_from_id(remoteref_id(take!(target).from)) === Distributed.channel_from_id(remoteref_id(wp))
-
- # make sure target only received one FetchParamMsg
- msg = PingMsg()
- put!(target, msg)
- @test take!(target) === msg
-
- msg = LoadParamMsg([])
- put!(wp, msg)
- for w in workers
- @test take!(w) === msg
+
+ @testset "Orchestrator" begin
+ # TODO
+ # Add an integration test
end
-end
-@testset "Orchestrator" begin
- # TODO
- # Add an integration test
end
-
-end
\ No newline at end of file
diff --git a/src/DistributedReinforcementLearning/test/runtests.jl b/src/DistributedReinforcementLearning/test/runtests.jl
index f9f572d55..a73cd0521 100644
--- a/src/DistributedReinforcementLearning/test/runtests.jl
+++ b/src/DistributedReinforcementLearning/test/runtests.jl
@@ -9,7 +9,7 @@ using Flux
@testset "DistributedReinforcementLearning.jl" begin
-include("actor.jl")
-include("core.jl")
+ include("actor.jl")
+ include("core.jl")
end
diff --git a/src/ReinforcementLearningBase/src/CommonRLInterface.jl b/src/ReinforcementLearningBase/src/CommonRLInterface.jl
index af86a2d83..28c73ec40 100644
--- a/src/ReinforcementLearningBase/src/CommonRLInterface.jl
+++ b/src/ReinforcementLearningBase/src/CommonRLInterface.jl
@@ -41,7 +41,8 @@ end
# !!! may need to be extended by user
CRL.@provide CRL.observe(env::CommonRLEnv) = state(env.env)
-CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = !isnothing(find_state_style(env.env, InternalState))
+CRL.provided(::typeof(CRL.state), env::CommonRLEnv) =
+ !isnothing(find_state_style(env.env, InternalState))
CRL.state(env::CommonRLEnv) = state(env.env, find_state_style(env.env, InternalState))
CRL.@provide CRL.clone(env::CommonRLEnv) = CommonRLEnv(copy(env.env))
@@ -94,4 +95,4 @@ ActionStyle(env::RLBaseEnv) =
CRL.provided(CRL.valid_actions, env.env) ? FullActionSet() : MinimalActionSet()
current_player(env::RLBaseEnv) = CRL.player(env.env)
-players(env::RLBaseEnv) = CRL.players(env.env)
\ No newline at end of file
+players(env::RLBaseEnv) = CRL.players(env.env)
diff --git a/src/ReinforcementLearningBase/src/interface.jl b/src/ReinforcementLearningBase/src/interface.jl
index 7c7b16888..3f64d6ad7 100644
--- a/src/ReinforcementLearningBase/src/interface.jl
+++ b/src/ReinforcementLearningBase/src/interface.jl
@@ -410,12 +410,13 @@ Make an independent copy of `env`,
!!! warning
Only check the state of all players in the env.
"""
-function Base.:(==)(env1::T, env2::T) where T<:AbstractEnv
+function Base.:(==)(env1::T, env2::T) where {T<:AbstractEnv}
len = length(players(env1))
- len == length(players(env2)) &&
- all(state(env1, player) == state(env2, player) for player in players(env1))
+ len == length(players(env2)) &&
+ all(state(env1, player) == state(env2, player) for player in players(env1))
end
-Base.hash(env::AbstractEnv, h::UInt) = hash([state(env, player) for player in players(env)], h)
+Base.hash(env::AbstractEnv, h::UInt) =
+ hash([state(env, player) for player in players(env)], h)
@api nameof(env::AbstractEnv) = nameof(typeof(env))
diff --git a/src/ReinforcementLearningBase/test/CommonRLInterface.jl b/src/ReinforcementLearningBase/test/CommonRLInterface.jl
index fc38a102b..7b32dbe08 100644
--- a/src/ReinforcementLearningBase/test/CommonRLInterface.jl
+++ b/src/ReinforcementLearningBase/test/CommonRLInterface.jl
@@ -1,34 +1,34 @@
@testset "CommonRLInterface" begin
-@testset "MDPEnv" begin
- struct RLTestMDP <: MDP{Int, Int} end
+ @testset "MDPEnv" begin
+ struct RLTestMDP <: MDP{Int,Int} end
- POMDPs.actions(m::RLTestMDP) = [-1, 1]
- POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
- POMDPs.initialstate(m::RLTestMDP) = Deterministic(1)
- POMDPs.isterminal(m::RLTestMDP, s) = s == 3
- POMDPs.reward(m::RLTestMDP, s, a, sp) = sp
- POMDPs.states(m::RLTestMDP) = 1:3
+ POMDPs.actions(m::RLTestMDP) = [-1, 1]
+ POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
+ POMDPs.initialstate(m::RLTestMDP) = Deterministic(1)
+ POMDPs.isterminal(m::RLTestMDP, s) = s == 3
+ POMDPs.reward(m::RLTestMDP, s, a, sp) = sp
+ POMDPs.states(m::RLTestMDP) = 1:3
- env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP()))
- RLBase.test_runnable!(env)
-end
+ env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP()))
+ RLBase.test_runnable!(env)
+ end
-@testset "POMDPEnv" begin
+ @testset "POMDPEnv" begin
- struct RLTestPOMDP <: POMDP{Int, Int, Int} end
+ struct RLTestPOMDP <: POMDP{Int,Int,Int} end
- POMDPs.actions(m::RLTestPOMDP) = [-1, 1]
- POMDPs.states(m::RLTestPOMDP) = 1:3
- POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
- POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1)
- POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1)
- POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1)
- POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3
- POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp
- POMDPs.observations(m::RLTestPOMDP) = 2:4
+ POMDPs.actions(m::RLTestPOMDP) = [-1, 1]
+ POMDPs.states(m::RLTestPOMDP) = 1:3
+ POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3))
+ POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1)
+ POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1)
+ POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1)
+ POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3
+ POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp
+ POMDPs.observations(m::RLTestPOMDP) = 2:4
- env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP()))
+ env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP()))
- RLBase.test_runnable!(env)
+ RLBase.test_runnable!(env)
+ end
end
-end
\ No newline at end of file
diff --git a/src/ReinforcementLearningBase/test/runtests.jl b/src/ReinforcementLearningBase/test/runtests.jl
index 6b44a29b8..d4f743f68 100644
--- a/src/ReinforcementLearningBase/test/runtests.jl
+++ b/src/ReinforcementLearningBase/test/runtests.jl
@@ -8,5 +8,5 @@ using POMDPs
using POMDPModelTools: Deterministic
@testset "ReinforcementLearningBase" begin
-include("CommonRLInterface.jl")
-end
\ No newline at end of file
+ include("CommonRLInterface.jl")
+end
diff --git a/src/ReinforcementLearningCore/src/core/experiment.jl b/src/ReinforcementLearningCore/src/core/experiment.jl
index 8240f36e3..f753f6d62 100644
--- a/src/ReinforcementLearningCore/src/core/experiment.jl
+++ b/src/ReinforcementLearningCore/src/core/experiment.jl
@@ -24,7 +24,7 @@ end
function Base.show(io::IO, x::Experiment)
display(Markdown.parse(x.description))
- AbstractTrees.print_tree(io, StructTree(x), maxdepth=get(io, :max_depth, 10))
+ AbstractTrees.print_tree(io, StructTree(x), maxdepth = get(io, :max_depth, 10))
end
macro experiment_cmd(s)
@@ -51,7 +51,7 @@ function Experiment(s::String)
)
end
-function Base.run(x::Experiment; describe::Bool=true)
+function Base.run(x::Experiment; describe::Bool = true)
describe && display(Markdown.parse(x.description))
run(x.policy, x.env, x.stop_condition, x.hook)
x
diff --git a/src/ReinforcementLearningCore/src/core/hooks.jl b/src/ReinforcementLearningCore/src/core/hooks.jl
index 11c65e686..a8a58be3e 100644
--- a/src/ReinforcementLearningCore/src/core/hooks.jl
+++ b/src/ReinforcementLearningCore/src/core/hooks.jl
@@ -13,7 +13,7 @@ export AbstractHook,
UploadTrajectoryEveryNStep,
MultiAgentHook
-using UnicodePlots:lineplot, lineplot!
+using UnicodePlots: lineplot, lineplot!
using Statistics
"""
@@ -155,7 +155,14 @@ end
function (hook::TotalRewardPerEpisode)(::PostExperimentStage, agent, env)
if hook.is_display_on_exit
- println(lineplot(hook.rewards, title="Total reward per episode", xlabel="Episode", ylabel="Score"))
+ println(
+ lineplot(
+ hook.rewards,
+ title = "Total reward per episode",
+ xlabel = "Episode",
+ ylabel = "Score",
+ ),
+ )
end
end
@@ -178,8 +185,12 @@ which return a `Vector` of rewards (a typical case with `MultiThreadEnv`).
If `is_display_on_exit` is set to `true`, a ribbon plot will be shown to reflect
the mean and std of rewards.
"""
-function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit=true)
- TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size), is_display_on_exit)
+function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit = true)
+ TotalBatchRewardPerEpisode(
+ [Float64[] for _ in 1:batch_size],
+ zeros(batch_size),
+ is_display_on_exit,
+ )
end
function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env)
@@ -198,7 +209,12 @@ function (hook::TotalBatchRewardPerEpisode)(::PostExperimentStage, agent, env)
n = minimum(map(length, hook.rewards))
m = mean([@view(x[1:n]) for x in hook.rewards])
s = std([@view(x[1:n]) for x in hook.rewards])
- p = lineplot(m, title="Avg total reward per episode", xlabel="Episode", ylabel="Score")
+ p = lineplot(
+ m,
+ title = "Avg total reward per episode",
+ xlabel = "Episode",
+ ylabel = "Score",
+ )
lineplot!(p, m .- s)
lineplot!(p, m .+ s)
println(p)
@@ -288,8 +304,7 @@ end
Execute `f(t, agent, env)` every `n` episode.
`t` is a counter of episodes.
"""
-mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <:
- AbstractHook
+mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: AbstractHook
f::F
n::Int
t::Int
diff --git a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl
index d641c615b..507907b8a 100644
--- a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl
+++ b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl
@@ -1,7 +1,10 @@
using ArrayInterface
-function ArrayInterface.restructure(x::AbstractArray{T1, 0}, y::AbstractArray{T2, 0}) where {T1, T2}
+function ArrayInterface.restructure(
+ x::AbstractArray{T1,0},
+ y::AbstractArray{T2,0},
+) where {T1,T2}
out = similar(x, eltype(y))
out .= y
out
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl b/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl
index dd4445eb7..d76810301 100644
--- a/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl
+++ b/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl
@@ -1,6 +1,6 @@
using AbstractTrees
Base.show(io::IO, p::AbstractPolicy) =
- AbstractTrees.print_tree(io, StructTree(p), maxdepth=get(io, :max_depth, 10))
+ AbstractTrees.print_tree(io, StructTree(p), maxdepth = get(io, :max_depth, 10))
is_expand(::AbstractEnv) = false
diff --git a/src/ReinforcementLearningCore/src/policies/agents/agent.jl b/src/ReinforcementLearningCore/src/policies/agents/agent.jl
index fb7862052..59f415063 100644
--- a/src/ReinforcementLearningCore/src/policies/agents/agent.jl
+++ b/src/ReinforcementLearningCore/src/policies/agents/agent.jl
@@ -139,7 +139,9 @@ function RLBase.update!(
# TODO: how to inject a local rng here to avoid polluting the global rng
s = policy isa NamedPolicy ? state(env, nameof(policy)) : state(env)
- a = policy isa NamedPolicy ? rand(action_space(env, nameof(policy))) : rand(action_space(env))
+ a =
+ policy isa NamedPolicy ? rand(action_space(env, nameof(policy))) :
+ rand(action_space(env))
push!(trajectory[:state], s)
push!(trajectory[:action], a)
if haskey(trajectory, :legal_actions_mask)
diff --git a/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl b/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl
index 7bc0ab255..6a496dd34 100644
--- a/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl
+++ b/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl
@@ -23,7 +23,8 @@ of `SIMULTANEOUS` style, please wrap it with [`SequentialEnv`](@ref) first.
MultiAgentManager(policies...) =
MultiAgentManager(Dict{Any,Any}(nameof(p) => p for p in policies))
-RLBase.prob(A::MultiAgentManager, env::AbstractEnv, args...) = prob(A[current_player(env)].policy, env, args...)
+RLBase.prob(A::MultiAgentManager, env::AbstractEnv, args...) =
+ prob(A[current_player(env)].policy, env, args...)
(A::MultiAgentManager)(env::AbstractEnv) = A(env, DynamicStyle(env))
diff --git a/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl b/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl
index 215ba5639..a692037df 100644
--- a/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl
+++ b/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl
@@ -42,6 +42,7 @@ function RLBase.update!(
end
-(p::NamedPolicy)(env::AbstractEnv) = DynamicStyle(env) == SEQUENTIAL ? p.policy(env) : p.policy(env, p.name)
+(p::NamedPolicy)(env::AbstractEnv) =
+ DynamicStyle(env) == SEQUENTIAL ? p.policy(env) : p.policy(env, p.name)
(p::NamedPolicy)(s::AbstractStage, env::AbstractEnv) = p.policy(s, env)
(p::NamedPolicy)(s::PreActStage, env::AbstractEnv, action) = p.policy(s, env, action)
diff --git a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl
index 7818e24ce..0050151c9 100644
--- a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl
+++ b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl
@@ -85,7 +85,11 @@ function fetch!(s::BatchSampler, t::AbstractTrajectory, inds::Vector{Int})
end
end
-function fetch!(s::BatchSampler{traces}, t::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory}, inds::Vector{Int}) where {traces}
+function fetch!(
+ s::BatchSampler{traces},
+ t::Union{CircularArraySARTTrajectory,CircularArraySLARTTrajectory},
+ inds::Vector{Int},
+) where {traces}
if traces == SARTS
batch = NamedTuple{SARTS}((
(consecutive_view(t[x], inds) for x in SART)...,
@@ -100,7 +104,7 @@ function fetch!(s::BatchSampler{traces}, t::Union{CircularArraySARTTrajectory, C
else
@error "unsupported traces $traces"
end
-
+
if isnothing(s.cache)
s.cache = map(batch) do x
convert(Array, x)
@@ -151,7 +155,7 @@ end
function fetch!(
sampler::NStepBatchSampler{traces},
- traj::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory},
+ traj::Union{CircularArraySARTTrajectory,CircularArraySLARTTrajectory},
inds::Vector{Int},
) where {traces}
γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size
diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl
index f176fbdab..95e2b689a 100644
--- a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl
+++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl
@@ -17,7 +17,7 @@ function (learner::AbstractLearner)(env) end
function RLBase.priority(p::AbstractLearner, experience) end
Base.show(io::IO, p::AbstractLearner) =
- AbstractTrees.print_tree(io, StructTree(p), maxdepth=get(io, :max_depth, 10))
+ AbstractTrees.print_tree(io, StructTree(p), maxdepth = get(io, :max_depth, 10))
function RLBase.update!(
L::AbstractLearner,
diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl
index 926d0d40c..81ab4e41d 100644
--- a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl
+++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl
@@ -1,4 +1,5 @@
-export NeuralNetworkApproximator, ActorCritic, GaussianNetwork, DuelingNetwork, PerturbationNetwork
+export NeuralNetworkApproximator,
+ ActorCritic, GaussianNetwork, DuelingNetwork, PerturbationNetwork
export VAE, decode, vae_loss
using Flux
@@ -79,7 +80,7 @@ Base.@kwdef struct GaussianNetwork{P,U,S}
pre::P = identity
μ::U
logσ::S
- min_σ::Float32 = 0f0
+ min_σ::Float32 = 0.0f0
max_σ::Float32 = Inf32
end
@@ -92,15 +93,24 @@ This function is compatible with a multidimensional action space. When outputtin
- `is_sampling::Bool=false`, whether to sample from the obtained normal distribution.
- `is_return_log_prob::Bool=false`, whether to calculate the conditional probability of getting actions in the given state.
"""
-function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
+function (model::GaussianNetwork)(
+ rng::AbstractRNG,
+ state;
+ is_sampling::Bool = false,
+ is_return_log_prob::Bool = false,
+)
x = model.pre(state)
- μ, raw_logσ = model.μ(x), model.logσ(x)
+ μ, raw_logσ = model.μ(x), model.logσ(x)
logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ))
if is_sampling
σ = exp.(logσ)
z = μ .+ σ .* send_to_device(device(model), randn(rng, Float32, size(μ)))
if is_return_log_prob
- logp_π = sum(normlogpdf(μ, σ, z) .- (2.0f0 .* (log(2.0f0) .- z .- softplus.(-2.0f0 .* z))), dims = 1)
+ logp_π = sum(
+ normlogpdf(μ, σ, z) .-
+ (2.0f0 .* (log(2.0f0) .- z .- softplus.(-2.0f0 .* z))),
+ dims = 1,
+ )
return tanh.(z), logp_π
else
return tanh.(z)
@@ -110,16 +120,29 @@ function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=fal
end
end
-function (model::GaussianNetwork)(state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
- model(Random.GLOBAL_RNG, state; is_sampling=is_sampling, is_return_log_prob=is_return_log_prob)
+function (model::GaussianNetwork)(
+ state;
+ is_sampling::Bool = false,
+ is_return_log_prob::Bool = false,
+)
+ model(
+ Random.GLOBAL_RNG,
+ state;
+ is_sampling = is_sampling,
+ is_return_log_prob = is_return_log_prob,
+ )
end
function (model::GaussianNetwork)(state, action)
x = model.pre(state)
- μ, raw_logσ = model.μ(x), model.logσ(x)
+ μ, raw_logσ = model.μ(x), model.logσ(x)
logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ))
σ = exp.(logσ)
- logp_π = sum(normlogpdf(μ, σ, action) .- (2.0f0 .* (log(2.0f0) .- action .- softplus.(-2.0f0 .* action))), dims = 1)
+ logp_π = sum(
+ normlogpdf(μ, σ, action) .-
+ (2.0f0 .* (log(2.0f0) .- action .- softplus.(-2.0f0 .* action))),
+ dims = 1,
+ )
return logp_π
end
@@ -143,7 +166,7 @@ Flux.@functor DuelingNetwork
function (m::DuelingNetwork)(state)
x = m.base(state)
val = m.val(x)
- return val .+ m.adv(x) .- mean(m.adv(x), dims=1)
+ return val .+ m.adv(x) .- mean(m.adv(x), dims = 1)
end
#####
@@ -183,7 +206,7 @@ end
"""
VAE(;encoder, decoder, latent_dims)
"""
-Base.@kwdef struct VAE{E, D}
+Base.@kwdef struct VAE{E,D}
encoder::E
decoder::D
latent_dims::Int
@@ -207,9 +230,14 @@ function reparamaterize(rng, μ, σ)
return Float32(rand(rng, Normal(0, 1))) * σ + μ
end
-function decode(rng::AbstractRNG, model::VAE, state, z=nothing; is_normalize::Bool=true)
+function decode(rng::AbstractRNG, model::VAE, state, z = nothing; is_normalize::Bool = true)
if z === nothing
- z = clamp.(randn(rng, Float32, (model.latent_dims, size(state)[2:end]...)), -0.5f0, 0.5f0)
+ z =
+ clamp.(
+ randn(rng, Float32, (model.latent_dims, size(state)[2:end]...)),
+ -0.5f0,
+ 0.5f0,
+ )
end
a = model.decoder(vcat(state, z))
if is_normalize
@@ -218,7 +246,7 @@ function decode(rng::AbstractRNG, model::VAE, state, z=nothing; is_normalize::Bo
return a
end
-function decode(model::VAE, state, z=nothing; is_normalize::Bool=true)
+function decode(model::VAE, state, z = nothing; is_normalize::Bool = true)
decode(Random.GLOBAL_RNG, model, state, z; is_normalize)
end
diff --git a/src/ReinforcementLearningCore/test/components/trajectories.jl b/src/ReinforcementLearningCore/test/components/trajectories.jl
index 8fb8eb9ae..c7c60e163 100644
--- a/src/ReinforcementLearningCore/test/components/trajectories.jl
+++ b/src/ReinforcementLearningCore/test/components/trajectories.jl
@@ -52,9 +52,9 @@
t = CircularArraySLARTTrajectory(
capacity = 3,
state = Vector{Int} => (4,),
- legal_actions_mask = Vector{Bool} => (4, ),
+ legal_actions_mask = Vector{Bool} => (4,),
)
-
+
# test instance type is same as type
@test isa(t, CircularArraySLARTTrajectory)
diff --git a/src/ReinforcementLearningCore/test/core/core.jl b/src/ReinforcementLearningCore/test/core/core.jl
index fb6f5fde5..56a81b809 100644
--- a/src/ReinforcementLearningCore/test/core/core.jl
+++ b/src/ReinforcementLearningCore/test/core/core.jl
@@ -1,22 +1,18 @@
@testset "simple workflow" begin
- env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy)
+ env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy)
policy = RandomPolicy(action_space(env))
N_EPISODE = 10_000
hook = TotalRewardPerEpisode()
run(policy, env, StopAfterEpisode(N_EPISODE), hook)
- @test isapprox(sum(hook[]) / N_EPISODE, 21; atol=2)
+ @test isapprox(sum(hook[]) / N_EPISODE, 21; atol = 2)
end
@testset "multi agent" begin
# https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/393
rps = RockPaperScissorsEnv() |> SequentialEnv
- ma_policy = MultiAgentManager(
- (
- NamedPolicy(p => RandomPolicy())
- for p in players(rps)
- )...
- )
+ ma_policy =
+ MultiAgentManager((NamedPolicy(p => RandomPolicy()) for p in players(rps))...)
run(ma_policy, rps, StopAfterEpisode(10))
end
diff --git a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl
index a2d657fa1..fc5b0bfc8 100644
--- a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl
+++ b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl
@@ -1,5 +1,5 @@
@testset "test StopAfterNoImprovement" begin
- env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy)
+ env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy)
policy = RandomPolicy(action_space(env))
total_reward_per_episode = TotalRewardPerEpisode()
@@ -14,7 +14,8 @@
hook = ComposedHook(total_reward_per_episode)
run(policy, env, stop_condition, hook)
- @test argmax(total_reward_per_episode.rewards) + patience == length(total_reward_per_episode.rewards)
+ @test argmax(total_reward_per_episode.rewards) + patience ==
+ length(total_reward_per_episode.rewards)
end
@testset "StopAfterNSeconds" begin
diff --git a/src/ReinforcementLearningDatasets/docs/make.jl b/src/ReinforcementLearningDatasets/docs/make.jl
index 5fdf1f97f..2a5eca27c 100644
--- a/src/ReinforcementLearningDatasets/docs/make.jl
+++ b/src/ReinforcementLearningDatasets/docs/make.jl
@@ -1,4 +1,4 @@
-push!(LOAD_PATH,"../src/")
+push!(LOAD_PATH, "../src/")
using Documenter, ReinforcementLearningDatasets
-makedocs(sitename="ReinforcementLearningDatasets")
\ No newline at end of file
+makedocs(sitename = "ReinforcementLearningDatasets")
diff --git a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl
index 4ecc99e29..c0d813fc2 100644
--- a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl
+++ b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl
@@ -28,4 +28,4 @@ include("deep_ope/d4rl/d4rl_policy.jl")
include("deep_ope/d4rl/evaluate.jl")
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl b/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl
index 53e65c436..da84bba32 100644
--- a/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl
+++ b/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl
@@ -16,8 +16,8 @@ Represents an `Iterable` dataset with the following fields:
- `meta::Dict`: the metadata provided along with the dataset.
- `is_shuffle::Bool`: determines if the batches returned by `iterate` are shuffled.
"""
-struct AtariDataSet{T<:AbstractRNG} <:RLDataSet
- dataset::Dict{Symbol, Any}
+struct AtariDataSet{T<:AbstractRNG} <: RLDataSet
+ dataset::Dict{Symbol,Any}
epochs::Vector{Int}
repo::String
length::Int
@@ -61,29 +61,29 @@ function dataset(
game::String,
index::Int,
epochs::Vector{Int};
- style::NTuple=SARTS,
- repo::String="atari-replay-datasets",
- rng::AbstractRNG=MersenneTwister(123),
- is_shuffle::Bool=true,
- batch_size::Int=256
+ style::NTuple = SARTS,
+ repo::String = "atari-replay-datasets",
+ rng::AbstractRNG = MersenneTwister(123),
+ is_shuffle::Bool = true,
+ batch_size::Int = 256,
)
-
- try
+
+ try
@datadep_str "$repo-$game-$index"
catch e
if isa(e, KeyError)
throw("Invalid params, check out `atari_params()`")
end
end
-
- path = @datadep_str "$repo-$game-$index"
+
+ path = @datadep_str "$repo-$game-$index"
@assert length(readdir(path)) == 1
folder_name = readdir(path)[1]
-
+
folder_path = "$path/$folder_name"
files = readdir(folder_path)
- file_prefixes = collect(Set(map(x->join(split(x,"_")[1:2], "_"), files)))
+ file_prefixes = collect(Set(map(x -> join(split(x, "_")[1:2], "_"), files)))
fields = map(collect(file_prefixes)) do x
if split(x, "_")[1] == "\$store\$"
x = split(x, "_")[2]
@@ -93,7 +93,7 @@ function dataset(
end
s_epochs = Set(epochs)
-
+
dataset = Dict()
for (prefix, field) in zip(file_prefixes, fields)
@@ -110,9 +110,9 @@ function dataset(
if haskey(dataset, field)
if field == "observation"
- dataset[field] = cat(dataset[field], data, dims=3)
+ dataset[field] = cat(dataset[field], data, dims = 3)
else
- dataset[field] = cat(dataset[field], data, dims=1)
+ dataset[field] = cat(dataset[field], data, dims = 1)
end
else
dataset[field] = data
@@ -122,24 +122,37 @@ function dataset(
num_epochs = length(s_epochs)
- atari_verify(dataset, num_epochs)
+ atari_verify(dataset, num_epochs)
N_samples = size(dataset["observation"])[3]
-
- final_dataset = Dict{Symbol, Any}()
- meta = Dict{String, Any}()
- for (key, d_key) in zip(["observation", "action", "reward", "terminal"], Symbol.(["state", "action", "reward", "terminal"]))
- final_dataset[d_key] = dataset[key]
+ final_dataset = Dict{Symbol,Any}()
+ meta = Dict{String,Any}()
+
+ for (key, d_key) in zip(
+ ["observation", "action", "reward", "terminal"],
+ Symbol.(["state", "action", "reward", "terminal"]),
+ )
+ final_dataset[d_key] = dataset[key]
end
-
+
for key in keys(dataset)
if !(key in ["observation", "action", "reward", "terminal"])
meta[key] = dataset[key]
end
end
- return AtariDataSet(final_dataset, epochs, repo, N_samples, batch_size, style, rng, meta, is_shuffle)
+ return AtariDataSet(
+ final_dataset,
+ epochs,
+ repo,
+ N_samples,
+ batch_size,
+ style,
+ rng,
+ meta,
+ is_shuffle,
+ )
end
@@ -153,7 +166,7 @@ function iterate(ds::AtariDataSet, state = 0)
if is_shuffle
inds = rand(rng, 1:length-1, batch_size)
else
- if (state+1) * batch_size <= length
+ if (state + 1) * batch_size <= length
inds = state*batch_size+1:(state+1)*batch_size
else
return nothing
@@ -161,15 +174,17 @@ function iterate(ds::AtariDataSet, state = 0)
state += 1
end
- batch = (state = view(ds.dataset[:state], :, :, inds),
- action = view(ds.dataset[:action], inds),
- reward = view(ds.dataset[:reward], inds),
- terminal = view(ds.dataset[:terminal], inds))
+ batch = (
+ state = view(ds.dataset[:state], :, :, inds),
+ action = view(ds.dataset[:action], inds),
+ reward = view(ds.dataset[:reward], inds),
+ terminal = view(ds.dataset[:terminal], inds),
+ )
if style == SARTS
- batch = merge(batch, (next_state = view(ds.dataset[:state], :, :, (1).+(inds)),))
+ batch = merge(batch, (next_state = view(ds.dataset[:state], :, :, (1) .+ (inds)),))
end
-
+
return batch, state
end
@@ -179,7 +194,8 @@ length(ds::AtariDataSet) = ds.length
IteratorEltype(::Type{AtariDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit)
function atari_verify(dataset::Dict, num_epochs::Int)
- @assert size(dataset["observation"]) == (atari_frame_size, atari_frame_size, num_epochs*samples_per_epoch)
+ @assert size(dataset["observation"]) ==
+ (atari_frame_size, atari_frame_size, num_epochs * samples_per_epoch)
@assert size(dataset["action"]) == (num_epochs * samples_per_epoch,)
@assert size(dataset["reward"]) == (num_epochs * samples_per_epoch,)
@assert size(dataset["terminal"]) == (num_epochs * samples_per_epoch,)
diff --git a/src/ReinforcementLearningDatasets/src/atari/register.jl b/src/ReinforcementLearningDatasets/src/atari/register.jl
index 8b4ddb81b..485a0ac31 100644
--- a/src/ReinforcementLearningDatasets/src/atari/register.jl
+++ b/src/ReinforcementLearningDatasets/src/atari/register.jl
@@ -9,19 +9,66 @@ function atari_params()
end
const ATARI_GAMES = [
- "air-raid", "alien", "amidar", "assault", "asterix",
- "asteroids", "atlantis", "bank-heist", "battle-zone", "beam-rider",
- "berzerk", "bowling", "boxing", "breakout", "carnival", "centipede",
- "chopper-command", "crazy-climber", "demon-attack",
- "double-dunk", "elevator-action", "enduro", "fishing-derby", "freeway",
- "frostbite", "gopher", "gravitar", "hero", "ice-hockey", "jamesbond",
- "journey-escape", "kangaroo", "krull", "kung-fu-master",
- "montezuma-revenge", "ms-pacman", "name-this-game", "phoenix",
- "pitfall", "pong", "pooyan", "private-eye", "qbert", "riverraid",
- "road-runner", "robotank", "seaquest", "skiing", "solaris",
- "space-invaders", "star-gunner", "tennis", "time-pilot", "tutankham",
- "up-n-down", "venture", "video-pinball", "wizard-of-wor",
- "yars-revenge", "zaxxon"
+ "air-raid",
+ "alien",
+ "amidar",
+ "assault",
+ "asterix",
+ "asteroids",
+ "atlantis",
+ "bank-heist",
+ "battle-zone",
+ "beam-rider",
+ "berzerk",
+ "bowling",
+ "boxing",
+ "breakout",
+ "carnival",
+ "centipede",
+ "chopper-command",
+ "crazy-climber",
+ "demon-attack",
+ "double-dunk",
+ "elevator-action",
+ "enduro",
+ "fishing-derby",
+ "freeway",
+ "frostbite",
+ "gopher",
+ "gravitar",
+ "hero",
+ "ice-hockey",
+ "jamesbond",
+ "journey-escape",
+ "kangaroo",
+ "krull",
+ "kung-fu-master",
+ "montezuma-revenge",
+ "ms-pacman",
+ "name-this-game",
+ "phoenix",
+ "pitfall",
+ "pong",
+ "pooyan",
+ "private-eye",
+ "qbert",
+ "riverraid",
+ "road-runner",
+ "robotank",
+ "seaquest",
+ "skiing",
+ "solaris",
+ "space-invaders",
+ "star-gunner",
+ "tennis",
+ "time-pilot",
+ "tutankham",
+ "up-n-down",
+ "venture",
+ "video-pinball",
+ "wizard-of-wor",
+ "yars-revenge",
+ "zaxxon",
]
game_name(game) = join(titlecase.(split(game, "-")))
@@ -48,9 +95,9 @@ function atari_init()
encountered during training into 5 replay datasets per game, resulting in a total of 300 datasets.
""",
"gs://atari-replay-datasets/dqn/$(game_name(game))/$index/replay_logs/";
- fetch_method = fetch_gc_bucket
- )
+ fetch_method = fetch_gc_bucket,
+ ),
)
end
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/common.jl b/src/ReinforcementLearningDatasets/src/common.jl
index bb3347000..ac574e031 100644
--- a/src/ReinforcementLearningDatasets/src/common.jl
+++ b/src/ReinforcementLearningDatasets/src/common.jl
@@ -31,10 +31,18 @@ fetch a gc bucket from `src` to `dest`.
"""
function fetch_gc_bucket(src, dest)
if Sys.iswindows()
- try run(`cmd /C gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end
+ try
+ run(`cmd /C gsutil -v`)
+ catch x
+ throw("gsutil not found, install gsutil to proceed further")
+ end
run(`cmd /C gsutil -m cp -r $src $dest`)
else
- try run(`gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end
+ try
+ run(`gsutil -v`)
+ catch x
+ throw("gsutil not found, install gsutil to proceed further")
+ end
run(`gsutil -m cp -r $src $dest`)
end
return dest
@@ -45,11 +53,19 @@ fetch a gc file from `src` to `dest`.
"""
function fetch_gc_file(src, dest)
if Sys.iswindows()
- try run(`cmd /C gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end
+ try
+ run(`cmd /C gsutil -v`)
+ catch x
+ throw("gsutil not found, install gsutil to proceed further")
+ end
run(`cmd /C gsutil -m cp $src $dest`)
else
- try run(`gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end
+ try
+ run(`gsutil -v`)
+ catch x
+ throw("gsutil not found, install gsutil to proceed further")
+ end
run(`gsutil -m cp $src $dest`)
end
return dest
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl
index 7d976bd66..eb5716068 100644
--- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl
+++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl
@@ -1,4 +1,4 @@
-export d4rl_dataset_params
+export d4rl_dataset_params
function d4rl_dataset_params()
dataset = keys(D4RL_DATASET_URLS)
@@ -6,7 +6,7 @@ function d4rl_dataset_params()
@info dataset repo
end
-const D4RL_DATASET_URLS = Dict{String, String}(
+const D4RL_DATASET_URLS = Dict{String,String}(
"maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5",
"maze2d-umaze-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5",
"maze2d-medium-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5",
@@ -62,209 +62,209 @@ const D4RL_DATASET_URLS = Dict{String, String}(
"antmaze-medium-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
"antmaze-large-play-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5",
"antmaze-large-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5",
- "flow-ring-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5",
- "flow-ring-controller-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5",
- "flow-merge-random-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5",
- "flow-merge-controller-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5",
+ "flow-ring-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5",
+ "flow-ring-controller-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5",
+ "flow-merge-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5",
+ "flow-merge-controller-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5",
"kitchen-complete-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5",
"kitchen-partial-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5",
"kitchen-mixed-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5",
- "carla-lane-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5",
- "carla-town-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5",
- "carla-town-full-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5",
- "bullet-halfcheetah-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5",
- "bullet-halfcheetah-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5",
- "bullet-halfcheetah-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5",
- "bullet-halfcheetah-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5",
- "bullet-halfcheetah-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5",
- "bullet-hopper-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5",
- "bullet-hopper-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5",
- "bullet-hopper-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5",
- "bullet-hopper-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5",
- "bullet-hopper-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5",
- "bullet-ant-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5",
- "bullet-ant-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5",
- "bullet-ant-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5",
- "bullet-ant-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5",
- "bullet-ant-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5",
- "bullet-walker2d-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5",
- "bullet-walker2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5",
- "bullet-walker2d-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5",
- "bullet-walker2d-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5",
- "bullet-walker2d-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5",
- "bullet-maze2d-open-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5",
- "bullet-maze2d-umaze-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5",
- "bullet-maze2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5",
- "bullet-maze2d-large-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5",
+ "carla-lane-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5",
+ "carla-town-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5",
+ "carla-town-full-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5",
+ "bullet-halfcheetah-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5",
+ "bullet-halfcheetah-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5",
+ "bullet-halfcheetah-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5",
+ "bullet-halfcheetah-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5",
+ "bullet-halfcheetah-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5",
+ "bullet-hopper-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5",
+ "bullet-hopper-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5",
+ "bullet-hopper-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5",
+ "bullet-hopper-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5",
+ "bullet-hopper-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5",
+ "bullet-ant-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5",
+ "bullet-ant-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5",
+ "bullet-ant-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5",
+ "bullet-ant-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5",
+ "bullet-ant-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5",
+ "bullet-walker2d-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5",
+ "bullet-walker2d-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5",
+ "bullet-walker2d-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5",
+ "bullet-walker2d-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5",
+ "bullet-walker2d-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5",
+ "bullet-maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5",
+ "bullet-maze2d-umaze-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5",
+ "bullet-maze2d-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5",
+ "bullet-maze2d-large-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5",
)
-const D4RL_REF_MIN_SCORE = Dict{String, Float32}(
- "maze2d-open-v0" => 0.01 ,
- "maze2d-umaze-v1" => 23.85 ,
- "maze2d-medium-v1" => 13.13 ,
- "maze2d-large-v1" => 6.7 ,
- "maze2d-open-dense-v0" => 11.17817 ,
- "maze2d-umaze-dense-v1" => 68.537689 ,
- "maze2d-medium-dense-v1" => 44.264742 ,
- "maze2d-large-dense-v1" => 30.569041 ,
- "minigrid-fourrooms-v0" => 0.01442 ,
- "minigrid-fourrooms-random-v0" => 0.01442 ,
- "pen-human-v0" => 96.262799 ,
- "pen-cloned-v0" => 96.262799 ,
- "pen-expert-v0" => 96.262799 ,
- "hammer-human-v0" => -274.856578 ,
- "hammer-cloned-v0" => -274.856578 ,
- "hammer-expert-v0" => -274.856578 ,
- "relocate-human-v0" => -6.425911 ,
- "relocate-cloned-v0" => -6.425911 ,
- "relocate-expert-v0" => -6.425911 ,
- "door-human-v0" => -56.512833 ,
- "door-cloned-v0" => -56.512833 ,
- "door-expert-v0" => -56.512833 ,
- "halfcheetah-random-v0" => -280.178953 ,
- "halfcheetah-medium-v0" => -280.178953 ,
- "halfcheetah-expert-v0" => -280.178953 ,
- "halfcheetah-medium-replay-v0" => -280.178953 ,
- "halfcheetah-medium-expert-v0" => -280.178953 ,
- "walker2d-random-v0" => 1.629008 ,
- "walker2d-medium-v0" => 1.629008 ,
- "walker2d-expert-v0" => 1.629008 ,
- "walker2d-medium-replay-v0" => 1.629008 ,
- "walker2d-medium-expert-v0" => 1.629008 ,
- "hopper-random-v0" => -20.272305 ,
- "hopper-medium-v0" => -20.272305 ,
- "hopper-expert-v0" => -20.272305 ,
- "hopper-medium-replay-v0" => -20.272305 ,
- "hopper-medium-expert-v0" => -20.272305 ,
+const D4RL_REF_MIN_SCORE = Dict{String,Float32}(
+ "maze2d-open-v0" => 0.01,
+ "maze2d-umaze-v1" => 23.85,
+ "maze2d-medium-v1" => 13.13,
+ "maze2d-large-v1" => 6.7,
+ "maze2d-open-dense-v0" => 11.17817,
+ "maze2d-umaze-dense-v1" => 68.537689,
+ "maze2d-medium-dense-v1" => 44.264742,
+ "maze2d-large-dense-v1" => 30.569041,
+ "minigrid-fourrooms-v0" => 0.01442,
+ "minigrid-fourrooms-random-v0" => 0.01442,
+ "pen-human-v0" => 96.262799,
+ "pen-cloned-v0" => 96.262799,
+ "pen-expert-v0" => 96.262799,
+ "hammer-human-v0" => -274.856578,
+ "hammer-cloned-v0" => -274.856578,
+ "hammer-expert-v0" => -274.856578,
+ "relocate-human-v0" => -6.425911,
+ "relocate-cloned-v0" => -6.425911,
+ "relocate-expert-v0" => -6.425911,
+ "door-human-v0" => -56.512833,
+ "door-cloned-v0" => -56.512833,
+ "door-expert-v0" => -56.512833,
+ "halfcheetah-random-v0" => -280.178953,
+ "halfcheetah-medium-v0" => -280.178953,
+ "halfcheetah-expert-v0" => -280.178953,
+ "halfcheetah-medium-replay-v0" => -280.178953,
+ "halfcheetah-medium-expert-v0" => -280.178953,
+ "walker2d-random-v0" => 1.629008,
+ "walker2d-medium-v0" => 1.629008,
+ "walker2d-expert-v0" => 1.629008,
+ "walker2d-medium-replay-v0" => 1.629008,
+ "walker2d-medium-expert-v0" => 1.629008,
+ "hopper-random-v0" => -20.272305,
+ "hopper-medium-v0" => -20.272305,
+ "hopper-expert-v0" => -20.272305,
+ "hopper-medium-replay-v0" => -20.272305,
+ "hopper-medium-expert-v0" => -20.272305,
"ant-random-v0" => -325.6,
"ant-medium-v0" => -325.6,
"ant-expert-v0" => -325.6,
"ant-medium-replay-v0" => -325.6,
"ant-medium-expert-v0" => -325.6,
- "antmaze-umaze-v0" => 0.0 ,
- "antmaze-umaze-diverse-v0" => 0.0 ,
- "antmaze-medium-play-v0" => 0.0 ,
- "antmaze-medium-diverse-v0" => 0.0 ,
- "antmaze-large-play-v0" => 0.0 ,
- "antmaze-large-diverse-v0" => 0.0 ,
- "kitchen-complete-v0" => 0.0 ,
- "kitchen-partial-v0" => 0.0 ,
- "kitchen-mixed-v0" => 0.0 ,
- "flow-ring-random-v0" => -165.22 ,
- "flow-ring-controller-v0" => -165.22 ,
- "flow-merge-random-v0" => 118.67993 ,
- "flow-merge-controller-v0" => 118.67993 ,
- "carla-lane-v0"=> -0.8503839912088142,
- "carla-town-v0"=> -114.81579500772153, # random score
- "bullet-halfcheetah-random-v0"=> -1275.766996,
- "bullet-halfcheetah-medium-v0"=> -1275.766996,
- "bullet-halfcheetah-expert-v0"=> -1275.766996,
- "bullet-halfcheetah-medium-expert-v0"=> -1275.766996,
- "bullet-halfcheetah-medium-replay-v0"=> -1275.766996,
- "bullet-hopper-random-v0"=> 20.058972,
- "bullet-hopper-medium-v0"=> 20.058972,
- "bullet-hopper-expert-v0"=> 20.058972,
- "bullet-hopper-medium-expert-v0"=> 20.058972,
- "bullet-hopper-medium-replay-v0"=> 20.058972,
- "bullet-ant-random-v0"=> 373.705955,
- "bullet-ant-medium-v0"=> 373.705955,
- "bullet-ant-expert-v0"=> 373.705955,
- "bullet-ant-medium-expert-v0"=> 373.705955,
- "bullet-ant-medium-replay-v0"=> 373.705955,
- "bullet-walker2d-random-v0"=> 16.523877,
- "bullet-walker2d-medium-v0"=> 16.523877,
- "bullet-walker2d-expert-v0"=> 16.523877,
- "bullet-walker2d-medium-expert-v0"=> 16.523877,
- "bullet-walker2d-medium-replay-v0"=> 16.523877,
- "bullet-maze2d-open-v0"=> 8.750000,
- "bullet-maze2d-umaze-v0"=> 32.460000,
- "bullet-maze2d-medium-v0"=> 14.870000,
- "bullet-maze2d-large-v0"=> 1.820000,
+ "antmaze-umaze-v0" => 0.0,
+ "antmaze-umaze-diverse-v0" => 0.0,
+ "antmaze-medium-play-v0" => 0.0,
+ "antmaze-medium-diverse-v0" => 0.0,
+ "antmaze-large-play-v0" => 0.0,
+ "antmaze-large-diverse-v0" => 0.0,
+ "kitchen-complete-v0" => 0.0,
+ "kitchen-partial-v0" => 0.0,
+ "kitchen-mixed-v0" => 0.0,
+ "flow-ring-random-v0" => -165.22,
+ "flow-ring-controller-v0" => -165.22,
+ "flow-merge-random-v0" => 118.67993,
+ "flow-merge-controller-v0" => 118.67993,
+ "carla-lane-v0" => -0.8503839912088142,
+ "carla-town-v0" => -114.81579500772153, # random score
+ "bullet-halfcheetah-random-v0" => -1275.766996,
+ "bullet-halfcheetah-medium-v0" => -1275.766996,
+ "bullet-halfcheetah-expert-v0" => -1275.766996,
+ "bullet-halfcheetah-medium-expert-v0" => -1275.766996,
+ "bullet-halfcheetah-medium-replay-v0" => -1275.766996,
+ "bullet-hopper-random-v0" => 20.058972,
+ "bullet-hopper-medium-v0" => 20.058972,
+ "bullet-hopper-expert-v0" => 20.058972,
+ "bullet-hopper-medium-expert-v0" => 20.058972,
+ "bullet-hopper-medium-replay-v0" => 20.058972,
+ "bullet-ant-random-v0" => 373.705955,
+ "bullet-ant-medium-v0" => 373.705955,
+ "bullet-ant-expert-v0" => 373.705955,
+ "bullet-ant-medium-expert-v0" => 373.705955,
+ "bullet-ant-medium-replay-v0" => 373.705955,
+ "bullet-walker2d-random-v0" => 16.523877,
+ "bullet-walker2d-medium-v0" => 16.523877,
+ "bullet-walker2d-expert-v0" => 16.523877,
+ "bullet-walker2d-medium-expert-v0" => 16.523877,
+ "bullet-walker2d-medium-replay-v0" => 16.523877,
+ "bullet-maze2d-open-v0" => 8.750000,
+ "bullet-maze2d-umaze-v0" => 32.460000,
+ "bullet-maze2d-medium-v0" => 14.870000,
+ "bullet-maze2d-large-v0" => 1.820000,
)
-const D4RL_REF_MAX_SCORE = Dict{String, Float32}(
- "maze2d-open-v0" => 20.66 ,
- "maze2d-umaze-v1" => 161.86 ,
- "maze2d-medium-v1" => 277.39 ,
- "maze2d-large-v1" => 273.99 ,
- "maze2d-open-dense-v0" => 27.166538620695782 ,
- "maze2d-umaze-dense-v1" => 193.66285642381482 ,
- "maze2d-medium-dense-v1" => 297.4552547777125 ,
- "maze2d-large-dense-v1" => 303.4857382709002 ,
- "minigrid-fourrooms-v0" => 2.89685 ,
- "minigrid-fourrooms-random-v0" => 2.89685 ,
- "pen-human-v0" => 3076.8331017826877 ,
- "pen-cloned-v0" => 3076.8331017826877 ,
- "pen-expert-v0" => 3076.8331017826877 ,
- "hammer-human-v0" => 12794.134825156867 ,
- "hammer-cloned-v0" => 12794.134825156867 ,
- "hammer-expert-v0" => 12794.134825156867 ,
- "relocate-human-v0" => 4233.877797728884 ,
- "relocate-cloned-v0" => 4233.877797728884 ,
- "relocate-expert-v0" => 4233.877797728884 ,
- "door-human-v0" => 2880.5693087298737 ,
- "door-cloned-v0" => 2880.5693087298737 ,
- "door-expert-v0" => 2880.5693087298737 ,
- "halfcheetah-random-v0" => 12135.0 ,
- "halfcheetah-medium-v0" => 12135.0 ,
- "halfcheetah-expert-v0" => 12135.0 ,
- "halfcheetah-medium-replay-v0" => 12135.0 ,
- "halfcheetah-medium-expert-v0" => 12135.0 ,
- "walker2d-random-v0" => 4592.3 ,
- "walker2d-medium-v0" => 4592.3 ,
- "walker2d-expert-v0" => 4592.3 ,
- "walker2d-medium-replay-v0" => 4592.3 ,
- "walker2d-medium-expert-v0" => 4592.3 ,
- "hopper-random-v0" => 3234.3 ,
- "hopper-medium-v0" => 3234.3 ,
- "hopper-expert-v0" => 3234.3 ,
- "hopper-medium-replay-v0" => 3234.3 ,
- "hopper-medium-expert-v0" => 3234.3 ,
+const D4RL_REF_MAX_SCORE = Dict{String,Float32}(
+ "maze2d-open-v0" => 20.66,
+ "maze2d-umaze-v1" => 161.86,
+ "maze2d-medium-v1" => 277.39,
+ "maze2d-large-v1" => 273.99,
+ "maze2d-open-dense-v0" => 27.166538620695782,
+ "maze2d-umaze-dense-v1" => 193.66285642381482,
+ "maze2d-medium-dense-v1" => 297.4552547777125,
+ "maze2d-large-dense-v1" => 303.4857382709002,
+ "minigrid-fourrooms-v0" => 2.89685,
+ "minigrid-fourrooms-random-v0" => 2.89685,
+ "pen-human-v0" => 3076.8331017826877,
+ "pen-cloned-v0" => 3076.8331017826877,
+ "pen-expert-v0" => 3076.8331017826877,
+ "hammer-human-v0" => 12794.134825156867,
+ "hammer-cloned-v0" => 12794.134825156867,
+ "hammer-expert-v0" => 12794.134825156867,
+ "relocate-human-v0" => 4233.877797728884,
+ "relocate-cloned-v0" => 4233.877797728884,
+ "relocate-expert-v0" => 4233.877797728884,
+ "door-human-v0" => 2880.5693087298737,
+ "door-cloned-v0" => 2880.5693087298737,
+ "door-expert-v0" => 2880.5693087298737,
+ "halfcheetah-random-v0" => 12135.0,
+ "halfcheetah-medium-v0" => 12135.0,
+ "halfcheetah-expert-v0" => 12135.0,
+ "halfcheetah-medium-replay-v0" => 12135.0,
+ "halfcheetah-medium-expert-v0" => 12135.0,
+ "walker2d-random-v0" => 4592.3,
+ "walker2d-medium-v0" => 4592.3,
+ "walker2d-expert-v0" => 4592.3,
+ "walker2d-medium-replay-v0" => 4592.3,
+ "walker2d-medium-expert-v0" => 4592.3,
+ "hopper-random-v0" => 3234.3,
+ "hopper-medium-v0" => 3234.3,
+ "hopper-expert-v0" => 3234.3,
+ "hopper-medium-replay-v0" => 3234.3,
+ "hopper-medium-expert-v0" => 3234.3,
"ant-random-v0" => 3879.7,
"ant-medium-v0" => 3879.7,
"ant-expert-v0" => 3879.7,
"ant-medium-replay-v0" => 3879.7,
"ant-medium-expert-v0" => 3879.7,
- "antmaze-umaze-v0" => 1.0 ,
- "antmaze-umaze-diverse-v0" => 1.0 ,
- "antmaze-medium-play-v0" => 1.0 ,
- "antmaze-medium-diverse-v0" => 1.0 ,
- "antmaze-large-play-v0" => 1.0 ,
- "antmaze-large-diverse-v0" => 1.0 ,
- "kitchen-complete-v0" => 4.0 ,
- "kitchen-partial-v0" => 4.0 ,
- "kitchen-mixed-v0" => 4.0 ,
- "flow-ring-random-v0" => 24.42 ,
- "flow-ring-controller-v0" => 24.42 ,
- "flow-merge-random-v0" => 330.03179 ,
- "flow-merge-controller-v0" => 330.03179 ,
- "carla-lane-v0"=> 1023.5784385429523,
- "carla-town-v0"=> 2440.1772022247314, # avg dataset score
- "bullet-halfcheetah-random-v0"=> 2381.6725,
- "bullet-halfcheetah-medium-v0"=> 2381.6725,
- "bullet-halfcheetah-expert-v0"=> 2381.6725,
- "bullet-halfcheetah-medium-expert-v0"=> 2381.6725,
- "bullet-halfcheetah-medium-replay-v0"=> 2381.6725,
- "bullet-hopper-random-v0"=> 1441.8059623430963,
- "bullet-hopper-medium-v0"=> 1441.8059623430963,
- "bullet-hopper-expert-v0"=> 1441.8059623430963,
- "bullet-hopper-medium-expert-v0"=> 1441.8059623430963,
- "bullet-hopper-medium-replay-v0"=> 1441.8059623430963,
- "bullet-ant-random-v0"=> 2650.495,
- "bullet-ant-medium-v0"=> 2650.495,
- "bullet-ant-expert-v0"=> 2650.495,
- "bullet-ant-medium-expert-v0"=> 2650.495,
- "bullet-ant-medium-replay-v0"=> 2650.495,
- "bullet-walker2d-random-v0"=> 1623.6476303317536,
- "bullet-walker2d-medium-v0"=> 1623.6476303317536,
- "bullet-walker2d-expert-v0"=> 1623.6476303317536,
- "bullet-walker2d-medium-expert-v0"=> 1623.6476303317536,
- "bullet-walker2d-medium-replay-v0"=> 1623.6476303317536,
- "bullet-maze2d-open-v0"=> 64.15,
- "bullet-maze2d-umaze-v0"=> 153.99,
- "bullet-maze2d-medium-v0"=> 238.05,
- "bullet-maze2d-large-v0"=> 285.92,
+ "antmaze-umaze-v0" => 1.0,
+ "antmaze-umaze-diverse-v0" => 1.0,
+ "antmaze-medium-play-v0" => 1.0,
+ "antmaze-medium-diverse-v0" => 1.0,
+ "antmaze-large-play-v0" => 1.0,
+ "antmaze-large-diverse-v0" => 1.0,
+ "kitchen-complete-v0" => 4.0,
+ "kitchen-partial-v0" => 4.0,
+ "kitchen-mixed-v0" => 4.0,
+ "flow-ring-random-v0" => 24.42,
+ "flow-ring-controller-v0" => 24.42,
+ "flow-merge-random-v0" => 330.03179,
+ "flow-merge-controller-v0" => 330.03179,
+ "carla-lane-v0" => 1023.5784385429523,
+ "carla-town-v0" => 2440.1772022247314, # avg dataset score
+ "bullet-halfcheetah-random-v0" => 2381.6725,
+ "bullet-halfcheetah-medium-v0" => 2381.6725,
+ "bullet-halfcheetah-expert-v0" => 2381.6725,
+ "bullet-halfcheetah-medium-expert-v0" => 2381.6725,
+ "bullet-halfcheetah-medium-replay-v0" => 2381.6725,
+ "bullet-hopper-random-v0" => 1441.8059623430963,
+ "bullet-hopper-medium-v0" => 1441.8059623430963,
+ "bullet-hopper-expert-v0" => 1441.8059623430963,
+ "bullet-hopper-medium-expert-v0" => 1441.8059623430963,
+ "bullet-hopper-medium-replay-v0" => 1441.8059623430963,
+ "bullet-ant-random-v0" => 2650.495,
+ "bullet-ant-medium-v0" => 2650.495,
+ "bullet-ant-expert-v0" => 2650.495,
+ "bullet-ant-medium-expert-v0" => 2650.495,
+ "bullet-ant-medium-replay-v0" => 2650.495,
+ "bullet-walker2d-random-v0" => 1623.6476303317536,
+ "bullet-walker2d-medium-v0" => 1623.6476303317536,
+ "bullet-walker2d-expert-v0" => 1623.6476303317536,
+ "bullet-walker2d-medium-expert-v0" => 1623.6476303317536,
+ "bullet-walker2d-medium-replay-v0" => 1623.6476303317536,
+ "bullet-maze2d-open-v0" => 64.15,
+ "bullet-maze2d-umaze-v0" => 153.99,
+ "bullet-maze2d-medium-v0" => 238.05,
+ "bullet-maze2d-large-v0" => 285.92,
)
# give a prompt for flow and carla tasks
@@ -274,20 +274,20 @@ function d4rl_init()
for ds in keys(D4RL_DATASET_URLS)
register(
DataDep(
- repo*"-"* ds,
+ repo * "-" * ds,
"""
Credits: https://arxiv.org/abs/2004.07219
The following dataset is fetched from the d4rl.
The dataset is fetched and modified in a form that is useful for RL.jl package.
-
+
Dataset information:
Name: $(ds)
$(if ds in keys(D4RL_REF_MAX_SCORE) "MAXIMUM_SCORE: " * string(D4RL_REF_MAX_SCORE[ds]) end)
$(if ds in keys(D4RL_REF_MIN_SCORE) "MINIMUM_SCORE: " * string(D4RL_REF_MIN_SCORE[ds]) end)
""", #check if the MAX and MIN score part is even necessary and make the log file prettier
D4RL_DATASET_URLS[ds],
- )
+ ),
)
end
nothing
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl
index 105b8e65f..0ed98335e 100644
--- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl
+++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl
@@ -21,7 +21,7 @@ Represents an `Iterable` dataset with the following fields:
- `is_shuffle::Bool`: determines if the batches returned by `iterate` are shuffled.
"""
struct D4RLDataSet{T<:AbstractRNG} <: RLDataSet
- dataset::Dict{Symbol, Any}
+ dataset::Dict{Symbol,Any}
repo::String
dataset_size::Integer
batch_size::Integer
@@ -58,42 +58,47 @@ been tested in this package yet.
"""
function dataset(
dataset::String;
- repo::String="d4rl",
- style::NTuple=SARTS,
- rng::AbstractRNG=MersenneTwister(123),
- is_shuffle::Bool=true,
- batch_size::Int=256
+ repo::String = "d4rl",
+ style::NTuple = SARTS,
+ rng::AbstractRNG = MersenneTwister(123),
+ is_shuffle::Bool = true,
+ batch_size::Int = 256,
)
-
- try
- @datadep_str repo*"-"*dataset
+
+ try
+ @datadep_str repo * "-" * dataset
catch e
if isa(e, KeyError)
- throw("Invalid params, check out d4rl_pybullet_dataset_params() or d4rl_dataset_params()")
+ throw(
+ "Invalid params, check out d4rl_pybullet_dataset_params() or d4rl_dataset_params()",
+ )
end
end
-
- path = @datadep_str repo*"-"*dataset
+
+ path = @datadep_str repo * "-" * dataset
@assert length(readdir(path)) == 1
file_name = readdir(path)[1]
-
- data = h5open(path*"/"*file_name, "r") do file
+
+ data = h5open(path * "/" * file_name, "r") do file
read(file)
end
# sanity checks on data
d4rl_verify(data)
- dataset = Dict{Symbol, Any}()
- meta = Dict{String, Any}()
+ dataset = Dict{Symbol,Any}()
+ meta = Dict{String,Any}()
N_samples = size(data["observations"])[2]
-
- for (key, d_key) in zip(["observations", "actions", "rewards", "terminals"], Symbol.(["state", "action", "reward", "terminal"]))
- dataset[d_key] = data[key]
+
+ for (key, d_key) in zip(
+ ["observations", "actions", "rewards", "terminals"],
+ Symbol.(["state", "action", "reward", "terminal"]),
+ )
+ dataset[d_key] = data[key]
end
-
+
for key in keys(data)
if !(key in ["observations", "actions", "rewards", "terminals"])
meta[key] = data[key]
@@ -113,9 +118,13 @@ function iterate(ds::D4RLDataSet, state = 0)
if is_shuffle
inds = rand(rng, 1:size, batch_size)
- map((x)-> if x <= size x else 1 end, inds)
+ map((x) -> if x <= size
+ x
+ else
+ 1
+ end, inds)
else
- if (state+1) * batch_size <= size
+ if (state + 1) * batch_size <= size
inds = state*batch_size+1:(state+1)*batch_size
else
return nothing
@@ -123,15 +132,17 @@ function iterate(ds::D4RLDataSet, state = 0)
state += 1
end
- batch = (state = copy(ds.dataset[:state][:, inds]),
- action = copy(ds.dataset[:action][:, inds]),
- reward = copy(ds.dataset[:reward][inds]),
- terminal = copy(ds.dataset[:terminal][inds]))
+ batch = (
+ state = copy(ds.dataset[:state][:, inds]),
+ action = copy(ds.dataset[:action][:, inds]),
+ reward = copy(ds.dataset[:reward][inds]),
+ terminal = copy(ds.dataset[:terminal][inds]),
+ )
if style == SARTS
batch = merge(batch, (next_state = copy(ds.dataset[:state][:, (1).+(inds)]),))
end
-
+
return batch, state
end
@@ -141,11 +152,12 @@ length(ds::D4RLDataSet) = ds.dataset_size
IteratorEltype(::Type{D4RLDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit)
-function d4rl_verify(data::Dict{String, Any})
+function d4rl_verify(data::Dict{String,Any})
for key in ["observations", "actions", "rewards", "terminals"]
@assert (key in keys(data)) "Expected keys not present in data"
end
N_samples = size(data["observations"])[2]
@assert size(data["rewards"]) == (N_samples,) || size(data["rewards"]) == (1, N_samples)
- @assert size(data["terminals"]) == (N_samples,) || size(data["terminals"]) == (1, N_samples)
+ @assert size(data["terminals"]) == (N_samples,) ||
+ size(data["terminals"]) == (1, N_samples)
end
diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl
index a8aba9748..7355c90eb 100644
--- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl
+++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl
@@ -7,18 +7,18 @@ function d4rl_pybullet_dataset_params()
end
const D4RL_PYBULLET_URLS = Dict(
- "hopper-bullet-mixed-v0" => "https://www.dropbox.com/s/xv3p0h7dzgxt8xb/hopper-bullet-mixed-v0.hdf5?dl=1",
- "walker2d-bullet-random-v0" => "https://www.dropbox.com/s/1gwcfl2nmx6878m/walker2d-bullet-random-v0.hdf5?dl=1",
- "hopper-bullet-medium-v0" => "https://www.dropbox.com/s/w22kgzldn6eng7j/hopper-bullet-medium-v0.hdf5?dl=1",
+ "hopper-bullet-mixed-v0" => "https://www.dropbox.com/s/xv3p0h7dzgxt8xb/hopper-bullet-mixed-v0.hdf5?dl=1",
+ "walker2d-bullet-random-v0" => "https://www.dropbox.com/s/1gwcfl2nmx6878m/walker2d-bullet-random-v0.hdf5?dl=1",
+ "hopper-bullet-medium-v0" => "https://www.dropbox.com/s/w22kgzldn6eng7j/hopper-bullet-medium-v0.hdf5?dl=1",
"walker2d-bullet-mixed-v0" => "https://www.dropbox.com/s/i4u2ii0d85iblou/walker2d-bullet-mixed-v0.hdf5?dl=1",
- "halfcheetah-bullet-mixed-v0" => "https://www.dropbox.com/s/scj1rqun963aw90/halfcheetah-bullet-mixed-v0.hdf5?dl=1",
+ "halfcheetah-bullet-mixed-v0" => "https://www.dropbox.com/s/scj1rqun963aw90/halfcheetah-bullet-mixed-v0.hdf5?dl=1",
"halfcheetah-bullet-random-v0" => "https://www.dropbox.com/s/jnvpb1hp60zt2ak/halfcheetah-bullet-random-v0.hdf5?dl=1",
- "walker2d-bullet-medium-v0" => "https://www.dropbox.com/s/v0f2kz48b1hw6or/walker2d-bullet-medium-v0.hdf5?dl=1",
- "hopper-bullet-random-v0" => "https://www.dropbox.com/s/bino8ojd7iq4p4d/hopper-bullet-random-v0.hdf5?dl=1",
- "ant-bullet-random-v0" => "https://www.dropbox.com/s/2xpmh4wk2m7i8xh/ant-bullet-random-v0.hdf5?dl=1",
- "halfcheetah-bullet-medium-v0" => "https://www.dropbox.com/s/v4xgssp1w968a9l/halfcheetah-bullet-medium-v0.hdf5?dl=1",
- "ant-bullet-medium-v0" => "https://www.dropbox.com/s/6n79kwd94xthr1t/ant-bullet-medium-v0.hdf5?dl=1",
- "ant-bullet-mixed-v0" => "https://www.dropbox.com/s/pmy3dzab35g4whk/ant-bullet-mixed-v0.hdf5?dl=1"
+ "walker2d-bullet-medium-v0" => "https://www.dropbox.com/s/v0f2kz48b1hw6or/walker2d-bullet-medium-v0.hdf5?dl=1",
+ "hopper-bullet-random-v0" => "https://www.dropbox.com/s/bino8ojd7iq4p4d/hopper-bullet-random-v0.hdf5?dl=1",
+ "ant-bullet-random-v0" => "https://www.dropbox.com/s/2xpmh4wk2m7i8xh/ant-bullet-random-v0.hdf5?dl=1",
+ "halfcheetah-bullet-medium-v0" => "https://www.dropbox.com/s/v4xgssp1w968a9l/halfcheetah-bullet-medium-v0.hdf5?dl=1",
+ "ant-bullet-medium-v0" => "https://www.dropbox.com/s/6n79kwd94xthr1t/ant-bullet-medium-v0.hdf5?dl=1",
+ "ant-bullet-mixed-v0" => "https://www.dropbox.com/s/pmy3dzab35g4whk/ant-bullet-mixed-v0.hdf5?dl=1",
)
function d4rl_pybullet_init()
@@ -26,14 +26,14 @@ function d4rl_pybullet_init()
for ds in keys(D4RL_PYBULLET_URLS)
register(
DataDep(
- repo* "-" * ds,
+ repo * "-" * ds,
"""
Credits: https://github.com/takuseno/d4rl-pybullet
The following dataset is fetched from the d4rl-pybullet.
- """,
+ """,
D4RL_PYBULLET_URLS[ds],
- )
+ ),
)
end
nothing
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl
index eaa1d1d41..c8e7f8281 100644
--- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl
+++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl
@@ -2,7 +2,7 @@ export d4rl_policy_params
function d4rl_policy_params()
d4rl_policy_paths = [split(policy["policy_path"], "/")[2] for policy in D4RL_POLICIES]
- env = Set(join.(map(x->x[1:end-2], split.(d4rl_policy_paths, "_")), "_"))
+ env = Set(join.(map(x -> x[1:end-2], split.(d4rl_policy_paths, "_")), "_"))
agent = ["dapg", "online"]
epoch = 0:10
@@ -12,322 +12,234 @@ end
const D4RL_POLICIES = [
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_0.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.0,
- "return_std =>" => 0.0
+ "return_std =>" => 0.0,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_10.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.48,
- "return_std =>" => 0.4995998398718718
+ "return_std =>" => 0.4995998398718718,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_1.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.0,
- "return_std =>" => 0.0
+ "return_std =>" => 0.0,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_2.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.0,
- "return_std =>" => 0.0
+ "return_std =>" => 0.0,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_3.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.0,
- "return_std =>" => 0.0
+ "return_std =>" => 0.0,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_4.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.01,
- "return_std =>" => 0.09949874371066199
+ "return_std =>" => 0.09949874371066199,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_5.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.13,
- "return_std =>" => 0.33630343441600474
+ "return_std =>" => 0.33630343441600474,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_6.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.22,
- "return_std =>" => 0.41424630354415964
+ "return_std =>" => 0.41424630354415964,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_7.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.12,
- "return_std =>" => 0.32496153618543844
+ "return_std =>" => 0.32496153618543844,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_8.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.39,
- "return_std =>" => 0.487749935930288
+ "return_std =>" => 0.487749935930288,
),
Dict(
"policy_path" => "antmaze_large/antmaze_large_dapg_9.pkl",
- "task.task_names" => [
- "antmaze-large-play-v0",
- "antmaze-large-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"],
"agent_name" => "BC",
"return_mean" => 0.49,
- "return_std =>" => 0.4998999899979995
+ "return_std =>" => 0.4998999899979995,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_0.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.66,
- "return_std =>" => 0.4737087712930805
+ "return_std =>" => 0.4737087712930805,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_10.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.12,
- "return_std =>" => 0.32496153618543844
+ "return_std =>" => 0.32496153618543844,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_1.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.53,
- "return_std =>" => 0.49909918853871116
+ "return_std =>" => 0.49909918853871116,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_2.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.66,
- "return_std =>" => 0.4737087712930805
+ "return_std =>" => 0.4737087712930805,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_3.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.57,
- "return_std =>" => 0.49507575177946245
+ "return_std =>" => 0.49507575177946245,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_4.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.58,
- "return_std =>" => 0.49355850717012273
+ "return_std =>" => 0.49355850717012273,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_5.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.42,
- "return_std =>" => 0.49355850717012273
+ "return_std =>" => 0.49355850717012273,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_6.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.45,
- "return_std =>" => 0.49749371855331004
+ "return_std =>" => 0.49749371855331004,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_7.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.27,
- "return_std =>" => 0.4439594576084623
+ "return_std =>" => 0.4439594576084623,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_8.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.1,
- "return_std =>" => 0.29999999999999993
+ "return_std =>" => 0.29999999999999993,
),
Dict(
"policy_path" => "antmaze_medium/antmaze_medium_dapg_9.pkl",
- "task.task_names" => [
- "antmaze-medium-play-v0",
- "antmaze-medium-diverse-v0"
- ],
+ "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.15,
- "return_std =>" => 0.3570714214271425
+ "return_std =>" => 0.3570714214271425,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_0.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.11,
- "return_std =>" => 0.31288975694324034
+ "return_std =>" => 0.31288975694324034,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_10.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.84,
- "return_std =>" => 0.36660605559646725
+ "return_std =>" => 0.36660605559646725,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_1.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.15,
- "return_std =>" => 0.3570714214271425
+ "return_std =>" => 0.3570714214271425,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_2.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.08,
- "return_std =>" => 0.2712931993250107
+ "return_std =>" => 0.2712931993250107,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_3.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.13,
- "return_std =>" => 0.33630343441600474
+ "return_std =>" => 0.33630343441600474,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_4.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.19,
- "return_std =>" => 0.3923009049186606
+ "return_std =>" => 0.3923009049186606,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_5.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.27,
- "return_std =>" => 0.4439594576084623
+ "return_std =>" => 0.4439594576084623,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_6.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.41,
- "return_std =>" => 0.4918333050943175
+ "return_std =>" => 0.4918333050943175,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_7.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.66,
- "return_std =>" => 0.4737087712930805
+ "return_std =>" => 0.4737087712930805,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_8.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.72,
- "return_std =>" => 0.4489988864128729
+ "return_std =>" => 0.4489988864128729,
),
Dict(
"policy_path" => "antmaze_umaze/antmaze_umaze_dapg_9.pkl",
- "task.task_names" => [
- "antmaze-umaze-v0"
- ],
+ "task.task_names" => ["antmaze-umaze-v0"],
"agent_name" => "DAPG",
"return_mean" => 0.45,
- "return_std =>" => 0.49749371855331
+ "return_std =>" => 0.49749371855331,
),
Dict(
"policy_path" => "ant/ant_online_0.pkl",
@@ -336,11 +248,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => -61.02055183305979,
- "return_std =>" => 118.86259895376526
+ "return_std =>" => 118.86259895376526,
),
Dict(
"policy_path" => "ant/ant_online_10.pkl",
@@ -349,11 +261,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 5226.071929273204,
- "return_std =>" => 1351.489114884685
+ "return_std =>" => 1351.489114884685,
),
Dict(
"policy_path" => "ant/ant_online_1.pkl",
@@ -362,11 +274,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1128.957315814236,
- "return_std =>" => 545.9910621405912
+ "return_std =>" => 545.9910621405912,
),
Dict(
"policy_path" => "ant/ant_online_2.pkl",
@@ -375,11 +287,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1874.9426222623788,
- "return_std =>" => 821.523301172575
+ "return_std =>" => 821.523301172575,
),
Dict(
"policy_path" => "ant/ant_online_3.pkl",
@@ -388,11 +300,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 2694.0050365558186,
- "return_std =>" => 829.1251729756312
+ "return_std =>" => 829.1251729756312,
),
Dict(
"policy_path" => "ant/ant_online_4.pkl",
@@ -401,11 +313,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 2927.728155987557,
- "return_std =>" => 1218.962159178784
+ "return_std =>" => 1218.962159178784,
),
Dict(
"policy_path" => "ant/ant_online_5.pkl",
@@ -414,11 +326,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => -271.0455967662947,
- "return_std =>" => 181.7343490946006
+ "return_std =>" => 181.7343490946006,
),
Dict(
"policy_path" => "ant/ant_online_6.pkl",
@@ -427,11 +339,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 3923.0820284011284,
- "return_std =>" => 1384.459574872169
+ "return_std =>" => 1384.459574872169,
),
Dict(
"policy_path" => "ant/ant_online_7.pkl",
@@ -440,11 +352,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 4564.024787293475,
- "return_std =>" => 1207.181426135141
+ "return_std =>" => 1207.181426135141,
),
Dict(
"policy_path" => "ant/ant_online_8.pkl",
@@ -453,11 +365,11 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 5116.58562094113,
- "return_std =>" => 962.8694737383373
+ "return_std =>" => 962.8694737383373,
),
Dict(
"policy_path" => "ant/ant_online_9.pkl",
@@ -466,132 +378,88 @@ const D4RL_POLICIES = [
"ant-random-v0",
"ant-expert-v0",
"ant-medium-replay-v0",
- "ant-medium-expert-v0"
+ "ant-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 5176.548960934259,
- "return_std =>" => 1000.122269767824
+ "return_std =>" => 1000.122269767824,
),
Dict(
"policy_path" => "door/door_dapg_0.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => -53.63337645679012,
- "return_std =>" => 2.0058239428094895
+ "return_std =>" => 2.0058239428094895,
),
Dict(
"policy_path" => "door/door_dapg_10.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2974.9306587121887,
- "return_std =>" => 52.48250668645121
+ "return_std =>" => 52.48250668645121,
),
Dict(
"policy_path" => "door/door_dapg_1.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => -51.41658735064874,
- "return_std =>" => 0.6978335854285623
+ "return_std =>" => 0.6978335854285623,
),
Dict(
"policy_path" => "door/door_dapg_2.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 86.28632719532406,
- "return_std =>" => 256.30747202806475
+ "return_std =>" => 256.30747202806475,
),
Dict(
"policy_path" => "door/door_dapg_3.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 1282.0275007615646,
- "return_std =>" => 633.9669441391286
+ "return_std =>" => 633.9669441391286,
),
Dict(
"policy_path" => "door/door_dapg_4.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 1607.4255566289276,
- "return_std =>" => 499.58651630841575
+ "return_std =>" => 499.58651630841575,
),
Dict(
"policy_path" => "door/door_dapg_5.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2142.36638691816,
- "return_std =>" => 442.0537003890031
+ "return_std =>" => 442.0537003890031,
),
Dict(
"policy_path" => "door/door_dapg_6.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2525.495218483574,
- "return_std =>" => 160.8683834534215
+ "return_std =>" => 160.8683834534215,
),
Dict(
"policy_path" => "door/door_dapg_7.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2794.653907232321,
- "return_std =>" => 62.78226619278986
+ "return_std =>" => 62.78226619278986,
),
Dict(
"policy_path" => "door/door_dapg_8.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2870.85173247603,
- "return_std =>" => 37.96052715176604
+ "return_std =>" => 37.96052715176604,
),
Dict(
"policy_path" => "door/door_dapg_9.pkl",
- "task.task_names" => [
- "door-cloned-v0",
- "door-expert-v0",
- "door-human-v0"
- ],
+ "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2959.4718836123457,
- "return_std =>" => 53.31391818495784
+ "return_std =>" => 53.31391818495784,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_0.pkl",
@@ -600,11 +468,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => -309.2417932614121,
- "return_std =>" => 91.3640277992432
+ "return_std =>" => 91.3640277992432,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_10.pkl",
@@ -613,11 +481,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 12695.696030461002,
- "return_std =>" => 209.98612023443096
+ "return_std =>" => 209.98612023443096,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_1.pkl",
@@ -626,11 +494,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 5686.148033603298,
- "return_std =>" => 77.60317050580818
+ "return_std =>" => 77.60317050580818,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_2.pkl",
@@ -639,11 +507,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 6898.252473142946,
- "return_std =>" => 131.2808199171071
+ "return_std =>" => 131.2808199171071,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_3.pkl",
@@ -652,11 +520,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 7843.345957832609,
- "return_std =>" => 119.82879594969056
+ "return_std =>" => 119.82879594969056,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_4.pkl",
@@ -665,11 +533,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 8661.367146815282,
- "return_std =>" => 142.1433195543218
+ "return_std =>" => 142.1433195543218,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_5.pkl",
@@ -678,11 +546,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 9197.889639800613,
- "return_std =>" => 125.40543058761767
+ "return_std =>" => 125.40543058761767,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_6.pkl",
@@ -691,11 +559,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 9623.789519132608,
- "return_std =>" => 130.91946985245835
+ "return_std =>" => 130.91946985245835,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_7.pkl",
@@ -704,11 +572,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 10255.26711299773,
- "return_std =>" => 173.52116806555978
+ "return_std =>" => 173.52116806555978,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_8.pkl",
@@ -717,11 +585,11 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 10899.460856799158,
- "return_std =>" => 324.2557642475202
+ "return_std =>" => 324.2557642475202,
),
Dict(
"policy_path" => "halfcheetah/halfcheetah_online_9.pkl",
@@ -730,132 +598,99 @@ const D4RL_POLICIES = [
"halfcheetah-random-v0",
"halfcheetah-expert-v0",
"halfcheetah-medium-replay-v0",
- "halfcheetah-medium-expert-v0"
+ "halfcheetah-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 11829.054827593913,
- "return_std =>" => 240.63510160394745
+ "return_std =>" => 240.63510160394745,
),
Dict(
"policy_path" => "hammer/hammer_dapg_0.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => -236.37114898868305,
- "return_std =>" => 5.2941436284324075
+ "return_std =>" => 5.2941436284324075,
),
Dict(
"policy_path" => "hammer/hammer_dapg_10.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 17585.58837262877,
- "return_std =>" => 96.53489547795978
+ "return_std =>" => 96.53489547795978,
),
Dict(
"policy_path" => "hammer/hammer_dapg_1.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 128.60395654435058,
- "return_std =>" => 30.68441678661929
+ "return_std =>" => 30.68441678661929,
),
Dict(
"policy_path" => "hammer/hammer_dapg_2.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 7408.354956936379,
- "return_std =>" => 7294.096332941535
+ "return_std =>" => 7294.096332941535,
),
Dict(
"policy_path" => "hammer/hammer_dapg_3.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 15594.112899701715,
- "return_std =>" => 197.28904701529942
+ "return_std =>" => 197.28904701529942,
),
Dict(
"policy_path" => "hammer/hammer_dapg_4.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 16245.548923178216,
- "return_std =>" => 262.7060238728634
+ "return_std =>" => 262.7060238728634,
),
Dict(
"policy_path" => "hammer/hammer_dapg_5.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 16595.136728219404,
- "return_std =>" => 124.5270089215883
+ "return_std =>" => 124.5270089215883,
),
Dict(
"policy_path" => "hammer/hammer_dapg_6.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 17065.590900836418,
- "return_std =>" => 55.85140116556182
+ "return_std =>" => 55.85140116556182,
),
Dict(
"policy_path" => "hammer/hammer_dapg_7.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 17209.380445590097,
- "return_std =>" => 35.922080086069116
+ "return_std =>" => 35.922080086069116,
),
Dict(
"policy_path" => "hammer/hammer_dapg_8.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 17388.10343669515,
- "return_std =>" => 71.04818789434533
+ "return_std =>" => 71.04818789434533,
),
Dict(
"policy_path" => "hammer/hammer_dapg_9.pkl",
- "task.task_names" => [
- "hammer-cloned-v0",
- "hammer-expert-v0",
- "hammer-human-v0"
- ],
+ "task.task_names" =>
+ ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 17565.807571496796,
- "return_std =>" => 83.22119300427666
+ "return_std =>" => 83.22119300427666,
),
Dict(
"policy_path" => "hopper/hopper_online_0.pkl",
@@ -864,11 +699,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 89.08207455972816,
- "return_std =>" => 45.69740377810402
+ "return_std =>" => 45.69740377810402,
),
Dict(
"policy_path" => "hopper/hopper_online_10.pkl",
@@ -877,11 +712,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1290.7677147248753,
- "return_std =>" => 86.34701290680572
+ "return_std =>" => 86.34701290680572,
),
Dict(
"policy_path" => "hopper/hopper_online_1.pkl",
@@ -890,11 +725,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1134.244611055915,
- "return_std =>" => 407.6547443287992
+ "return_std =>" => 407.6547443287992,
),
Dict(
"policy_path" => "hopper/hopper_online_2.pkl",
@@ -903,11 +738,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 727.0768143435397,
- "return_std =>" => 92.94955320157855
+ "return_std =>" => 92.94955320157855,
),
Dict(
"policy_path" => "hopper/hopper_online_3.pkl",
@@ -916,11 +751,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1571.2810005160163,
- "return_std =>" => 447.3216244940128
+ "return_std =>" => 447.3216244940128,
),
Dict(
"policy_path" => "hopper/hopper_online_4.pkl",
@@ -929,11 +764,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1140.2394986005213,
- "return_std =>" => 671.1379607505328
+ "return_std =>" => 671.1379607505328,
),
Dict(
"policy_path" => "hopper/hopper_online_5.pkl",
@@ -942,11 +777,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1872.571834592923,
- "return_std =>" => 793.8865779126361
+ "return_std =>" => 793.8865779126361,
),
Dict(
"policy_path" => "hopper/hopper_online_6.pkl",
@@ -955,11 +790,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 3088.2017624993064,
- "return_std =>" => 356.52713477862386
+ "return_std =>" => 356.52713477862386,
),
Dict(
"policy_path" => "hopper/hopper_online_7.pkl",
@@ -968,11 +803,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 1726.0060438089222,
- "return_std =>" => 761.6326666292086
+ "return_std =>" => 761.6326666292086,
),
Dict(
"policy_path" => "hopper/hopper_online_8.pkl",
@@ -981,11 +816,11 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 2952.957468938808,
- "return_std =>" => 682.5831907733249
+ "return_std =>" => 682.5831907733249,
),
Dict(
"policy_path" => "hopper/hopper_online_9.pkl",
@@ -994,550 +829,407 @@ const D4RL_POLICIES = [
"hopper-random-v0",
"hopper-expert-v0",
"hopper-medium-replay-v0",
- "hopper-medium-expert-v0"
+ "hopper-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 2369.7998719150673,
- "return_std =>" => 1119.4914225331481
+ "return_std =>" => 1119.4914225331481,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_0.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 2.21,
- "return_std =>" => 8.873888662812938
+ "return_std =>" => 8.873888662812938,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_10.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 627.86,
- "return_std =>" => 161.0254650668645
+ "return_std =>" => 161.0254650668645,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_1.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 41.74,
- "return_std =>" => 72.2068722491149
+ "return_std =>" => 72.2068722491149,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_2.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 124.9,
- "return_std =>" => 131.5638628195448
+ "return_std =>" => 131.5638628195448,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_3.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 107.78,
- "return_std =>" => 109.32251186283638
+ "return_std =>" => 109.32251186283638,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_4.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 289.46,
- "return_std =>" => 262.69070862898826
+ "return_std =>" => 262.69070862898826,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_5.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 356.17,
- "return_std =>" => 276.9112151936068
+ "return_std =>" => 276.9112151936068,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_6.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 393.87,
- "return_std =>" => 309.08651394067647
+ "return_std =>" => 309.08651394067647,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_7.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 517.4,
- "return_std =>" => 274.58688970888613
+ "return_std =>" => 274.58688970888613,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_8.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 565.42,
- "return_std =>" => 210.94450360225082
+ "return_std =>" => 210.94450360225082,
),
Dict(
"policy_path" => "maze2d_large/maze2d_large_dapg_9.pkl",
- "task.task_names" => [
- "maze2d-large-v1"
- ],
+ "task.task_names" => ["maze2d-large-v1"],
"agent_name" => "DAPG",
"return_mean" => 629.22,
- "return_std =>" => 123.23023817229276
+ "return_std =>" => 123.23023817229276,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_0.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 83.15,
- "return_std =>" => 177.59827561099797
+ "return_std =>" => 177.59827561099797,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_10.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 442.35,
- "return_std =>" => 161.2205554512203
+ "return_std =>" => 161.2205554512203,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_1.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 177.8,
- "return_std =>" => 218.1089635938881
+ "return_std =>" => 218.1089635938881,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_2.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 249.33,
- "return_std =>" => 237.2338110388146
+ "return_std =>" => 237.2338110388146,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_3.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 214.81,
- "return_std =>" => 246.09809812349224
+ "return_std =>" => 246.09809812349224,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_4.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 254.63,
- "return_std =>" => 262.0181541420365
+ "return_std =>" => 262.0181541420365,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_5.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 238.76,
- "return_std =>" => 260.3596404975241
+ "return_std =>" => 260.3596404975241,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_6.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 374.9,
- "return_std =>" => 222.18107480161314
+ "return_std =>" => 222.18107480161314,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_7.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 379.68,
- "return_std =>" => 228.59111443798514
+ "return_std =>" => 228.59111443798514,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_8.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 392.9,
- "return_std =>" => 217.99805044999832
+ "return_std =>" => 217.99805044999832,
),
Dict(
"policy_path" => "maze2d_medium/maze2d_medium_dapg_9.pkl",
- "task.task_names" => [
- "maze2d-medium-v1"
- ],
+ "task.task_names" => ["maze2d-medium-v1"],
"agent_name" => "DAPG",
"return_mean" => 432.03,
- "return_std =>" => 173.93714123211294
+ "return_std =>" => 173.93714123211294,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_0.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 22.19,
- "return_std =>" => 25.18320670605711
+ "return_std =>" => 25.18320670605711,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_10.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 250.64,
- "return_std =>" => 36.357810715168206
+ "return_std =>" => 36.357810715168206,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_1.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 43.33,
- "return_std =>" => 66.01621846182951
+ "return_std =>" => 66.01621846182951,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_2.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 100.97,
- "return_std =>" => 95.598060126762
+ "return_std =>" => 95.598060126762,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_3.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 115.26,
- "return_std =>" => 120.07919220247945
+ "return_std =>" => 120.07919220247945,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_4.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 106.56,
- "return_std =>" => 123.82562901112192
+ "return_std =>" => 123.82562901112192,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_5.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 142.5,
- "return_std =>" => 111.55568116416124
+ "return_std =>" => 111.55568116416124,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_6.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 172.13,
- "return_std =>" => 118.24048841238772
+ "return_std =>" => 118.24048841238772,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_7.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 190.98,
- "return_std =>" => 73.81706848690214
+ "return_std =>" => 73.81706848690214,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_8.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 228.17,
- "return_std =>" => 39.635856241539685
+ "return_std =>" => 39.635856241539685,
),
Dict(
"policy_path" => "maze2d_umaze/maze2d_umaze_dapg_9.pkl",
- "task.task_names" => [
- "maze2d-umaze-v1"
- ],
+ "task.task_names" => ["maze2d-umaze-v1"],
"agent_name" => "DAPG",
"return_mean" => 239.34,
- "return_std =>" => 37.597664821102924
+ "return_std =>" => 37.597664821102924,
),
Dict(
"policy_path" => "pen/pen_dapg_0.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 1984.096763504694,
- "return_std =>" => 1929.6110474391166
+ "return_std =>" => 1929.6110474391166,
),
Dict(
"policy_path" => "pen/pen_dapg_10.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3808.794849593491,
- "return_std =>" => 1932.9965631785215
+ "return_std =>" => 1932.9965631785215,
),
Dict(
"policy_path" => "pen/pen_dapg_1.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2480.1224231814135,
- "return_std =>" => 2125.5773427152635
+ "return_std =>" => 2125.5773427152635,
),
Dict(
"policy_path" => "pen/pen_dapg_2.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2494.1335875747145,
- "return_std =>" => 2118.0014860996175
+ "return_std =>" => 2118.0014860996175,
),
Dict(
"policy_path" => "pen/pen_dapg_3.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2802.87073294418,
- "return_std =>" => 2120.3981104287323
+ "return_std =>" => 2120.3981104287323,
),
Dict(
"policy_path" => "pen/pen_dapg_4.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3136.18545171068,
- "return_std =>" => 2112.923714191993
+ "return_std =>" => 2112.923714191993,
),
Dict(
"policy_path" => "pen/pen_dapg_5.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3110.619191864754,
- "return_std =>" => 2012.2585161410343
+ "return_std =>" => 2012.2585161410343,
),
Dict(
"policy_path" => "pen/pen_dapg_6.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3410.4384362331157,
- "return_std =>" => 2029.187357465904
+ "return_std =>" => 2029.187357465904,
),
Dict(
"policy_path" => "pen/pen_dapg_7.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3489.353704450997,
- "return_std =>" => 2035.2279026017748
+ "return_std =>" => 2035.2279026017748,
),
Dict(
"policy_path" => "pen/pen_dapg_8.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3673.9622983303598,
- "return_std =>" => 2052.8837762657795
+ "return_std =>" => 2052.8837762657795,
),
Dict(
"policy_path" => "pen/pen_dapg_9.pkl",
- "task.task_names" => [
- "pen-cloned-v0",
- "pen-expert-v0",
- "pen-human-v0"
- ],
+ "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3683.932983177092,
- "return_std =>" => 2028.9543873822265
+ "return_std =>" => 2028.9543873822265,
),
Dict(
"policy_path" => "relocate/relocate_dapg_0.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => -4.4718813284277195,
- "return_std =>" => 0.9021515021945451
+ "return_std =>" => 0.9021515021945451,
),
Dict(
"policy_path" => "relocate/relocate_dapg_10.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3481.7834354311035,
- "return_std =>" => 813.1857720257618
+ "return_std =>" => 813.1857720257618,
),
Dict(
"policy_path" => "relocate/relocate_dapg_1.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 5.070946470816939,
- "return_std =>" => 31.708695854456067
+ "return_std =>" => 31.708695854456067,
),
Dict(
"policy_path" => "relocate/relocate_dapg_2.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 54.976670129729555,
- "return_std =>" => 140.09635704443158
+ "return_std =>" => 140.09635704443158,
),
Dict(
"policy_path" => "relocate/relocate_dapg_3.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 54.11338525066304,
- "return_std =>" => 146.87277676706216
+ "return_std =>" => 146.87277676706216,
),
Dict(
"policy_path" => "relocate/relocate_dapg_4.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 97.16474411169358,
- "return_std =>" => 164.81156449057102
+ "return_std =>" => 164.81156449057102,
),
Dict(
"policy_path" => "relocate/relocate_dapg_5.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 366.3185681324701,
- "return_std =>" => 581.577837554543
+ "return_std =>" => 581.577837554543,
),
Dict(
"policy_path" => "relocate/relocate_dapg_6.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 1254.0676523894747,
- "return_std =>" => 929.5248207929493
+ "return_std =>" => 929.5248207929493,
),
Dict(
"policy_path" => "relocate/relocate_dapg_7.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2700.2361856493385,
- "return_std =>" => 1089.9871332809942
+ "return_std =>" => 1089.9871332809942,
),
Dict(
"policy_path" => "relocate/relocate_dapg_8.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 2570.351217370911,
- "return_std =>" => 1266.9305994339466
+ "return_std =>" => 1266.9305994339466,
),
Dict(
"policy_path" => "relocate/relocate_dapg_9.pkl",
- "task.task_names" => [
- "relocate-cloned-v0",
- "relocate-expert-v0",
- "relocate-human-v0"
- ],
+ "task.task_names" =>
+ ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"],
"agent_name" => "DAPG",
"return_mean" => 3379.424369497742,
- "return_std =>" => 948.6183219418235
+ "return_std =>" => 948.6183219418235,
),
Dict(
"policy_path" => "walker/walker_online_0.pkl",
@@ -1546,11 +1238,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 17.57372020467802,
- "return_std =>" => 51.686802739349666
+ "return_std =>" => 51.686802739349666,
),
Dict(
"policy_path" => "walker/walker_online_10.pkl",
@@ -1559,11 +1251,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 4120.947079569632,
- "return_std =>" => 468.1515654051671
+ "return_std =>" => 468.1515654051671,
),
Dict(
"policy_path" => "walker/walker_online_1.pkl",
@@ -1572,11 +1264,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 193.84631742541606,
- "return_std =>" => 185.16785303932383
+ "return_std =>" => 185.16785303932383,
),
Dict(
"policy_path" => "walker/walker_online_2.pkl",
@@ -1585,11 +1277,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 942.6191179097373,
- "return_std =>" => 532.9834162811841
+ "return_std =>" => 532.9834162811841,
),
Dict(
"policy_path" => "walker/walker_online_3.pkl",
@@ -1598,11 +1290,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 2786.7497792224794,
- "return_std =>" => 477.5450988462439
+ "return_std =>" => 477.5450988462439,
),
Dict(
"policy_path" => "walker/walker_online_4.pkl",
@@ -1611,11 +1303,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 914.4680927038296,
- "return_std =>" => 559.5155757967623
+ "return_std =>" => 559.5155757967623,
),
Dict(
"policy_path" => "walker/walker_online_5.pkl",
@@ -1624,11 +1316,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 3481.491012709211,
- "return_std =>" => 87.12729823320758
+ "return_std =>" => 87.12729823320758,
),
Dict(
"policy_path" => "walker/walker_online_6.pkl",
@@ -1637,11 +1329,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 2720.2509272083826,
- "return_std =>" => 746.9753406110725
+ "return_std =>" => 746.9753406110725,
),
Dict(
"policy_path" => "walker/walker_online_7.pkl",
@@ -1650,11 +1342,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 3926.346852318098,
- "return_std =>" => 365.4230491920236
+ "return_std =>" => 365.4230491920236,
),
Dict(
"policy_path" => "walker/walker_online_8.pkl",
@@ -1663,11 +1355,11 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 3695.4887678612936,
- "return_std =>" => 262.0350155576298
+ "return_std =>" => 262.0350155576298,
),
Dict(
"policy_path" => "walker/walker_online_9.pkl",
@@ -1676,10 +1368,10 @@ const D4RL_POLICIES = [
"walker2d-random-v0",
"walker2d-expert-v0",
"walker2d-medium-replay-v0",
- "walker2d-medium-expert-v0"
+ "walker2d-medium-expert-v0",
],
"agent_name" => "SAC",
"return_mean" => 4122.358396232011,
- "return_std =>" => 107.76279305206488
- )
-]
\ No newline at end of file
+ "return_std =>" => 107.76279305206488,
+ ),
+]
diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl
index f47b55624..c50f83366 100644
--- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl
+++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl
@@ -21,8 +21,8 @@ Flux.@functor D4RLGaussianNetwork
function (model::D4RLGaussianNetwork)(
state::AbstractArray;
- rng::AbstractRNG=MersenneTwister(123),
- noisy::Bool=true
+ rng::AbstractRNG = MersenneTwister(123),
+ noisy::Bool = true,
)
x = model.pre(state)
μ, logσ = model.μ(x), model.logσ(x)
@@ -32,7 +32,7 @@ function (model::D4RLGaussianNetwork)(
a = μ + exp.(logσ)
end
a, μ
-end
+end
"""
d4rl_policy(env, agent, epoch)
@@ -45,11 +45,8 @@ Check [deep_ope](https://github.com/google-research/deep_ope) with preloaded wei
- `agent::String`: can be `dapg` or `online`.
- `epoch::Int`: can be in `0:10`.
"""
-function d4rl_policy(
- env::String,
- agent::String,
- epoch::Int)
-
+function d4rl_policy(env::String, agent::String, epoch::Int)
+
folder_prefix = "deep-ope-d4rl"
try
@datadep_str "$(folder_prefix)-$(env)_$(agent)_$(epoch)"
@@ -60,13 +57,13 @@ function d4rl_policy(
end
policy_folder = @datadep_str "$(folder_prefix)-$(env)_$(agent)_$(epoch)"
policy_file = "$(policy_folder)/$(readdir(policy_folder)[1])"
-
+
model_params = Pickle.npyload(policy_file)
@pipe parse_network_params(model_params) |> build_model(_...)
end
function parse_network_params(model_params::Dict)
- size_dict = Dict{String, Tuple}()
+ size_dict = Dict{String,Tuple}()
nonlinearity = nothing
output_transformation = nothing
for param in model_params
@@ -81,7 +78,7 @@ function parse_network_params(model_params::Dict)
nonlinearity = tanh
end
else
- if param_value == "tanh_gaussian"
+ if param_value == "tanh_gaussian"
output_transformation = tanh
else
output_transformation = identity
@@ -92,29 +89,31 @@ function parse_network_params(model_params::Dict)
model_params, size_dict, nonlinearity, output_transformation
end
-function build_model(model_params::Dict, size_dict::Dict, nonlinearity::Function, output_transformation::Function)
+function build_model(
+ model_params::Dict,
+ size_dict::Dict,
+ nonlinearity::Function,
+ output_transformation::Function,
+)
fc_0 = Dense(size_dict["fc0/weight"]..., nonlinearity)
fc_0 = @set fc_0.weight = model_params["fc0/weight"]
fc_0 = @set fc_0.bias = model_params["fc0/bias"]
-
+
fc_1 = Dense(size_dict["fc1/weight"]..., nonlinearity)
fc_1 = @set fc_1.weight = model_params["fc1/weight"]
fc_1 = @set fc_1.bias = model_params["fc1/bias"]
-
+
μ_fc = Dense(size_dict["last_fc/weight"]...)
μ_fc = @set μ_fc.weight = model_params["last_fc/weight"]
μ_fc = @set μ_fc.bias = model_params["last_fc/bias"]
-
+
log_σ_fc = Dense(size_dict["last_fc_log_std/weight"]...)
log_σ_fc = @set log_σ_fc.weight = model_params["last_fc_log_std/weight"]
log_σ_fc = @set log_σ_fc.bias = model_params["last_fc_log_std/bias"]
-
- pre = Chain(
- fc_0,
- fc_1
- )
+
+ pre = Chain(fc_0, fc_1)
μ = Chain(μ_fc)
log_σ = Chain(log_σ_fc)
-
+
D4RLGaussianNetwork(pre, μ, log_σ)
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl
index 816736bc4..b3a742428 100644
--- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl
+++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl
@@ -17,28 +17,30 @@ function deep_ope_d4rl_evaluate(
env_name::String,
agent::String,
epoch::Int;
- gym_env_name::Union{String, Nothing}=nothing,
- rng::AbstractRNG=MersenneTwister(123),
- num_evaluations::Int=10,
- γ::Float64=1.0,
- noisy::Bool=false,
- env_seed::Union{Int, Nothing}=nothing
-)
+ gym_env_name::Union{String,Nothing} = nothing,
+ rng::AbstractRNG = MersenneTwister(123),
+ num_evaluations::Int = 10,
+ γ::Float64 = 1.0,
+ noisy::Bool = false,
+ env_seed::Union{Int,Nothing} = nothing,
+)
policy_folder = "$(env_name)_$(agent)_$(epoch)"
if gym_env_name === nothing
for policy in D4RL_POLICIES
policy_file = split(policy["policy_path"], "/")[end]
- if chop(policy_file, head=0, tail=4) == policy_folder
+ if chop(policy_file, head = 0, tail = 4) == policy_folder
gym_env_name = policy["task.task_names"][1]
break
end
end
- if gym_env_name === nothing error("invalid parameters") end
+ if gym_env_name === nothing
+ error("invalid parameters")
+ end
end
- env = GymEnv(gym_env_name; seed=env_seed)
+ env = GymEnv(gym_env_name; seed = env_seed)
model = d4rl_policy(env_name, agent, epoch)
scores = Vector{Float64}(undef, num_evaluations)
@@ -48,14 +50,23 @@ function deep_ope_d4rl_evaluate(
reset!(env)
while !is_terminated(env)
s = state(env)
- a = model(s;rng=rng, noisy=noisy)[1]
- s, a , env(a)
+ a = model(s; rng = rng, noisy = noisy)[1]
+ s, a, env(a)
r = reward(env)
t = is_terminated(env)
- score += r*γ*(1-t)
+ score += r * γ * (1 - t)
end
scores[eval] = score
end
- plt = lineplot(1:length(scores), scores, title = "$(gym_env_name) scores", name = "scores", xlabel = "episode", canvas = DotCanvas, ylabel = "score", border=:ascii)
+ plt = lineplot(
+ 1:length(scores),
+ scores,
+ title = "$(gym_env_name) scores",
+ name = "scores",
+ xlabel = "episode",
+ canvas = DotCanvas,
+ ylabel = "score",
+ border = :ascii,
+ )
plt
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl
index 0c0081dc4..d1be2502f 100644
--- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl
+++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl
@@ -1,11 +1,11 @@
gcs_prefix = "gs://gresearch/deep-ope/d4rl"
-folder_prefix = "deep-ope-d4rl"
+folder_prefix = "deep-ope-d4rl"
policies = D4RL_POLICIES
function deep_ope_d4rl_init()
for policy in policies
gcs_policy_folder = policy["policy_path"]
- local_policy_folder = chop(split(gcs_policy_folder, "/")[end], head=0, tail=4)
+ local_policy_folder = chop(split(gcs_policy_folder, "/")[end], head = 0, tail = 4)
register(
DataDep(
"$(folder_prefix)-$(local_policy_folder)",
@@ -16,7 +16,7 @@ function deep_ope_d4rl_init()
Authors: Justin Fu, Mohammad Norouzi, Ofir Nachum, George Tucker, ziyu wang, Alexander Novikov, Mengjiao Yang, Michael R Zhang,
Yutian Chen, Aviral Kumar, Cosmin Paduraru, Sergey Levine, Thomas Paine
Year: 2021
-
+
Deep OPE contains:
Policies for the tasks in the D4RL, DeepMind Locomotion and Control Suite datasets.
Policies trained with the following algorithms (D4PG, ABM, CRR, SAC, DAPG and BC) and snapshots along the training trajectory. This facilitates
@@ -29,8 +29,8 @@ function deep_ope_d4rl_init()
what datasets are available, please refer to D4RL: Datasets for Deep Data-Driven Reinforcement Learning.
""",
"$(gcs_prefix)/$(gcs_policy_folder)";
- fetch_method=fetch_gc_file
- )
+ fetch_method = fetch_gc_file,
+ ),
)
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/init.jl b/src/ReinforcementLearningDatasets/src/init.jl
index 90bf6f0bd..fd1ede41d 100644
--- a/src/ReinforcementLearningDatasets/src/init.jl
+++ b/src/ReinforcementLearningDatasets/src/init.jl
@@ -6,4 +6,4 @@ function __init__()
RLDatasets.bsuite_init()
RLDatasets.dm_init()
RLDatasets.deep_ope_d4rl_init()
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl
index 287e6dbac..ec022d63e 100644
--- a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl
+++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl
@@ -57,13 +57,13 @@ const TESTING_SUITE = [
]
# Total of 45 games.
-const ALL = cat(TUNING_SUITE, TESTING_SUITE, dims=1)
+const ALL = cat(TUNING_SUITE, TESTING_SUITE, dims = 1)
function rl_unplugged_atari_params()
game = ALL
run = 1:5
shards = 0:99
-
+
@info game run shards
end
@@ -98,11 +98,12 @@ function rl_unplugged_atari_init()
on Atari if you are interested in comparing your approach to other state of the
art offline RL methods with discrete actions.
""",
- "gs://rl_unplugged/atari/$game/"*@sprintf("run_%i-%05i-of-%05i", run, index, num_shards);
- fetch_method = fetch_gc_file
- )
+ "gs://rl_unplugged/atari/$game/" *
+ @sprintf("run_%i-%05i-of-%05i", run, index, num_shards);
+ fetch_method = fetch_gc_file,
+ ),
)
end
end
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl
index 753265e08..08cade692 100644
--- a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl
+++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl
@@ -1,7 +1,7 @@
export rl_unplugged_atari_dataset
using Base.Threads
-using Printf:@sprintf
+using Printf: @sprintf
using Base.Iterators
using TFRecord
using ImageCore
@@ -14,14 +14,14 @@ using PNGFiles
Represent an AtariRLTransition and can also represent a batch.
"""
struct AtariRLTransition <: RLTransition
- state
- action
- reward
- terminal
- next_state
- next_action
- episode_id
- episode_return
+ state::Any
+ action::Any
+ reward::Any
+ terminal::Any
+ next_state::Any
+ next_action::Any
+ episode_id::Any
+ episode_return::Any
end
function decode_frame(bytes)
@@ -29,7 +29,7 @@ function decode_frame(bytes)
end
function decode_state(bytes)
- PermutedDimsArray(StackedView((decode_frame(x) for x in bytes)...), (2,3,1))
+ PermutedDimsArray(StackedView((decode_frame(x) for x in bytes)...), (2, 3, 1))
end
function AtariRLTransition(example::TFRecord.Example)
@@ -70,65 +70,62 @@ function rl_unplugged_atari_dataset(
game::String,
run::Int,
shards::Vector{Int};
- shuffle_buffer_size=10_000,
- tf_reader_bufsize=1*1024*1024,
- tf_reader_sz=10_000,
- batch_size=256,
- n_preallocations=nthreads()*12
+ shuffle_buffer_size = 10_000,
+ tf_reader_bufsize = 1 * 1024 * 1024,
+ tf_reader_sz = 10_000,
+ batch_size = 256,
+ n_preallocations = nthreads() * 12,
)
n = nthreads()
@info "Loading the shards $shards in $run run of $game with $n threads"
folders = [
- @datadep_str "rl-unplugged-atari-$(titlecase(game))-$run-$shard"
- for shard in shards
+ @datadep_str "rl-unplugged-atari-$(titlecase(game))-$run-$shard" for
+ shard in shards
]
-
+
ch_files = Channel{String}(length(folders)) do ch
for folder in cycle(folders)
file = folder * "/$(readdir(folder)[1])"
put!(ch, file)
end
end
-
+
shuffled_files = buffered_shuffle(ch_files, length(folders))
-
+
ch_src = Channel{AtariRLTransition}(n * tf_reader_sz) do ch
for fs in partition(shuffled_files, n)
Threads.foreach(
TFRecord.read(
fs;
- compression=:gzip,
- bufsize=tf_reader_bufsize,
- channel_size=tf_reader_sz,
+ compression = :gzip,
+ bufsize = tf_reader_bufsize,
+ channel_size = tf_reader_sz,
);
- schedule=Threads.StaticSchedule()
+ schedule = Threads.StaticSchedule(),
) do x
put!(ch, AtariRLTransition(x))
end
end
end
-
- transitions = buffered_shuffle(
- ch_src,
- shuffle_buffer_size
- )
-
+
+ transitions = buffered_shuffle(ch_src, shuffle_buffer_size)
+
buffer = AtariRLTransition(
- Array{UInt8, 4}(undef, 84, 84, 4, batch_size),
- Array{Int, 1}(undef, batch_size),
- Array{Float32, 1}(undef, batch_size),
- Array{Bool, 1}(undef, batch_size),
- Array{UInt8, 4}(undef, 84, 84, 4, batch_size),
- Array{Int, 1}(undef, batch_size),
- Array{Int, 1}(undef, batch_size),
- Array{Float32, 1}(undef, batch_size),
+ Array{UInt8,4}(undef, 84, 84, 4, batch_size),
+ Array{Int,1}(undef, batch_size),
+ Array{Float32,1}(undef, batch_size),
+ Array{Bool,1}(undef, batch_size),
+ Array{UInt8,4}(undef, 84, 84, 4, batch_size),
+ Array{Int,1}(undef, batch_size),
+ Array{Int,1}(undef, batch_size),
+ Array{Float32,1}(undef, batch_size),
)
taskref = Ref{Task}()
- res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff
+ res = RingBuffer(buffer; taskref = taskref, sz = n_preallocations) do buff
Threads.@threads for i in 1:batch_size
batch!(buff, popfirst!(transitions), i)
end
@@ -137,4 +134,4 @@ function rl_unplugged_atari_dataset(
bind(ch_src, taskref[])
bind(ch_files, taskref[])
res
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl
index af8ff4af6..713342fa9 100644
--- a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl
+++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl
@@ -3,11 +3,11 @@ export rl_unplugged_bsuite_dataset
using TFRecord
struct BSuiteRLTransition <: RLTransition
- state
- action
- reward
- terminal
- next_state
+ state::Any
+ action::Any
+ reward::Any
+ terminal::Any
+ next_state::Any
end
function BSuiteRLTransition(example::TFRecord.Example, game::String)
@@ -55,46 +55,43 @@ function rl_unplugged_bsuite_dataset(
game::String,
shards::Vector{Int},
type::String;
- is_shuffle::Bool=true,
- stochasticity::Float64=0.0,
- shuffle_buffer_size::Int=10_000,
- tf_reader_bufsize::Int=10_000,
- tf_reader_sz::Int=10_000,
- batch_size::Int=256,
- n_preallocations::Int=nthreads()*12
-)
+ is_shuffle::Bool = true,
+ stochasticity::Float64 = 0.0,
+ shuffle_buffer_size::Int = 10_000,
+ tf_reader_bufsize::Int = 10_000,
+ tf_reader_sz::Int = 10_000,
+ batch_size::Int = 256,
+ n_preallocations::Int = nthreads() * 12,
+)
n = nthreads()
repo = "rl-unplugged-bsuite"
-
- folders= [
- @datadep_str "$repo-$game-$stochasticity-$shard-$type"
- for shard in shards
- ]
-
+
+ folders = [@datadep_str "$repo-$game-$stochasticity-$shard-$type" for shard in shards]
+
ch_files = Channel{String}(length(folders)) do ch
for folder in cycle(folders)
file = folder * "/$(readdir(folder)[1])"
put!(ch, file)
end
end
-
+
if is_shuffle
files = buffered_shuffle(ch_files, length(folders))
else
files = ch_files
end
-
+
ch_src = Channel{BSuiteRLTransition}(n * tf_reader_sz) do ch
for fs in partition(files, n)
Threads.foreach(
TFRecord.read(
fs;
- compression=:gzip,
- bufsize=tf_reader_bufsize,
- channel_size=tf_reader_sz,
+ compression = :gzip,
+ bufsize = tf_reader_bufsize,
+ channel_size = tf_reader_sz,
);
- schedule=Threads.StaticSchedule()
+ schedule = Threads.StaticSchedule(),
) do x
put!(ch, BSuiteRLTransition(x, game))
end
@@ -102,33 +99,30 @@ function rl_unplugged_bsuite_dataset(
end
if is_shuffle
- transitions = buffered_shuffle(
- ch_src,
- shuffle_buffer_size
- )
+ transitions = buffered_shuffle(ch_src, shuffle_buffer_size)
else
transitions = ch_src
end
-
+
taskref = Ref{Task}()
- ob_size = game=="mountain_car" ? 3 : 6
+ ob_size = game == "mountain_car" ? 3 : 6
if game == "catch"
- obs_template = Array{Float32, 3}(undef, 10, 5, batch_size)
+ obs_template = Array{Float32,3}(undef, 10, 5, batch_size)
else
- obs_template = Array{Float32, 2}(undef, ob_size, batch_size)
+ obs_template = Array{Float32,2}(undef, ob_size, batch_size)
end
buffer = BSuiteRLTransition(
copy(obs_template),
- Array{Int, 1}(undef, batch_size),
- Array{Float32, 1}(undef, batch_size),
- Array{Bool, 1}(undef, batch_size),
+ Array{Int,1}(undef, batch_size),
+ Array{Float32,1}(undef, batch_size),
+ Array{Bool,1}(undef, batch_size),
copy(obs_template),
)
- res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff
+ res = RingBuffer(buffer; taskref = taskref, sz = n_preallocations) do buff
Threads.@threads for i in 1:batch_size
batch!(buff, take!(transitions), i)
end
@@ -137,4 +131,4 @@ function rl_unplugged_bsuite_dataset(
bind(ch_src, taskref[])
bind(ch_files, taskref[])
res
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl
index 487a6f72b..488619f77 100644
--- a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl
+++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl
@@ -1,17 +1,9 @@
repo = "bsuite"
export bsuite_params
-const BSUITE_DATASETS = [
- "cartpole",
- "catch",
- "mountain_car"
-]
+const BSUITE_DATASETS = ["cartpole", "catch", "mountain_car"]
-types = [
- "full",
- "full_train",
- "full_valid"
-]
+types = ["full", "full_train", "full_valid"]
function bsuite_params()
game = BSUITE_DATASETS
@@ -47,11 +39,11 @@ function bsuite_init()
where the stochasticity of the environment is easy to control.
""",
"gs://rl_unplugged/$repo/$env/0_$stochasticity/$(index)_$type-00000-of-00001",
- fetch_method = fetch_gc_file
- )
+ fetch_method = fetch_gc_file,
+ ),
)
end
end
end
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl
index 364555e79..b46fffd43 100644
--- a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl
+++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl
@@ -12,170 +12,166 @@ function dm_params()
@info game shards
end
-const DM_LOCOMOTION_RODENT = Dict{String, String}(
+const DM_LOCOMOTION_RODENT = Dict{String,String}(
"rodent_gaps" => "dm_locomotion/rodent_gaps/seq2",
"rodent_escape" => "dm_locomotion/rodent_bowl_escape/seq2",
"rodent_two_touch" => "dm_locomotion/rodent_two_touch/seq40",
- "rodent_mazes" => "dm_locomotion/rodent_mazes/seq40"
+ "rodent_mazes" => "dm_locomotion/rodent_mazes/seq40",
)
-const DM_LOCOMOTION_RODENT_SIZE = Dict{String, Tuple}(
+const DM_LOCOMOTION_RODENT_SIZE = Dict{String,Tuple}(
"observation/walker/actuator_activation" => (38,),
# "observation/walker/sensors_torque" => (),
# "observation/walker/sensors_force" => (),
"observation/walker/body_height" => (1,),
"observation/walker/end_effectors_pos" => (12,),
- "observation/walker/joints_pos"=> (30,),
- "observation/walker/joints_vel"=> (30,),
- "observation/walker/tendons_pos"=> (8,),
- "observation/walker/tendons_vel"=> (8,),
- "observation/walker/appendages_pos"=> (15,),
- "observation/walker/world_zaxis"=> (3,),
- "observation/walker/sensors_accelerometer"=> (3,),
- "observation/walker/sensors_velocimeter"=> (3,),
- "observation/walker/sensors_gyro" => (3,),
- "observation/walker/sensors_touch"=> (4,),
- "observation/walker/egocentric_camera"=> (64, 64, 3),
- "action"=> (38,),
- "discount"=> (),
- "reward"=> (),
- "step_type"=> ()
+ "observation/walker/joints_pos" => (30,),
+ "observation/walker/joints_vel" => (30,),
+ "observation/walker/tendons_pos" => (8,),
+ "observation/walker/tendons_vel" => (8,),
+ "observation/walker/appendages_pos" => (15,),
+ "observation/walker/world_zaxis" => (3,),
+ "observation/walker/sensors_accelerometer" => (3,),
+ "observation/walker/sensors_velocimeter" => (3,),
+ "observation/walker/sensors_gyro" => (3,),
+ "observation/walker/sensors_touch" => (4,),
+ "observation/walker/egocentric_camera" => (64, 64, 3),
+ "action" => (38,),
+ "discount" => (),
+ "reward" => (),
+ "step_type" => (),
)
-const DM_LOCOMOTION_HUMANOID = Dict{String, String}(
+const DM_LOCOMOTION_HUMANOID = Dict{String,String}(
"humanoid_corridor" => "dm_locomotion/humanoid_corridor/seq2",
"humanoid_gaps" => "dm_locomotion/humanoid_gaps/seq2",
- "humanoid_walls" => "dm_locomotion/humanoid_walls/seq40"
+ "humanoid_walls" => "dm_locomotion/humanoid_walls/seq40",
)
-const DM_LOCOMOTION_HUMANOID_SIZE = Dict{String, Tuple}(
+const DM_LOCOMOTION_HUMANOID_SIZE = Dict{String,Tuple}(
# "observation/walker/actuator_activation" => (0,),
"observation/walker/sensors_torque" => (6,),
# "observation/walker/sensors_force" => (),
- "observation/walker/joints_vel"=> (56,),
- "observation/walker/sensors_velocimeter"=> (3,),
- "observation/walker/sensors_gyro"=> (3,),
- "observation/walker/joints_pos"=> (56,),
+ "observation/walker/joints_vel" => (56,),
+ "observation/walker/sensors_velocimeter" => (3,),
+ "observation/walker/sensors_gyro" => (3,),
+ "observation/walker/joints_pos" => (56,),
"observation/walker/appendages_pos" => (15,),
- "observation/walker/world_zaxis"=> (3,),
- "observation/walker/body_height"=> (1,),
- "observation/walker/sensors_accelerometer"=> (3,),
- "observation/walker/end_effectors_pos"=> (12,),
- "observation/walker/egocentric_camera"=> (
- 64,
- 64,
- 3,
- ),
- "action"=> (56,),
- "discount"=> (),
- "reward"=> (),
+ "observation/walker/world_zaxis" => (3,),
+ "observation/walker/body_height" => (1,),
+ "observation/walker/sensors_accelerometer" => (3,),
+ "observation/walker/end_effectors_pos" => (12,),
+ "observation/walker/egocentric_camera" => (64, 64, 3),
+ "action" => (56,),
+ "discount" => (),
+ "reward" => (),
# "episodic_reward"=> (),
- "step_type"=> ()
+ "step_type" => (),
)
-const DM_CONTROL_SUITE_SIZE = Dict{String, Dict{String, Tuple}}(
- "cartpole_swingup" => Dict{String, Tuple}(
- "observation/position"=> (3,),
- "observation/velocity"=> (2,),
- "action"=> (1,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
+const DM_CONTROL_SUITE_SIZE = Dict{String,Dict{String,Tuple}}(
+ "cartpole_swingup" => Dict{String,Tuple}(
+ "observation/position" => (3,),
+ "observation/velocity" => (2,),
+ "action" => (1,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
),
- "cheetah_run" => Dict{String, Tuple}(
- "observation/position"=> (8,),
- "observation/velocity"=> (9,),
- "action"=> (6,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
+ "cheetah_run" => Dict{String,Tuple}(
+ "observation/position" => (8,),
+ "observation/velocity" => (9,),
+ "action" => (6,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
),
- "finger_turn_hard" => Dict{String, Tuple}(
- "observation/position"=> (4,),
- "observation/velocity"=> (3,),
- "observation/touch"=> (2,),
- "observation/target_position"=> (2,),
- "observation/dist_to_target"=> (1,),
- "action"=> (2,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
+ "finger_turn_hard" => Dict{String,Tuple}(
+ "observation/position" => (4,),
+ "observation/velocity" => (3,),
+ "observation/touch" => (2,),
+ "observation/target_position" => (2,),
+ "observation/dist_to_target" => (1,),
+ "action" => (2,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
),
- "fish_swim" => Dict{String, Tuple}(
- "observation/target"=> (3,),
- "observation/velocity"=> (13,),
- "observation/upright"=> (1,),
- "observation/joint_angles"=> (7,),
- "action"=> (5,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
+ "fish_swim" => Dict{String,Tuple}(
+ "observation/target" => (3,),
+ "observation/velocity" => (13,),
+ "observation/upright" => (1,),
+ "observation/joint_angles" => (7,),
+ "action" => (5,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
),
- "humanoid_run" => Dict{String, Tuple}(
- "observation/velocity"=> (27,),
- "observation/com_velocity"=> (3,),
- "observation/torso_vertical"=> (3,),
- "observation/extremities"=> (12,),
- "observation/head_height"=> (1,),
- "observation/joint_angles"=> (21,),
- "action"=> (21,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
+ "humanoid_run" => Dict{String,Tuple}(
+ "observation/velocity" => (27,),
+ "observation/com_velocity" => (3,),
+ "observation/torso_vertical" => (3,),
+ "observation/extremities" => (12,),
+ "observation/head_height" => (1,),
+ "observation/joint_angles" => (21,),
+ "action" => (21,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
),
- "manipulator_insert_ball" => Dict{String, Tuple}(
- "observation/arm_pos"=> (16,),
- "observation/arm_vel"=> (8,),
- "observation/touch"=> (5,),
- "observation/hand_pos"=> (4,),
- "observation/object_pos"=> (4,),
- "observation/object_vel"=> (3,),
- "observation/target_pos"=> (4,),
- "action"=> (5,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
+ "manipulator_insert_ball" => Dict{String,Tuple}(
+ "observation/arm_pos" => (16,),
+ "observation/arm_vel" => (8,),
+ "observation/touch" => (5,),
+ "observation/hand_pos" => (4,),
+ "observation/object_pos" => (4,),
+ "observation/object_vel" => (3,),
+ "observation/target_pos" => (4,),
+ "action" => (5,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
),
- "manipulator_insert_peg" => Dict{String, Tuple}(
- "observation/arm_pos"=> (16,),
- "observation/arm_vel"=> (8,),
- "observation/touch"=> (5,),
- "observation/hand_pos"=> (4,),
- "observation/object_pos"=> (4,),
- "observation/object_vel"=> (3,),
- "observation/target_pos"=> (4,),
- "episodic_reward"=> (),
- "action"=> (5,),
- "discount"=> (),
- "reward"=> (),
- "step_type"=> ()
+ "manipulator_insert_peg" => Dict{String,Tuple}(
+ "observation/arm_pos" => (16,),
+ "observation/arm_vel" => (8,),
+ "observation/touch" => (5,),
+ "observation/hand_pos" => (4,),
+ "observation/object_pos" => (4,),
+ "observation/object_vel" => (3,),
+ "observation/target_pos" => (4,),
+ "episodic_reward" => (),
+ "action" => (5,),
+ "discount" => (),
+ "reward" => (),
+ "step_type" => (),
),
- "walker_stand" => Dict{String, Tuple}(
- "observation/orientations"=> (14,),
- "observation/velocity"=> (9,),
- "observation/height"=> (1,),
- "action"=> (6,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
+ "walker_stand" => Dict{String,Tuple}(
+ "observation/orientations" => (14,),
+ "observation/velocity" => (9,),
+ "observation/height" => (1,),
+ "action" => (6,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
+ ),
+ "walker_walk" => Dict{String,Tuple}(
+ "observation/orientations" => (14,),
+ "observation/velocity" => (9,),
+ "observation/height" => (1,),
+ "action" => (6,),
+ "discount" => (),
+ "reward" => (),
+ "episodic_reward" => (),
+ "step_type" => (),
),
- "walker_walk" => Dict{String, Tuple}(
- "observation/orientations"=> (14,),
- "observation/velocity"=> (9,),
- "observation/height"=> (1,),
- "action"=> (6,),
- "discount"=> (),
- "reward"=> (),
- "episodic_reward"=> (),
- "step_type"=> ()
- )
)
const DM_LOCOMOTION = merge(DM_LOCOMOTION_HUMANOID, DM_LOCOMOTION_RODENT)
@@ -208,9 +204,10 @@ function dm_init()
please refer to the paper. DeepMind Control Suite is a traditional continuous action RL benchmark. In particular, it is recommended
that you test your approach in DeepMind Control Suite if you are interested in comparing against other state of the art offline RL methods.
""",
- "gs://rl_unplugged/dm_control_suite/$task/"*@sprintf("train-%05i-of-%05i", index, num_shards);
+ "gs://rl_unplugged/dm_control_suite/$task/" *
+ @sprintf("train-%05i-of-%05i", index, num_shards);
fetch_method = fetch_gc_file,
- )
+ ),
)
end
end
@@ -240,10 +237,11 @@ function dm_init()
It is recommended that you to try offline RL methods on DeepMind Locomotion dataset, if you are interested in very challenging
offline RL dataset with continuous action space.
""",
- "gs://rl_unplugged/$(DM_LOCOMOTION[task])/"*@sprintf("train-%05i-of-%05i", index, num_shards);
- fetch_method = fetch_gc_file
- )
+ "gs://rl_unplugged/$(DM_LOCOMOTION[task])/" *
+ @sprintf("train-%05i-of-%05i", index, num_shards);
+ fetch_method = fetch_gc_file,
+ ),
)
end
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl
index df3ebad5d..115ba1bde 100644
--- a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl
+++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl
@@ -3,31 +3,41 @@ export rl_unplugged_dm_dataset
using TFRecord
function make_batch_array(type::Type, feature_dims::Int, size::Tuple, batch_size::Int)
- Array{type, feature_dims+1}(undef, size..., batch_size)
+ Array{type,feature_dims + 1}(undef, size..., batch_size)
end
-function dm_buffer_dict(feature_size::Dict{String, Tuple}, batch_size::Int)
- obs_buffer = Dict{Symbol, AbstractArray}()
+function dm_buffer_dict(feature_size::Dict{String,Tuple}, batch_size::Int)
+ obs_buffer = Dict{Symbol,AbstractArray}()
- buffer_dict = Dict{Symbol, Any}()
+ buffer_dict = Dict{Symbol,Any}()
for feature in keys(feature_size)
feature_dims = length(feature_size[feature])
if split(feature, "/")[1] == "observation"
- ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0))
+ ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0))
if split(feature, "/")[end] == "egocentric_camera"
- obs_buffer[ob_key] = make_batch_array(UInt8, feature_dims, feature_size[feature], batch_size)
+ obs_buffer[ob_key] =
+ make_batch_array(UInt8, feature_dims, feature_size[feature], batch_size)
else
- obs_buffer[ob_key] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size)
+ obs_buffer[ob_key] = make_batch_array(
+ Float32,
+ feature_dims,
+ feature_size[feature],
+ batch_size,
+ )
end
elseif feature == "action"
- buffer_dict[:action] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size)
- buffer_dict[:next_action] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size)
+ buffer_dict[:action] =
+ make_batch_array(Float32, feature_dims, feature_size[feature], batch_size)
+ buffer_dict[:next_action] =
+ make_batch_array(Float32, feature_dims, feature_size[feature], batch_size)
elseif feature == "step_type"
- buffer_dict[:terminal] = make_batch_array(Bool, feature_dims, feature_size[feature], batch_size)
+ buffer_dict[:terminal] =
+ make_batch_array(Bool, feature_dims, feature_size[feature], batch_size)
else
ob_key = Symbol(feature)
- buffer_dict[ob_key] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size)
+ buffer_dict[ob_key] =
+ make_batch_array(Float32, feature_dims, feature_size[feature], batch_size)
end
end
@@ -61,28 +71,33 @@ function batch_named_tuple!(dest::NamedTuple, src::NamedTuple, i::Int)
end
end
-function make_transition(example::TFRecord.Example, feature_size::Dict{String, Tuple})
+function make_transition(example::TFRecord.Example, feature_size::Dict{String,Tuple})
f = example.features.feature
-
- observation_dict = Dict{Symbol, AbstractArray}()
- next_observation_dict = Dict{Symbol, AbstractArray}()
- transition_dict = Dict{Symbol, Any}()
+
+ observation_dict = Dict{Symbol,AbstractArray}()
+ next_observation_dict = Dict{Symbol,AbstractArray}()
+ transition_dict = Dict{Symbol,Any}()
for feature in keys(feature_size)
if split(feature, "/")[1] == "observation"
- ob_key = Symbol(chop(feature, head = length("observation")+1, tail=0))
+ ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0))
if split(feature, "/")[end] == "egocentric_camera"
cam_feature_size = feature_size[feature]
ob_size = prod(cam_feature_size)
- observation_dict[ob_key] = reshape(f[feature].bytes_list.value[1][1:ob_size], cam_feature_size...)
- next_observation_dict[ob_key] = reshape(f[feature].bytes_list.value[1][ob_size+1:end], cam_feature_size...)
+ observation_dict[ob_key] =
+ reshape(f[feature].bytes_list.value[1][1:ob_size], cam_feature_size...)
+ next_observation_dict[ob_key] = reshape(
+ f[feature].bytes_list.value[1][ob_size+1:end],
+ cam_feature_size...,
+ )
else
if feature_size[feature] == ()
observation_dict[ob_key] = f[feature].float_list.value
else
ob_size = feature_size[feature][1]
observation_dict[ob_key] = f[feature].float_list.value[1:ob_size]
- next_observation_dict[ob_key] = f[feature].float_list.value[ob_size+1:end]
+ next_observation_dict[ob_key] =
+ f[feature].float_list.value[ob_size+1:end]
end
end
elseif feature == "action"
@@ -138,28 +153,25 @@ function rl_unplugged_dm_dataset(
shards;
type = "dm_control_suite",
is_shuffle = true,
- shuffle_buffer_size=10_000,
- tf_reader_bufsize=10_000,
- tf_reader_sz=10_000,
- batch_size=256,
- n_preallocations=nthreads()*12
-)
+ shuffle_buffer_size = 10_000,
+ tf_reader_bufsize = 10_000,
+ tf_reader_sz = 10_000,
+ batch_size = 256,
+ n_preallocations = nthreads() * 12,
+)
n = nthreads()
repo = "rl-unplugged-dm"
-
- folders= [
- @datadep_str "$repo-$game-$shard"
- for shard in shards
- ]
-
+
+ folders = [@datadep_str "$repo-$game-$shard" for shard in shards]
+
ch_files = Channel{String}(length(folders)) do ch
for folder in cycle(folders)
file = folder * "/$(readdir(folder)[1])"
put!(ch, file)
end
end
-
+
if is_shuffle
files = buffered_shuffle(ch_files, length(folders))
else
@@ -173,11 +185,11 @@ function rl_unplugged_dm_dataset(
Threads.foreach(
TFRecord.read(
fs;
- compression=:gzip,
- bufsize=tf_reader_bufsize,
- channel_size=tf_reader_sz,
+ compression = :gzip,
+ bufsize = tf_reader_bufsize,
+ channel_size = tf_reader_sz,
);
- schedule=Threads.StaticSchedule()
+ schedule = Threads.StaticSchedule(),
) do x
put!(ch, make_transition(x, feature_size))
end
@@ -185,21 +197,18 @@ function rl_unplugged_dm_dataset(
end
if is_shuffle
- transitions = buffered_shuffle(
- ch_src,
- shuffle_buffer_size
- )
+ transitions = buffered_shuffle(ch_src, shuffle_buffer_size)
else
transitions = ch_src
end
-
+
taskref = Ref{Task}()
-
+
buffer_dict = dm_buffer_dict(feature_size, batch_size)
buffer = NamedTuple(buffer_dict)
- res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff
+ res = RingBuffer(buffer; taskref = taskref, sz = n_preallocations) do buff
Threads.@threads for i in 1:batch_size
batch_named_tuple!(buff, take!(transitions), i)
end
@@ -208,4 +217,4 @@ function rl_unplugged_dm_dataset(
bind(ch_src, taskref[])
bind(ch_files, taskref[])
res
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl
index 4c09e678e..5a62247de 100644
--- a/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl
+++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl
@@ -27,7 +27,7 @@ Therefore, it acts as a channel that holds a shuffled buffer which is of type Ve
- `buffer::Vector{T}`, The shuffled buffer.
- `rng<:AbstractRNG`.
"""
-struct BufferedShuffle{T, R<:AbstractRNG} <: AbstractChannel{T}
+struct BufferedShuffle{T,R<:AbstractRNG} <: AbstractChannel{T}
src::Channel{T}
buffer::Vector{T}
rng::R
@@ -43,7 +43,11 @@ Arguments:
- `buffer_size::Int`. The size of the buffered channel.
- `rng<:AbstractRNG` = Random.GLOBAL_RNG.
"""
-function buffered_shuffle(src::Channel{T}, buffer_size::Int;rng=Random.GLOBAL_RNG) where T
+function buffered_shuffle(
+ src::Channel{T},
+ buffer_size::Int;
+ rng = Random.GLOBAL_RNG,
+) where {T}
buffer = Array{T}(undef, buffer_size)
p = Progress(buffer_size)
Threads.@threads for i in 1:buffer_size
@@ -70,7 +74,7 @@ function Base.take!(b::BufferedShuffle)
end
end
-function Base.iterate(b::BufferedShuffle, state=nothing)
+function Base.iterate(b::BufferedShuffle, state = nothing)
try
return (popfirst!(b), nothing)
catch e
@@ -104,14 +108,14 @@ Return a RingBuffer that gives batches with the specs in `buffer`.
- `buffer::T`: the type containing the batch.
- `sz::Int`:size of the internal buffers.
"""
-function RingBuffer(f!, buffer::T;sz=Threads.nthreads(), taskref=nothing) where T
+function RingBuffer(f!, buffer::T; sz = Threads.nthreads(), taskref = nothing) where {T}
buffers = Channel{T}(sz)
for _ in 1:sz
put!(buffers, deepcopy(buffer))
end
- results = Channel{T}(sz, spawn=true, taskref=taskref) do ch
- Threads.foreach(buffers;schedule=Threads.StaticSchedule()) do x
- # for x in buffers
+ results = Channel{T}(sz, spawn = true, taskref = taskref) do ch
+ Threads.foreach(buffers; schedule = Threads.StaticSchedule()) do x
+ # for x in buffers
f!(x) # in-place operation
put!(ch, x)
end
diff --git a/src/ReinforcementLearningDatasets/test/atari_dataset.jl b/src/ReinforcementLearningDatasets/test/atari_dataset.jl
index eb6c03175..b89c63ef3 100644
--- a/src/ReinforcementLearningDatasets/test/atari_dataset.jl
+++ b/src/ReinforcementLearningDatasets/test/atari_dataset.jl
@@ -13,11 +13,11 @@ rng = StableRNG(123)
"pong",
index,
epochs;
- repo="atari-replay-datasets",
+ repo = "atari-replay-datasets",
style = style,
rng = rng,
is_shuffle = true,
- batch_size = batch_size
+ batch_size = batch_size,
)
data_dict = ds.dataset
@@ -64,11 +64,11 @@ end
"pong",
index,
epochs;
- repo="atari-replay-datasets",
+ repo = "atari-replay-datasets",
style = style,
rng = rng,
is_shuffle = false,
- batch_size = batch_size
+ batch_size = batch_size,
)
data_dict = ds.dataset
@@ -118,4 +118,4 @@ end
@test data_dict[:reward][batch_size+1:batch_size*2] == iter2[:reward]
@test data_dict[:terminal][batch_size+1:batch_size*2] == iter2[:terminal]
@test data_dict[:state][:, :, batch_size+2:batch_size*2+1] == iter2[:next_state]
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/test/bsuite.jl b/src/ReinforcementLearningDatasets/test/bsuite.jl
index 0bceeb7c0..36fd647e0 100644
--- a/src/ReinforcementLearningDatasets/test/bsuite.jl
+++ b/src/ReinforcementLearningDatasets/test/bsuite.jl
@@ -10,13 +10,13 @@
tf_reader_bufsize = 10_000,
tf_reader_sz = 10_000,
batch_size = 256,
- n_preallocations = Threads.nthreads() * 12
+ n_preallocations = Threads.nthreads() * 12,
)
- @test typeof(ds)<:RingBuffer
+ @test typeof(ds) <: RingBuffer
s_size = 6
-
+
data_1 = take!(ds)
@test size(data_1.state) == (s_size, batch_size)
@@ -25,11 +25,11 @@
@test size(data_1.reward) == (batch_size,)
@test size(data_1.terminal) == (batch_size,)
- @test typeof(data_1.state) == Array{Float32, 2}
- @test typeof(data_1.next_state) == Array{Float32, 2}
- @test typeof(data_1.action) == Array{Int, 1}
- @test typeof(data_1.reward) == Array{Float32, 1}
- @test typeof(data_1.terminal) == Array{Bool, 1}
+ @test typeof(data_1.state) == Array{Float32,2}
+ @test typeof(data_1.next_state) == Array{Float32,2}
+ @test typeof(data_1.action) == Array{Int,1}
+ @test typeof(data_1.reward) == Array{Float32,1}
+ @test typeof(data_1.terminal) == Array{Bool,1}
end
@@ -44,13 +44,13 @@
tf_reader_bufsize = 10_000,
tf_reader_sz = 10_000,
batch_size = 256,
- n_preallocations = Threads.nthreads() * 12
+ n_preallocations = Threads.nthreads() * 12,
)
- @test typeof(ds)<:RingBuffer
+ @test typeof(ds) <: RingBuffer
s_size = 6
-
+
data_1 = take!(ds)
@test size(data_1.state) == (s_size, batch_size)
@@ -59,11 +59,11 @@
@test size(data_1.reward) == (batch_size,)
@test size(data_1.terminal) == (batch_size,)
- @test typeof(data_1.state) == Array{Float32, 2}
- @test typeof(data_1.next_state) == Array{Float32, 2}
- @test typeof(data_1.action) == Array{Int, 1}
- @test typeof(data_1.reward) == Array{Float32, 1}
- @test typeof(data_1.terminal) == Array{Bool, 1}
+ @test typeof(data_1.state) == Array{Float32,2}
+ @test typeof(data_1.next_state) == Array{Float32,2}
+ @test typeof(data_1.action) == Array{Int,1}
+ @test typeof(data_1.reward) == Array{Float32,1}
+ @test typeof(data_1.terminal) == Array{Bool,1}
end
@@ -78,13 +78,13 @@
tf_reader_bufsize = 10_000,
tf_reader_sz = 10_000,
batch_size = 256,
- n_preallocations = Threads.nthreads() * 12
+ n_preallocations = Threads.nthreads() * 12,
)
- @test typeof(ds)<:RingBuffer
+ @test typeof(ds) <: RingBuffer
s_size = (10, 5)
-
+
data_1 = take!(ds)
@test size(data_1.state) == (s_size[1], s_size[2], batch_size)
@@ -93,11 +93,11 @@
@test size(data_1.reward) == (batch_size,)
@test size(data_1.terminal) == (batch_size,)
- @test typeof(data_1.state) == Array{Float32, 3}
- @test typeof(data_1.next_state) == Array{Float32, 3}
- @test typeof(data_1.action) == Array{Int, 1}
- @test typeof(data_1.reward) == Array{Float32, 1}
- @test typeof(data_1.terminal) == Array{Bool, 1}
+ @test typeof(data_1.state) == Array{Float32,3}
+ @test typeof(data_1.next_state) == Array{Float32,3}
+ @test typeof(data_1.action) == Array{Int,1}
+ @test typeof(data_1.reward) == Array{Float32,1}
+ @test typeof(data_1.terminal) == Array{Bool,1}
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl b/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl
index a146647c7..245bfae92 100644
--- a/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl
+++ b/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl
@@ -2,11 +2,11 @@ using Base: batch_size_err_str
@testset "d4rl_pybullet" begin
ds = dataset(
"hopper-bullet-mixed-v0";
- repo="d4rl-pybullet",
+ repo = "d4rl-pybullet",
style = style,
rng = rng,
is_shuffle = true,
- batch_size = batch_size
+ batch_size = batch_size,
)
n_s = 15
@@ -23,10 +23,12 @@ using Base: batch_size_err_str
for sample in Iterators.take(ds, 3)
@test typeof(sample) <: NamedTuple{SARTS}
- @test size(sample[:state]) == (n_s, batch_size)
- @test size(sample[:action]) == (n_a, batch_size)
- @test size(sample[:reward]) == (1, batch_size) || size(sample[:reward]) == (batch_size,)
- @test size(sample[:terminal]) == (1, batch_size) || size(sample[:terminal]) == (batch_size,)
+ @test size(sample[:state]) == (n_s, batch_size)
+ @test size(sample[:action]) == (n_a, batch_size)
+ @test size(sample[:reward]) == (1, batch_size) ||
+ size(sample[:reward]) == (batch_size,)
+ @test size(sample[:terminal]) == (1, batch_size) ||
+ size(sample[:terminal]) == (batch_size,)
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/test/dataset.jl b/src/ReinforcementLearningDatasets/test/dataset.jl
index 9a726c6ab..a2ac0f3a9 100644
--- a/src/ReinforcementLearningDatasets/test/dataset.jl
+++ b/src/ReinforcementLearningDatasets/test/dataset.jl
@@ -8,11 +8,11 @@ rng = MersenneTwister(123)
@testset "dataset_shuffle" begin
ds = dataset(
"hopper-medium-replay-v0";
- repo="d4rl",
+ repo = "d4rl",
style = style,
rng = rng,
is_shuffle = true,
- batch_size = batch_size
+ batch_size = batch_size,
)
data_dict = ds.dataset
@@ -58,7 +58,7 @@ end
style = style,
rng = rng,
is_shuffle = false,
- batch_size = batch_size
+ batch_size = batch_size,
)
data_dict = ds.dataset
diff --git a/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl b/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl
index 39afb2b32..6105513de 100644
--- a/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl
+++ b/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl
@@ -6,7 +6,7 @@ using UnicodePlots
@testset "d4rl_policies" begin
model = d4rl_policy("ant", "online", 10)
- @test typeof(model) <: D4RLGaussianNetwork
+ @test typeof(model) <: D4RLGaussianNetwork
env = GymEnv("ant-medium-v0")
@@ -16,6 +16,6 @@ using UnicodePlots
end
@testset "d4rl_policy_evaluate" begin
- plt = deep_ope_d4rl_evaluate("halfcheetah", "online", 10; num_evaluations=100)
+ plt = deep_ope_d4rl_evaluate("halfcheetah", "online", 10; num_evaluations = 100)
@test typeof(plt) <: UnicodePlots.Plot
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl b/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl
index 949a22b83..d2b01028f 100644
--- a/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl
+++ b/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl
@@ -4,13 +4,13 @@
1,
[1, 2];
shuffle_buffer_size = 10_000,
- tf_reader_bufsize = 1*1024*1024,
+ tf_reader_bufsize = 1 * 1024 * 1024,
tf_reader_sz = 10_000,
batch_size = 256,
- n_preallocations = Threads.nthreads() * 12
+ n_preallocations = Threads.nthreads() * 12,
)
- @test typeof(ds)<:RingBuffer
+ @test typeof(ds) <: RingBuffer
data_1 = take!(ds)
@@ -26,8 +26,8 @@
@test size(data_1.episode_id) == (batch_size,)
@test size(data_1.episode_return) == (batch_size,)
- @test typeof(data_1.state) == Array{UInt8, 4}
- @test typeof(data_1.next_state) == Array{UInt8, 4}
+ @test typeof(data_1.state) == Array{UInt8,4}
+ @test typeof(data_1.next_state) == Array{UInt8,4}
@test typeof(data_1.action) == Vector{Int64}
@test typeof(data_1.next_action) == Vector{Int64}
@test typeof(data_1.reward) == Vector{Float32}
@@ -39,4 +39,4 @@
take!(ds)
data_2 = take!(ds)
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl b/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl
index 897ce16dd..5c3c7ded7 100644
--- a/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl
+++ b/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl
@@ -4,37 +4,40 @@ using Base.Threads
ds = rl_unplugged_dm_dataset(
"fish_swim",
[1, 2];
- type="dm_control_suite",
+ type = "dm_control_suite",
is_shuffle = true,
- shuffle_buffer_size=10_000,
- tf_reader_bufsize=10_000,
- tf_reader_sz=10_000,
- batch_size=256,
- n_preallocations=nthreads()*12
+ shuffle_buffer_size = 10_000,
+ tf_reader_bufsize = 10_000,
+ tf_reader_sz = 10_000,
+ batch_size = 256,
+ n_preallocations = nthreads() * 12,
)
- @test typeof(ds)<:RingBuffer
+ @test typeof(ds) <: RingBuffer
data = take!(ds)
-
+
batch_size = 256
feature_size = ReinforcementLearningDatasets.DM_CONTROL_SUITE_SIZE["fish_swim"]
-
+
@test typeof(data.state) <: NamedTuple
@test typeof(data.next_state) <: NamedTuple
-
+
for feature in keys(feature_size)
if split(feature, "/")[1] != "observation"
if feature != "step_type"
ob_key = Symbol(feature)
- @test size(getfield(data, ob_key)) == (feature_size[feature]..., batch_size,)
+ @test size(getfield(data, ob_key)) ==
+ (feature_size[feature]..., batch_size)
end
else
state = data.state
next_state = data.next_state
- ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0))
- @test size(getfield(state, ob_key)) == (feature_size[feature]...,batch_size)
- @test size(getfield(next_state, ob_key)) == (feature_size[feature]..., batch_size,)
+ ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0))
+ @test size(getfield(state, ob_key)) ==
+ (feature_size[feature]..., batch_size)
+ @test size(getfield(next_state, ob_key)) ==
+ (feature_size[feature]..., batch_size)
end
end
end
@@ -43,37 +46,40 @@ using Base.Threads
ds = rl_unplugged_dm_dataset(
"humanoid_corridor",
[1, 2];
- type="dm_locomotion_humanoid",
+ type = "dm_locomotion_humanoid",
is_shuffle = true,
- shuffle_buffer_size=10_000,
- tf_reader_bufsize=10_000,
- tf_reader_sz=10_000,
- batch_size=256,
- n_preallocations=nthreads()*12
+ shuffle_buffer_size = 10_000,
+ tf_reader_bufsize = 10_000,
+ tf_reader_sz = 10_000,
+ batch_size = 256,
+ n_preallocations = nthreads() * 12,
)
- @test typeof(ds)<:RingBuffer
+ @test typeof(ds) <: RingBuffer
data = take!(ds)
-
+
batch_size = 256
feature_size = ReinforcementLearningDatasets.DM_LOCOMOTION_HUMANOID_SIZE
-
+
@test typeof(data.state) <: NamedTuple
@test typeof(data.next_state) <: NamedTuple
-
+
for feature in keys(feature_size)
if split(feature, "/")[1] != "observation"
if feature != "step_type"
ob_key = Symbol(feature)
- @test size(getfield(data, ob_key)) == (feature_size[feature]..., batch_size,)
+ @test size(getfield(data, ob_key)) ==
+ (feature_size[feature]..., batch_size)
end
else
state = data.state
next_state = data.next_state
- ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0))
- @test size(getfield(state, ob_key)) == (feature_size[feature]..., batch_size,)
- @test size(getfield(next_state, ob_key)) == (feature_size[feature]..., batch_size,)
+ ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0))
+ @test size(getfield(state, ob_key)) ==
+ (feature_size[feature]..., batch_size)
+ @test size(getfield(next_state, ob_key)) ==
+ (feature_size[feature]..., batch_size)
end
end
end
@@ -82,38 +88,41 @@ using Base.Threads
ds = rl_unplugged_dm_dataset(
"rodent_escape",
[1, 2];
- type="dm_locomotion_rodent",
+ type = "dm_locomotion_rodent",
is_shuffle = true,
- shuffle_buffer_size=10_000,
- tf_reader_bufsize=10_000,
- tf_reader_sz=10_000,
- batch_size=256,
- n_preallocations=nthreads()*12
+ shuffle_buffer_size = 10_000,
+ tf_reader_bufsize = 10_000,
+ tf_reader_sz = 10_000,
+ batch_size = 256,
+ n_preallocations = nthreads() * 12,
)
- @test typeof(ds)<:RingBuffer
+ @test typeof(ds) <: RingBuffer
data = take!(ds)
-
+
batch_size = 256
feature_size = ReinforcementLearningDatasets.DM_LOCOMOTION_RODENT_SIZE
-
+
@test typeof(data.state) <: NamedTuple
@test typeof(data.next_state) <: NamedTuple
-
+
for feature in keys(feature_size)
if split(feature, "/")[1] != "observation"
if feature != "step_type"
ob_key = Symbol(feature)
- @test size(getfield(data, ob_key)) == (feature_size[feature]..., batch_size,)
+ @test size(getfield(data, ob_key)) ==
+ (feature_size[feature]..., batch_size)
end
else
state = data.state
next_state = data.next_state
- ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0))
- @test size(getfield(state, ob_key)) == (feature_size[feature]..., batch_size,)
- @test size(getfield(next_state, ob_key)) == (feature_size[feature]..., batch_size,)
+ ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0))
+ @test size(getfield(state, ob_key)) ==
+ (feature_size[feature]..., batch_size)
+ @test size(getfield(next_state, ob_key)) ==
+ (feature_size[feature]..., batch_size)
end
end
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl
index f3d2dd7e9..dfd1f0ee8 100644
--- a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl
+++ b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl
@@ -4,7 +4,7 @@ using ReinforcementLearningBase
using Random
using Requires
using IntervalSets
-using Base.Threads:@spawn
+using Base.Threads: @spawn
using Markdown
const RLEnvs = ReinforcementLearningEnvironments
@@ -31,9 +31,7 @@ function __init__()
@require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" include(
"environments/3rd_party/AcrobotEnv.jl",
)
- @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include(
- "plots.jl",
- )
+ @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plots.jl")
end
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl
index c371bb80b..eac7957c7 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl
@@ -18,23 +18,23 @@ AcrobotEnv(;kwargs...)
- `avail_torque = [T(-1.), T(0.), T(1.)]`
"""
function AcrobotEnv(;
- T=Float64,
- link_length_a=T(1.0),
- link_length_b=T(1.0),
- link_mass_a=T(1.0),
- link_mass_b=T(1.0),
- link_com_pos_a=T(0.5),
- link_com_pos_b=T(0.5),
- link_moi=T(1.0),
- max_torque_noise=T(0.0),
- max_vel_a=T(4 * π),
- max_vel_b=T(9 * π),
- g=T(9.8),
- dt=T(0.2),
- max_steps=200,
- rng=Random.GLOBAL_RNG,
- book_or_nips="book",
- avail_torque=[T(-1.0), T(0.0), T(1.0)],
+ T = Float64,
+ link_length_a = T(1.0),
+ link_length_b = T(1.0),
+ link_mass_a = T(1.0),
+ link_mass_b = T(1.0),
+ link_com_pos_a = T(0.5),
+ link_com_pos_b = T(0.5),
+ link_moi = T(1.0),
+ max_torque_noise = T(0.0),
+ max_vel_a = T(4 * π),
+ max_vel_b = T(9 * π),
+ g = T(9.8),
+ dt = T(0.2),
+ max_steps = 200,
+ rng = Random.GLOBAL_RNG,
+ book_or_nips = "book",
+ avail_torque = [T(-1.0), T(0.0), T(1.0)],
)
params = AcrobotEnvParams{T}(
@@ -81,7 +81,7 @@ RLBase.is_terminated(env::AcrobotEnv) = env.done
RLBase.state(env::AcrobotEnv) = acrobot_observation(env.state)
RLBase.reward(env::AcrobotEnv) = env.reward
-function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number}
+function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number}
env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05)
env.t = 0
env.action = 2
@@ -91,7 +91,7 @@ function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number}
end
# governing equations as per python gym
-function (env::AcrobotEnv{T})(a) where {T <: Number}
+function (env::AcrobotEnv{T})(a) where {T<:Number}
env.action = a
env.t += 1
torque = env.avail_torque[a]
@@ -137,7 +137,7 @@ function dsdt(du, s_augmented, env::AcrobotEnv, t)
# extract action and state
a = s_augmented[end]
- s = s_augmented[1:end - 1]
+ s = s_augmented[1:end-1]
# writing in standard form
theta1 = s[1]
@@ -201,7 +201,7 @@ function wrap(x, m, M)
while x < m
x = x + diff
end
-return x
+ return x
end
function bound(x, m, M)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl
index 33ac9ea86..95a1b3e8e 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl
@@ -1,13 +1,15 @@
using .PyCall
-function GymEnv(name::String; seed::Union{Int, Nothing}=nothing)
+function GymEnv(name::String; seed::Union{Int,Nothing} = nothing)
if !PyCall.pyexists("gym")
error(
"Cannot import module 'gym'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n",
)
end
gym = pyimport_conda("gym", "gym")
- if PyCall.pyexists("d4rl") pyimport("d4rl") end
+ if PyCall.pyexists("d4rl")
+ pyimport("d4rl")
+ end
pyenv = try
gym.make(name)
catch e
@@ -15,7 +17,9 @@ function GymEnv(name::String; seed::Union{Int, Nothing}=nothing)
"Gym environment $name not found.\n\nRun `ReinforcementLearningEnvironments.list_gym_env_names()` to find supported environments.\n",
)
end
- if seed !== nothing pyenv.seed(seed) end
+ if seed !== nothing
+ pyenv.seed(seed)
+ end
obs_space = space_transform(pyenv.observation_space)
act_space = space_transform(pyenv.action_space)
obs_type = if obs_space isa Space{<:Union{Array{<:Interval},Array{<:ZeroTo}}}
@@ -139,8 +143,10 @@ function list_gym_env_names(;
"d4rl.gym_bullet.gym_envs",
"d4rl.pointmaze_bullet.bullet_maze", # yet to include flow and carla
],
-)
- if PyCall.pyexists("d4rl") pyimport("d4rl") end
+)
+ if PyCall.pyexists("d4rl")
+ pyimport("d4rl")
+ end
gym = pyimport("gym")
[x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules]
end
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl
index c325b65d6..de8c15f14 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl
@@ -44,7 +44,7 @@ import .OpenSpiel:
`True` or `False` (instead of `true` or `false`). Another approach is to just
specify parameters in `kwargs` in the Julia style.
"""
-function OpenSpielEnv(name="kuhn_poker"; kwargs...)
+function OpenSpielEnv(name = "kuhn_poker"; kwargs...)
game = load_game(String(name); kwargs...)
state = new_initial_state(game)
OpenSpielEnv(state, game)
@@ -60,7 +60,7 @@ RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state)
RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER)
function RLBase.players(env::OpenSpielEnv)
- p = 0:(num_players(env.game) - 1)
+ p = 0:(num_players(env.game)-1)
if ChanceStyle(env) === EXPLICIT_STOCHASTIC
(p..., RLBase.chance_player(env))
else
@@ -91,7 +91,7 @@ function RLBase.prob(env::OpenSpielEnv, player)
# @assert player == chance_player(env)
p = zeros(length(action_space(env)))
for (k, v) in chance_outcomes(env.state)
- p[k + 1] = v
+ p[k+1] = v
end
p
end
@@ -102,7 +102,7 @@ function RLBase.legal_action_space_mask(env::OpenSpielEnv, player)
num_distinct_actions(env.game)
mask = BitArray(undef, n)
for a in legal_actions(env.state, player)
- mask[a + 1] = true
+ mask[a+1] = true
end
mask
end
@@ -138,12 +138,16 @@ end
_state(env::OpenSpielEnv, ::RLBase.InformationSet{String}, player) =
information_state_string(env.state, player)
-_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) =
- reshape(information_state_tensor(env.state, player), reverse(information_state_tensor_shape(env.game))...)
+_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) = reshape(
+ information_state_tensor(env.state, player),
+ reverse(information_state_tensor_shape(env.game))...,
+)
_state(env::OpenSpielEnv, ::Observation{String}, player) =
observation_string(env.state, player)
-_state(env::OpenSpielEnv, ::Observation{Array}, player) =
- reshape(observation_tensor(env.state, player), reverse(observation_tensor_shape(env.game))...)
+_state(env::OpenSpielEnv, ::Observation{Array}, player) = reshape(
+ observation_tensor(env.state, player),
+ reverse(observation_tensor_shape(env.game))...,
+)
RLBase.state_space(
env::OpenSpielEnv,
@@ -151,16 +155,18 @@ RLBase.state_space(
p,
) = WorldSpace{AbstractString}()
-RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array},
- p,
-) = Space(
- fill(typemin(Float64)..typemax(Float64), reverse(information_state_tensor_shape(env.game))...),
+RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array}, p) = Space(
+ fill(
+ typemin(Float64) .. typemax(Float64),
+ reverse(information_state_tensor_shape(env.game))...,
+ ),
)
-RLBase.state_space(env::OpenSpielEnv, ::Observation{Array},
- p,
-) = Space(
- fill(typemin(Float64)..typemax(Float64), reverse(observation_tensor_shape(env.game))...),
+RLBase.state_space(env::OpenSpielEnv, ::Observation{Array}, p) = Space(
+ fill(
+ typemin(Float64) .. typemax(Float64),
+ reverse(observation_tensor_shape(env.game))...,
+ ),
)
Random.seed!(env::OpenSpielEnv, s) = @warn "seed!(OpenSpielEnv) is not supported currently."
@@ -201,7 +207,9 @@ RLBase.RewardStyle(env::OpenSpielEnv) =
reward_model(get_type(env.game)) == OpenSpiel.REWARDS ? RLBase.STEP_REWARD :
RLBase.TERMINAL_REWARD
-RLBase.StateStyle(env::OpenSpielEnv) = (RLBase.InformationSet{String}(),
+RLBase.StateStyle(env::OpenSpielEnv) = (
+ RLBase.InformationSet{String}(),
RLBase.InformationSet{Array}(),
RLBase.Observation{String}(),
- RLBase.Observation{Array}(),)
+ RLBase.Observation{Array}(),
+)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl
index df3a9e87f..02cfdc201 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl
@@ -43,7 +43,7 @@ end
RLBase.action_space(env::SnakeGameEnv) = 1:4
RLBase.state(env::SnakeGameEnv) = env.game.board
-RLBase.state_space(env::SnakeGameEnv) = Space(fill(false..true, size(env.game.board)))
+RLBase.state_space(env::SnakeGameEnv) = Space(fill(false .. true, size(env.game.board)))
RLBase.reward(env::SnakeGameEnv{<:Any,SINGLE_AGENT}) =
length(env.game.snakes[]) - env.latest_snakes_length[]
RLBase.reward(env::SnakeGameEnv) = length.(env.game.snakes) .- env.latest_snakes_length
diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl
index 83586f4e3..0acd51427 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl
@@ -6,7 +6,7 @@ struct GymEnv{T,Ta,To,P} <: AbstractEnv
end
export GymEnv
-mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S <: AbstractRNG} <: AbstractEnv
+mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv
ale::Ptr{Nothing}
name::String
screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling
@@ -65,7 +65,7 @@ end
export AcrobotEnvParams
-mutable struct AcrobotEnv{T,R <: AbstractRNG} <: AbstractEnv
+mutable struct AcrobotEnv{T,R<:AbstractRNG} <: AbstractEnv
params::AcrobotEnvParams{T}
state::Vector{T}
action::Int
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl
index 2c491bf63..74d0056c1 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl
@@ -41,8 +41,8 @@ end
RLBase.state(env::BitFlippingEnv) = state(env::BitFlippingEnv, Observation{BitArray{1}}())
RLBase.state(env::BitFlippingEnv, ::Observation) = env.state
RLBase.state(env::BitFlippingEnv, ::GoalState) = env.goal_state
-RLBase.state_space(env::BitFlippingEnv, ::Observation) = Space(fill(false..true, env.N))
-RLBase.state_space(env::BitFlippingEnv, ::GoalState) = Space(fill(false..true, env.N))
+RLBase.state_space(env::BitFlippingEnv, ::Observation) = Space(fill(false .. true, env.N))
+RLBase.state_space(env::BitFlippingEnv, ::GoalState) = Space(fill(false .. true, env.N))
RLBase.is_terminated(env::BitFlippingEnv) =
(env.state == env.goal_state) || (env.t >= env.max_steps)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl
index fd8c5af49..dd0cc04e9 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl
@@ -5,7 +5,7 @@ using SparseArrays
using LinearAlgebra
-mutable struct GraphShortestPathEnv{G, R} <: AbstractEnv
+mutable struct GraphShortestPathEnv{G,R} <: AbstractEnv
graph::G
pos::Int
goal::Int
@@ -31,7 +31,12 @@ Quoted **A.3** in the the paper [Decision Transformer: Reinforcement Learning vi
> lengths and maximizing them corresponds to generating shortest paths.
"""
-function GraphShortestPathEnv(rng=Random.GLOBAL_RNG; n=20, sparsity=0.1, max_steps=10)
+function GraphShortestPathEnv(
+ rng = Random.GLOBAL_RNG;
+ n = 20,
+ sparsity = 0.1,
+ max_steps = 10,
+)
graph = sprand(rng, Bool, n, n, sparsity) .| I(n)
goal = rand(rng, 1:n)
@@ -55,7 +60,8 @@ RLBase.state_space(env::GraphShortestPathEnv) = axes(env.graph, 2)
RLBase.action_space(env::GraphShortestPathEnv) = axes(env.graph, 2)
RLBase.legal_action_space(env::GraphShortestPathEnv) = (env.graph[:, env.pos]).nzind
RLBase.reward(env::GraphShortestPathEnv) = env.reward
-RLBase.is_terminated(env::GraphShortestPathEnv) = env.pos == env.goal || env.step >= env.max_steps
+RLBase.is_terminated(env::GraphShortestPathEnv) =
+ env.pos == env.goal || env.step >= env.max_steps
function RLBase.reset!(env::GraphShortestPathEnv)
env.step = 0
@@ -144,4 +150,4 @@ barplot(1:10, [sum(h[1].steps .== i) for i in 1:10]) # random walk
# 10 ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 769
# └ ┘
#
-=#
\ No newline at end of file
+=#
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl
index 5cafab01d..18094e54a 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl
@@ -77,7 +77,7 @@ function MountainCarEnv(;
env = MountainCarEnv(
params,
action_space,
- Space([params.min_pos..params.max_pos, -params.max_speed..params.max_speed]),
+ Space([params.min_pos .. params.max_pos, -params.max_speed .. params.max_speed]),
zeros(T, 2),
rand(action_space),
false,
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl
index 72dea511c..bacdc04e1 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl
@@ -53,7 +53,7 @@ function PendulumEnv(;
rng = Random.GLOBAL_RNG,
)
high = T.([1, 1, max_speed])
- action_space = continuous ? -2.0..2.0 : Base.OneTo(n_actions)
+ action_space = continuous ? -2.0 .. 2.0 : Base.OneTo(n_actions)
env = PendulumEnv(
PendulumEnvParams(max_speed, max_torque, g, m, l, dt, max_steps),
action_space,
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl
index 1026ac8cc..0f38fb624 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl
@@ -36,7 +36,7 @@ RLBase.prob(env::PigEnv, ::ChancePlayer) = fill(1 / 6, 6) # TODO: uniform distr
RLBase.state(env::PigEnv, ::Observation{Vector{Int}}, p) = env.scores
RLBase.state_space(env::PigEnv, ::Observation, p) =
- Space([0..(PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores])
+ Space([0 .. (PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores])
RLBase.is_terminated(env::PigEnv) = any(s >= PIG_TARGET_SCORE for s in env.scores)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl
index f4b9572b8..f42d7217b 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl
@@ -7,9 +7,9 @@ mutable struct SpeakerListenerEnv{T<:Vector{Float64}} <: AbstractEnv
player_pos::T
landmarks_pos::Vector{T}
landmarks_num::Int
- ϵ
- damping
- max_accel
+ ϵ::Any
+ damping::Any
+ max_accel::Any
space_dim::Int
init_step::Int
play_step::Int
@@ -46,7 +46,8 @@ function SpeakerListenerEnv(;
max_accel = 0.5,
space_dim::Int = 2,
max_steps::Int = 50,
- continuous::Bool = true)
+ continuous::Bool = true,
+)
SpeakerListenerEnv(
zeros(N),
zeros(N),
@@ -74,21 +75,24 @@ function RLBase.reset!(env::SpeakerListenerEnv)
env.landmarks_pos = [zeros(env.space_dim) for _ in Base.OneTo(env.landmarks_num)]
end
-RLBase.is_terminated(env::SpeakerListenerEnv) = (reward(env) > - env.ϵ) || (env.play_step > env.max_steps)
+RLBase.is_terminated(env::SpeakerListenerEnv) =
+ (reward(env) > -env.ϵ) || (env.play_step > env.max_steps)
RLBase.players(::SpeakerListenerEnv) = (:Speaker, :Listener, CHANCE_PLAYER)
-RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) = Dict(p => state(env, p) for p in players)
+RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) =
+ Dict(p => state(env, p) for p in players)
-RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) =
- # for speaker, it can observe the target and help listener to arrive it.
+RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) =
+# for speaker, it can observe the target and help listener to arrive it.
if player == :Speaker
env.target
- # for listener, it can observe current velocity, relative positions of landmarks, and speaker's conveyed information.
+ # for listener, it can observe current velocity, relative positions of landmarks, and speaker's conveyed information.
elseif player == :Listener
vcat(
env.player_vel...,
(
- vcat((landmark_pos .- env.player_pos)...) for landmark_pos in env.landmarks_pos
+ vcat((landmark_pos .- env.player_pos)...) for
+ landmark_pos in env.landmarks_pos
)...,
env.content...,
)
@@ -96,47 +100,60 @@ RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) =
@error "No player $player."
end
-RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) = vcat(env.landmarks_pos, [env.player_pos])
+RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) =
+ vcat(env.landmarks_pos, [env.player_pos])
-RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) =
+RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) =
Space(Dict(player => state_space(env, player) for player in players))
-RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) =
+RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) =
if player == :Speaker
# env.target
- Space([[0., 1.] for _ in Base.OneTo(env.landmarks_num)])
+ Space([[0.0, 1.0] for _ in Base.OneTo(env.landmarks_num)])
elseif player == :Listener
- Space(vcat(
- # relative positions of landmarks, no bounds.
- (vcat(
- Space([ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim)])...
- ) for _ in Base.OneTo(env.landmarks_num + 1))...,
- # communication content from `Speaker`
- [[0., 1.] for _ in Base.OneTo(env.landmarks_num)],
- ))
+ Space(
+ vcat(
+ # relative positions of landmarks, no bounds.
+ (
+ vcat(
+ Space([
+ ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim)
+ ])...,
+ ) for _ in Base.OneTo(env.landmarks_num + 1)
+ )...,
+ # communication content from `Speaker`
+ [[0.0, 1.0] for _ in Base.OneTo(env.landmarks_num)],
+ ),
+ )
else
@error "No player $player."
end
-RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) =
- Space(
- vcat(
- # landmarks' positions
- (Space([ClosedInterval(-1, 1) for _ in Base.OneTo(env.space_dim)]) for _ in Base.OneTo(env.landmarks_num))...,
- # player's position, no bounds.
- Space([ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim)]),
- )
- )
-
-RLBase.action_space(env::SpeakerListenerEnv, players::Tuple) =
- Space(Dict(p => action_space(env, p) for p in players))
-
-RLBase.action_space(env::SpeakerListenerEnv, player::Symbol) =
+RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) = Space(
+ vcat(
+ # landmarks' positions
+ (
+ Space([ClosedInterval(-1, 1) for _ in Base.OneTo(env.space_dim)]) for
+ _ in Base.OneTo(env.landmarks_num)
+ )...,
+ # player's position, no bounds.
+ Space([ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim)]),
+ ),
+)
+
+RLBase.action_space(env::SpeakerListenerEnv, players::Tuple) =
+ Space(Dict(p => action_space(env, p) for p in players))
+
+RLBase.action_space(env::SpeakerListenerEnv, player::Symbol) =
if player == :Speaker
- env.continuous ? Space([ClosedInterval(0, 1) for _ in Base.OneTo(env.landmarks_num)]) : Space([ZeroTo(1) for _ in Base.OneTo(env.landmarks_num)])
+ env.continuous ?
+ Space([ClosedInterval(0, 1) for _ in Base.OneTo(env.landmarks_num)]) :
+ Space([ZeroTo(1) for _ in Base.OneTo(env.landmarks_num)])
elseif player == :Listener
# there has two directions in each dimension.
- env.continuous ? Space([ClosedInterval(0, 1) for _ in Base.OneTo(2 * env.space_dim)]) : Space([ZeroTo(1) for _ in Base.OneTo(2 * env.space_dim)])
+ env.continuous ?
+ Space([ClosedInterval(0, 1) for _ in Base.OneTo(2 * env.space_dim)]) :
+ Space([ZeroTo(1) for _ in Base.OneTo(2 * env.space_dim)])
else
@error "No player $player."
end
@@ -157,7 +174,7 @@ function (env::SpeakerListenerEnv)(action, ::ChancePlayer)
env.player_pos = action
else
@assert action in Base.OneTo(env.landmarks_num) "The target should be assigned to one of the landmarks."
- env.target[action] = 1.
+ env.target[action] = 1.0
end
end
@@ -176,7 +193,7 @@ function (env::SpeakerListenerEnv)(action::Vector, player::Symbol)
elseif player == :Listener
# update velocity, here env.damping is for simulation physical rule.
action = round.(action)
- acceleration = [action[2 * i] - action[2 * i - 1] for i in Base.OneTo(env.space_dim)]
+ acceleration = [action[2*i] - action[2*i-1] for i in Base.OneTo(env.space_dim)]
env.player_vel .*= (1 - env.damping)
env.player_vel .+= (acceleration * env.max_accel)
# update position
@@ -190,14 +207,14 @@ RLBase.reward(::SpeakerListenerEnv, ::ChancePlayer) = -Inf
function RLBase.reward(env::SpeakerListenerEnv, p)
if sum(env.target) == 1
- goal = findfirst(env.target .== 1.)
+ goal = findfirst(env.target .== 1.0)
-sum((env.landmarks_pos[goal] .- env.player_pos) .^ 2)
else
-Inf
end
end
-RLBase.current_player(env::SpeakerListenerEnv) =
+RLBase.current_player(env::SpeakerListenerEnv) =
if env.init_step < env.landmarks_num + 2
CHANCE_PLAYER
else
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl
index 6a3f6a6f6..0dc94ec4b 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl
@@ -2,12 +2,12 @@ export StockTradingEnv, StockTradingEnvWithTurbulence
using Pkg.Artifacts
using DelimitedFiles
-using LinearAlgebra:dot
+using LinearAlgebra: dot
using IntervalSets
function load_default_stock_data(s)
if s == "prices.csv" || s == "features.csv"
- data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header=true)
+ data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header = true)
collect(data')
elseif s == "turbulence.csv"
readdlm(joinpath(artifact"stock_trading_data", "turbulence.csv")) |> vec
@@ -16,7 +16,8 @@ function load_default_stock_data(s)
end
end
-mutable struct StockTradingEnv{F<:AbstractMatrix{Float64}, P<:AbstractMatrix{Float64}} <: AbstractEnv
+mutable struct StockTradingEnv{F<:AbstractMatrix{Float64},P<:AbstractMatrix{Float64}} <:
+ AbstractEnv
features::F
prices::P
HMAX_NORMALIZE::Float32
@@ -48,14 +49,14 @@ This environment is originally provided in [Deep Reinforcement Learning for Auto
- `initial_account_balance=1_000_000`.
"""
function StockTradingEnv(;
- initial_account_balance=1_000_000f0,
- features=nothing,
- prices=nothing,
- first_day=nothing,
- last_day=nothing,
- HMAX_NORMALIZE = 100f0,
+ initial_account_balance = 1_000_000.0f0,
+ features = nothing,
+ prices = nothing,
+ first_day = nothing,
+ last_day = nothing,
+ HMAX_NORMALIZE = 100.0f0,
TRANSACTION_FEE_PERCENT = 0.001f0,
- REWARD_SCALING = 1f-4
+ REWARD_SCALING = 1f-4,
)
prices = isnothing(prices) ? load_default_stock_data("prices.csv") : prices
features = isnothing(features) ? load_default_stock_data("features.csv") : features
@@ -77,11 +78,11 @@ function StockTradingEnv(;
REWARD_SCALING,
initial_account_balance,
state,
- 0f0,
+ 0.0f0,
day,
first_day,
last_day,
- 0f0
+ 0.0f0,
)
_balance(env)[] = initial_account_balance
@@ -108,10 +109,10 @@ function (env::StockTradingEnv)(actions)
# then buy
# better to shuffle?
- for (i,b) in enumerate(actions)
+ for (i, b) in enumerate(actions)
if b > 0
max_buy = div(_balance(env)[], _prices(env)[i])
- buy = min(b*env.HMAX_NORMALIZE, max_buy)
+ buy = min(b * env.HMAX_NORMALIZE, max_buy)
_holds(env)[i] += buy
deduction = buy * _prices(env)[i]
cost = deduction * env.TRANSACTION_FEE_PERCENT
@@ -136,12 +137,13 @@ function RLBase.reset!(env::StockTradingEnv)
_balance(env)[] = env.initial_account_balance
_prices(env) .= @view env.prices[:, env.day]
_features(env) .= @view env.features[:, env.day]
- env.total_cost = 0.
- env.daily_reward = 0.
+ env.total_cost = 0.0
+ env.daily_reward = 0.0
end
-RLBase.state_space(env::StockTradingEnv) = Space(fill(-Inf32..Inf32, length(state(env))))
-RLBase.action_space(env::StockTradingEnv) = Space(fill(-1f0..1f0, length(_holds(env))))
+RLBase.state_space(env::StockTradingEnv) = Space(fill(-Inf32 .. Inf32, length(state(env))))
+RLBase.action_space(env::StockTradingEnv) =
+ Space(fill(-1.0f0 .. 1.0f0, length(_holds(env))))
RLBase.ChanceStyle(::StockTradingEnv) = DETERMINISTIC
@@ -154,16 +156,16 @@ struct StockTradingEnvWithTurbulence{E<:StockTradingEnv} <: AbstractEnvWrapper
end
function StockTradingEnvWithTurbulence(;
- turbulence_threshold=140.,
- turbulences=nothing,
- kw...
+ turbulence_threshold = 140.0,
+ turbulences = nothing,
+ kw...,
)
turbulences = isnothing(turbulences) && load_default_stock_data("turbulence.csv")
StockTradingEnvWithTurbulence(
- StockTradingEnv(;kw...),
+ StockTradingEnv(; kw...),
turbulences,
- turbulence_threshold
+ turbulence_threshold,
)
end
@@ -172,4 +174,4 @@ function (w::StockTradingEnvWithTurbulence)(actions)
actions .= ifelse.(actions .< 0, -Inf32, 0)
end
w.env(actions)
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl
index 455488619..d8301904e 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl
@@ -76,7 +76,7 @@ RLBase.players(env::TicTacToeEnv) = (CROSS, NOUGHT)
RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board
RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) =
- Space(fill(false..true, 3, 3, 3))
+ Space(fill(false .. true, 3, 3, 3))
RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) =
get_tic_tac_toe_state_info()[env].index
RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) =
diff --git a/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl b/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl
index 8e6aef27d..a1621f1e5 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl
@@ -70,7 +70,7 @@ RLBase.reward(env::PendulumNonInteractiveEnv) = 0
RLBase.is_terminated(env::PendulumNonInteractiveEnv) = env.done
RLBase.state(env::PendulumNonInteractiveEnv) = env.state
RLBase.state_space(env::PendulumNonInteractiveEnv{T}) where {T} =
- Space([typemin(T)..typemax(T), typemin(T)..typemax(T)])
+ Space([typemin(T) .. typemax(T), typemin(T) .. typemax(T)])
function RLBase.reset!(env::PendulumNonInteractiveEnv{Fl}) where {Fl}
env.state .= (Fl(2 * pi) * rand(env.rng, Fl), randn(env.rng, Fl))
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl
index 47efe8628..831f80f7d 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl
@@ -13,15 +13,14 @@ end
`legal_action_space(env)`. `action_mapping` will be applied to `action` before
feeding it into `env`.
"""
-ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) =
+ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) =
ActionTransformedEnv(env, action_mapping, action_space_mapping)
-Base.copy(env::ActionTransformedEnv) =
- ActionTransformedEnv(
- copy(env.env),
- action_mapping = env.action_mapping,
- action_space_mapping = env.action_space_mapping
- )
+Base.copy(env::ActionTransformedEnv) = ActionTransformedEnv(
+ copy(env.env),
+ action_mapping = env.action_mapping,
+ action_space_mapping = env.action_space_mapping,
+)
RLBase.action_space(env::ActionTransformedEnv, args...) =
env.action_space_mapping(action_space(env.env, args...))
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl
index 5aff4a669..d75af9c34 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl
@@ -13,11 +13,12 @@ DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env)
RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S
-Base.copy(env::DefaultStateStyleEnv{S}) where S = DefaultStateStyleEnv{S}(copy(env.env))
+Base.copy(env::DefaultStateStyleEnv{S}) where {S} = DefaultStateStyleEnv{S}(copy(env.env))
-RLBase.state(env::DefaultStateStyleEnv{S}) where S = state(env.env, S)
+RLBase.state(env::DefaultStateStyleEnv{S}) where {S} = state(env.env, S)
RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss)
-RLBase.state(env::DefaultStateStyleEnv{S}, player) where S = state(env.env, S, player)
+RLBase.state(env::DefaultStateStyleEnv{S}, player) where {S} = state(env.env, S, player)
-RLBase.state_space(env::DefaultStateStyleEnv{S}) where S = state_space(env.env, S)
-RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss)
\ No newline at end of file
+RLBase.state_space(env::DefaultStateStyleEnv{S}) where {S} = state_space(env.env, S)
+RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) =
+ state_space(env.env, ss)
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl
index 4f18af426..2a88903dd 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl
@@ -9,7 +9,7 @@ mutable struct SequentialEnv{E<:AbstractEnv} <: AbstractEnvWrapper
env::E
current_player_idx::Int
actions::Vector{Any}
- function SequentialEnv(env::T) where T<:AbstractEnv
+ function SequentialEnv(env::T) where {T<:AbstractEnv}
@assert DynamicStyle(env) === SIMULTANEOUS "The SequentialEnv wrapper can only be applied to SIMULTANEOUS environments"
new{T}(env, 1, Vector{Any}(undef, length(players(env))))
end
@@ -32,7 +32,8 @@ end
RLBase.reward(env::SequentialEnv) = reward(env, current_player(env))
-RLBase.reward(env::SequentialEnv, player) = current_player(env) == 1 ? reward(env.env, player) : 0
+RLBase.reward(env::SequentialEnv, player) =
+ current_player(env) == 1 ? reward(env.env, player) : 0
function (env::SequentialEnv)(action)
env.actions[env.current_player_idx] = action
@@ -43,4 +44,3 @@ function (env::SequentialEnv)(action)
env.current_player_idx += 1
end
end
-
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl
index e8626a3b8..97e18e928 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl
@@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some
environments are stateful during each `state(env)`. For example:
`StateTransformedEnv(StackFrames(...))`.
"""
-mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper
+mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper
s::S
env::E
is_state_cached::Bool
diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl
index dfe90bddd..840c9ecdb 100644
--- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl
+++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl
@@ -12,11 +12,11 @@ end
`state_mapping` will be applied on the original state when calling `state(env)`,
and similarly `state_space_mapping` will be applied when calling `state_space(env)`.
"""
-StateTransformedEnv(env; state_mapping=identity, state_space_mapping=identity) =
+StateTransformedEnv(env; state_mapping = identity, state_space_mapping = identity) =
StateTransformedEnv(env, state_mapping, state_space_mapping)
RLBase.state(env::StateTransformedEnv, args...; kwargs...) =
env.state_mapping(state(env.env, args...; kwargs...))
-RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) =
+RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) =
env.state_space_mapping(state_space(env.env, args...; kwargs...))
diff --git a/src/ReinforcementLearningEnvironments/src/plots.jl b/src/ReinforcementLearningEnvironments/src/plots.jl
index a2c3ce7bc..0e9e5cb61 100644
--- a/src/ReinforcementLearningEnvironments/src/plots.jl
+++ b/src/ReinforcementLearningEnvironments/src/plots.jl
@@ -8,36 +8,36 @@ function plot(env::CartPoleEnv; kwargs...)
xthreshold = env.params.xthreshold
# set the frame
plot(
- xlims=(-xthreshold, xthreshold),
- ylims=(-.1, l + 0.1),
- legend=false,
- border=:none,
+ xlims = (-xthreshold, xthreshold),
+ ylims = (-.1, l + 0.1),
+ legend = false,
+ border = :none,
)
# plot the cart
- plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05];
- seriestype=:shape,
- )
+ plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05]; seriestype = :shape)
# plot the pole
- plot!([x, x + l * sin(theta)], [0, l * cos(theta)];
- linewidth=3,
- )
+ plot!([x, x + l * sin(theta)], [0, l * cos(theta)]; linewidth = 3)
# plot the arrow
- plot!([x + (a == 1) - 0.5, x + 1.4 * (a == 1)-0.7], [ -.025, -.025];
- linewidth=3,
- arrow=true,
- color=2,
+ plot!(
+ [x + (a == 1) - 0.5, x + 1.4 * (a == 1) - 0.7],
+ [-.025, -.025];
+ linewidth = 3,
+ arrow = true,
+ color = 2,
)
# if done plot pink circle in top right
if d
- plot!([xthreshold - 0.2], [l];
- marker=:circle,
- markersize=20,
- markerstrokewidth=0.,
- color=:pink,
+ plot!(
+ [xthreshold - 0.2],
+ [l];
+ marker = :circle,
+ markersize = 20,
+ markerstrokewidth = 0.0,
+ color = :pink,
)
end
-
- plot!(;kwargs...)
+
+ plot!(; kwargs...)
end
@@ -51,10 +51,10 @@ function plot(env::MountainCarEnv; kwargs...)
d = env.done
plot(
- xlims=(env.params.min_pos - 0.1, env.params.max_pos + 0.2),
- ylims=(-.1, height(env.params.max_pos) + 0.2),
- legend=false,
- border=:none,
+ xlims = (env.params.min_pos - 0.1, env.params.max_pos + 0.2),
+ ylims = (-.1, height(env.params.max_pos) + 0.2),
+ legend = false,
+ border = :none,
)
# plot the terrain
xs = LinRange(env.params.min_pos, env.params.max_pos, 100)
@@ -72,17 +72,19 @@ function plot(env::MountainCarEnv; kwargs...)
ys .+= clearance
xs, ys = rotate(xs, ys, θ)
xs, ys = translate(xs, ys, [x, height(x)])
- plot!(xs, ys; seriestype=:shape)
+ plot!(xs, ys; seriestype = :shape)
# if done plot pink circle in top right
if d
- plot!([xthreshold - 0.2], [l];
- marker=:circle,
- markersize=20,
- markerstrokewidth=0.,
- color=:pink,
+ plot!(
+ [xthreshold - 0.2],
+ [l];
+ marker = :circle,
+ markersize = 20,
+ markerstrokewidth = 0.0,
+ color = :pink,
)
end
- plot!(;kwargs...)
- end
+ plot!(; kwargs...)
+end
diff --git a/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl
index c7b3ab787..9c8b44de0 100644
--- a/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl
+++ b/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl
@@ -1,10 +1,6 @@
@testset "gym envs" begin
gym_env_names = ReinforcementLearningEnvironments.list_gym_env_names(
- modules = [
- "gym.envs.algorithmic",
- "gym.envs.classic_control",
- "gym.envs.unittest",
- ],
+ modules = ["gym.envs.algorithmic", "gym.envs.classic_control", "gym.envs.unittest"],
) # mujoco, box2d, robotics are not tested here
for x in gym_env_names
diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl
index a912ecc24..e47ce9047 100644
--- a/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl
+++ b/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl
@@ -6,4 +6,3 @@
RLBase.test_runnable!(env)
end
-
diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl
index 7cb138328..c826e0e4e 100644
--- a/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl
+++ b/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl
@@ -5,4 +5,3 @@
RLBase.test_interfaces!(env)
RLBase.test_runnable!(env)
end
-
diff --git a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl
index 208307a67..049defa22 100644
--- a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl
+++ b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl
@@ -1,11 +1,11 @@
@testset "wrappers" begin
@testset "ActionTransformedEnv" begin
- env = TigerProblemEnv(; rng=StableRNG(123))
+ env = TigerProblemEnv(; rng = StableRNG(123))
env′ = ActionTransformedEnv(
env;
- action_space_mapping=x -> Base.OneTo(3),
- action_mapping=i -> action_space(env)[i],
+ action_space_mapping = x -> Base.OneTo(3),
+ action_mapping = i -> action_space(env)[i],
)
RLBase.test_interfaces!(env′)
@@ -14,7 +14,7 @@
@testset "DefaultStateStyleEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
S = InternalState{Int}()
env′ = DefaultStateStyleEnv{S}(env)
@test DefaultStateStyle(env′) === S
@@ -35,7 +35,7 @@
@testset "MaxTimeoutEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
n = 100
env′ = MaxTimeoutEnv(env, n)
@@ -55,7 +55,7 @@
@testset "RewardOverriddenEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
env′ = RewardOverriddenEnv(env, x -> sign(x))
RLBase.test_interfaces!(env′)
@@ -69,7 +69,7 @@
@testset "StateCachedEnv" begin
rng = StableRNG(123)
- env = CartPoleEnv(; rng=rng)
+ env = CartPoleEnv(; rng = rng)
env′ = StateCachedEnv(env)
RLBase.test_interfaces!(env′)
@@ -85,7 +85,7 @@
@testset "StateTransformedEnv" begin
rng = StableRNG(123)
- env = TigerProblemEnv(; rng=rng)
+ env = TigerProblemEnv(; rng = rng)
# S = (:door1, :door2, :door3, :none)
# env′ = StateTransformedEnv(env, state_mapping=s -> s+1)
# RLBase.state_space(env::typeof(env′), ::RLBase.AbstractStateStyle, ::Any) = S
@@ -97,14 +97,14 @@
@testset "StochasticEnv" begin
env = KuhnPokerEnv()
rng = StableRNG(123)
- env′ = StochasticEnv(env; rng=rng)
+ env′ = StochasticEnv(env; rng = rng)
RLBase.test_interfaces!(env′)
RLBase.test_runnable!(env′)
end
@testset "SequentialEnv" begin
- env = RockPaperScissorsEnv()
+ env = RockPaperScissorsEnv()
env′ = SequentialEnv(env)
RLBase.test_interfaces!(env′)
RLBase.test_runnable!(env′)
diff --git a/src/ReinforcementLearningExperiments/deps/build.jl b/src/ReinforcementLearningExperiments/deps/build.jl
index def42b559..6dfde6249 100644
--- a/src/ReinforcementLearningExperiments/deps/build.jl
+++ b/src/ReinforcementLearningExperiments/deps/build.jl
@@ -2,10 +2,11 @@ using Weave
const DEST_DIR = joinpath(@__DIR__, "..", "src", "experiments")
-for (root, dirs, files) in walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments"))
+for (root, dirs, files) in
+ walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments"))
for f in files
if splitext(f)[2] == ".jl"
- tangle(joinpath(root,f);informat="script", out_path=DEST_DIR)
+ tangle(joinpath(root, f); informat = "script", out_path = DEST_DIR)
end
end
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl
index 0f615ba3f..2a273443a 100644
--- a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl
+++ b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl
@@ -23,7 +23,6 @@ for f in readdir(EXPERIMENTS_DIR)
end
# dynamic loading environments
-function __init__()
-end
+function __init__() end
end # module
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl
index ddd2c6ace..793eaa897 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl
@@ -4,7 +4,10 @@
const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner}
-function RLBase.update!(learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory)
+function RLBase.update!(
+ learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners},
+ t::AbstractTrajectory,
+)
length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return
learner.update_step += 1
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl
index b2ebab93c..ad2808f26 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl
@@ -55,7 +55,7 @@ function DQNLearner(;
traces = SARTS,
update_step = 0,
rng = Random.GLOBAL_RNG,
- is_enable_double_DQN::Bool = true
+ is_enable_double_DQN::Bool = true,
) where {Tq,Tt,Tf}
copyto!(approximator, target_approximator)
sampler = NStepBatchSampler{traces}(;
@@ -75,7 +75,7 @@ function DQNLearner(;
sampler,
rng,
0.0f0,
- is_enable_double_DQN
+ is_enable_double_DQN,
)
end
@@ -117,14 +117,14 @@ function RLBase.update!(learner::DQNLearner, batch::NamedTuple)
else
q_values = Qₜ(s′)
end
-
+
if haskey(batch, :next_legal_actions_mask)
l′ = send_to_device(D, batch[:next_legal_actions_mask])
q_values .+= ifelse.(l′, 0.0f0, typemin(Float32))
end
if is_enable_double_DQN
- selected_actions = dropdims(argmax(q_values, dims=1), dims=1)
+ selected_actions = dropdims(argmax(q_values, dims = 1), dims = 1)
q′ = Qₜ(s′)[selected_actions]
else
q′ = dropdims(maximum(q_values; dims = 1), dims = 1)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
index 4ec47c5ba..8f190ba2c 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl
@@ -5,4 +5,4 @@ include("qr_dqn.jl")
include("rem_dqn.jl")
include("rainbow.jl")
include("iqn.jl")
-include("common.jl")
\ No newline at end of file
+include("common.jl")
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl
index 2832b1905..0f6c1b519 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl
@@ -1,6 +1,6 @@
export QRDQNLearner, quantile_huber_loss
-function quantile_huber_loss(ŷ, y; κ=1.0f0)
+function quantile_huber_loss(ŷ, y; κ = 1.0f0)
N, B = size(y)
Δ = reshape(y, N, 1, B) .- reshape(ŷ, 1, N, B)
abs_error = abs.(Δ)
@@ -8,12 +8,13 @@ function quantile_huber_loss(ŷ, y; κ=1.0f0)
linear = abs_error .- quadratic
huber_loss = 0.5f0 .* quadratic .* quadratic .+ κ .* linear
- cum_prob = send_to_device(device(y), range(0.5f0 / N; length=N, step=1.0f0 / N))
+ cum_prob = send_to_device(device(y), range(0.5f0 / N; length = N, step = 1.0f0 / N))
loss = Zygote.dropgrad(abs.(cum_prob .- (Δ .< 0))) .* huber_loss
- mean(sum(loss;dims=1))
+ mean(sum(loss; dims = 1))
end
-mutable struct QRDQNLearner{Tq <: AbstractApproximator,Tt <: AbstractApproximator,Tf,R} <: AbstractLearner
+mutable struct QRDQNLearner{Tq<:AbstractApproximator,Tt<:AbstractApproximator,Tf,R} <:
+ AbstractLearner
approximator::Tq
target_approximator::Tt
min_replay_history::Int
@@ -51,25 +52,25 @@ See paper: [Distributional Reinforcement Learning with Quantile Regression](http
function QRDQNLearner(;
approximator,
target_approximator,
- stack_size::Union{Int,Nothing}=nothing,
- γ::Float32=0.99f0,
- batch_size::Int=32,
- update_horizon::Int=1,
- min_replay_history::Int=32,
- update_freq::Int=1,
- n_quantile::Int=1,
- target_update_freq::Int=100,
- traces=SARTS,
- update_step=0,
- loss_func=quantile_huber_loss,
- rng=Random.GLOBAL_RNG
+ stack_size::Union{Int,Nothing} = nothing,
+ γ::Float32 = 0.99f0,
+ batch_size::Int = 32,
+ update_horizon::Int = 1,
+ min_replay_history::Int = 32,
+ update_freq::Int = 1,
+ n_quantile::Int = 1,
+ target_update_freq::Int = 100,
+ traces = SARTS,
+ update_step = 0,
+ loss_func = quantile_huber_loss,
+ rng = Random.GLOBAL_RNG,
)
copyto!(approximator, target_approximator)
sampler = NStepBatchSampler{traces}(;
- γ=γ,
- n=update_horizon,
- stack_size=stack_size,
- batch_size=batch_size,
+ γ = γ,
+ n = update_horizon,
+ stack_size = stack_size,
+ batch_size = batch_size,
)
N = n_quantile
@@ -100,7 +101,7 @@ function (learner::QRDQNLearner)(env)
s = send_to_device(device(learner.approximator), state(env))
s = Flux.unsqueeze(s, ndims(s) + 1)
q = reshape(learner.approximator(s), learner.n_quantile, :)
- vec(mean(q, dims=1)) |> send_to_host
+ vec(mean(q, dims = 1)) |> send_to_host
end
function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple)
@@ -117,10 +118,12 @@ function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple)
a = CartesianIndex.(a, 1:batch_size)
target_quantiles = reshape(Qₜ(s′), N, :, batch_size)
- qₜ = dropdims(mean(target_quantiles; dims=1); dims=1)
- aₜ = dropdims(argmax(qₜ, dims=1); dims=1)
+ qₜ = dropdims(mean(target_quantiles; dims = 1); dims = 1)
+ aₜ = dropdims(argmax(qₜ, dims = 1); dims = 1)
@views target_quantile_aₜ = target_quantiles[:, aₜ]
- y = reshape(r, 1, batch_size) .+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ
+ y =
+ reshape(r, 1, batch_size) .+
+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ
gs = gradient(params(Q)) do
q = reshape(Q(s), N, :, batch_size)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl
index 182ce253f..08f270429 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl
@@ -120,7 +120,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
target_q = Qₜ(s′)
target_q = convex_polygon .* reshape(target_q, :, ensemble_num, batch_size)
- target_q = dropdims(sum(target_q, dims=2), dims=2)
+ target_q = dropdims(sum(target_q, dims = 2), dims = 2)
if haskey(batch, :next_legal_actions_mask)
l′ = send_to_device(D, batch[:next_legal_actions_mask])
@@ -133,7 +133,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
gs = gradient(params(Q)) do
q = Q(s)
q = convex_polygon .* reshape(q, :, ensemble_num, batch_size)
- q = dropdims(sum(q, dims=2), dims=2)[a]
+ q = dropdims(sum(q, dims = 2), dims = 2)[a]
loss = loss_func(G, q)
ignore() do
@@ -143,5 +143,4 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple)
end
update!(Q, gs)
-end
-
+end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl
index c7a270f3d..d01b0e44a 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl
@@ -24,7 +24,7 @@ performs the following update for each player:
[Computing Approximate Equilibria in Sequential Adversarial Games by Exploitability Descent](https://arxiv.org/abs/1903.05614)
"""
-mutable struct EDPolicy{P<:NeuralNetworkApproximator, E<:AbstractExplorer}
+mutable struct EDPolicy{P<:NeuralNetworkApproximator,E<:AbstractExplorer}
opponent::Any
learner::P
explorer::E
@@ -40,16 +40,16 @@ function (π::EDPolicy)(env::AbstractEnv)
s = state(env)
s = send_to_device(device(π.learner), Flux.unsqueeze(s, ndims(s) + 1))
logits = π.learner(s) |> vec |> send_to_host
- ActionStyle(env) isa MinimalActionSet ? π.explorer(logits) :
- π.explorer(logits, legal_action_space_mask(env))
+ ActionStyle(env) isa MinimalActionSet ? π.explorer(logits) :
+ π.explorer(logits, legal_action_space_mask(env))
end
function RLBase.prob(π::EDPolicy, env::AbstractEnv)
s = @ignore state(env) |>
- x-> send_to_device(device(π.learner), Flux.unsqueeze(x, ndims(x) + 1))
+ x -> send_to_device(device(π.learner), Flux.unsqueeze(x, ndims(x) + 1))
logits = π.learner(s) |> vec |> send_to_host
- ActionStyle(env) isa MinimalActionSet ? prob(π.explorer, logits) :
- prob(π.explorer, logits, @ignore legal_action_space_mask(env))
+ ActionStyle(env) isa MinimalActionSet ? prob(π.explorer, logits) :
+ prob(π.explorer, logits, @ignore legal_action_space_mask(env))
end
function RLBase.prob(π::EDPolicy, env::AbstractEnv, action)
@@ -66,12 +66,12 @@ function RLBase.prob(π::EDPolicy, env::AbstractEnv, action)
end
@error "action[$action] is not found in action space[$(action_space(env))]"
end
-end
+end
## update policy
function RLBase.update!(
- π::EDPolicy,
- Opponent_BR::BestResponsePolicy,
+ π::EDPolicy,
+ Opponent_BR::BestResponsePolicy,
env::AbstractEnv,
player::Any,
)
@@ -79,10 +79,7 @@ function RLBase.update!(
# construct policy vs best response
policy_vs_br = PolicyVsBestReponse(
- MultiAgentManager(
- NamedPolicy(player, π),
- NamedPolicy(π.opponent, Opponent_BR),
- ),
+ MultiAgentManager(NamedPolicy(player, π), NamedPolicy(π.opponent, Opponent_BR)),
env,
player,
)
@@ -94,7 +91,7 @@ function RLBase.update!(
# compute expected reward from the start of `e` with policy_vs_best_reponse
# baseline = ∑ₐ πᵢ(s, a) * q(s, a)
baseline = @ignore [values_vs_br(policy_vs_br, e) for e in info_states]
-
+
# Vector of shape `(length(info_states), length(action_space))`
# compute expected reward from the start of `e` when playing each action.
q_values = Flux.stack((q_value(π, policy_vs_br, e) for e in info_states), 1)
@@ -106,18 +103,17 @@ function RLBase.update!(
# get each info_state's loss
# ∑ₐ πᵢ(s, a) * (q(s, a) - baseline), where baseline = ∑ₐ πᵢ(s, a) * q(s, a).
- loss_per_state = - sum(policy_values .* advantage, dims=2)
+ loss_per_state = -sum(policy_values .* advantage, dims = 2)
- sum(loss_per_state .* cfr_reach_prob) |>
- x -> send_to_device(device(π.learner), x)
+ sum(loss_per_state .* cfr_reach_prob) |> x -> send_to_device(device(π.learner), x)
end
update!(π.learner, gs)
end
## Supplement struct for Computing related results when player's policy versus opponent's best_response.
-struct PolicyVsBestReponse{E, P<:MultiAgentManager}
- info_reach_prob::Dict{E, Float64}
- values_vs_br_cache::Dict{E, Float64}
+struct PolicyVsBestReponse{E,P<:MultiAgentManager}
+ info_reach_prob::Dict{E,Float64}
+ values_vs_br_cache::Dict{E,Float64}
player::Any
policy::P
end
@@ -125,13 +121,8 @@ end
function PolicyVsBestReponse(policy, env, player)
E = typeof(env)
- p = PolicyVsBestReponse(
- Dict{E, Float64}(),
- Dict{E, Float64}(),
- player,
- policy,
- )
-
+ p = PolicyVsBestReponse(Dict{E,Float64}(), Dict{E,Float64}(), player, policy)
+
e = copy(env)
RLBase.reset!(e)
get_cfr_prob!(p, e)
@@ -190,7 +181,7 @@ function values_vs_br(p::PolicyVsBestReponse, env::AbstractEnv)
end
function q_value(π::EDPolicy, p::PolicyVsBestReponse, env::AbstractEnv)
- P, A = prob(π, env) , @ignore action_space(env)
+ P, A = prob(π, env), @ignore action_space(env)
v = []
for (a, pₐ) in zip(A, P)
value = pₐ == 0 ? pₐ : values_vs_br(p, @ignore child(env, a))
diff --git a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl
index 53a3b72be..8b8c02567 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl
@@ -9,7 +9,7 @@ export EDManager
A special MultiAgentManager in which all agents use Exploitability Descent(ED) algorithm to play the game.
"""
mutable struct EDManager <: AbstractPolicy
- agents::Dict{Any, EDPolicy}
+ agents::Dict{Any,EDPolicy}
end
## interactions
@@ -22,7 +22,8 @@ function (π::EDManager)(env::AbstractEnv)
end
end
-RLBase.prob(π::EDManager, env::AbstractEnv, args...) = prob(π.agents[current_player(env)], env, args...)
+RLBase.prob(π::EDManager, env::AbstractEnv, args...) =
+ prob(π.agents[current_player(env)], env, args...)
## run function
function Base.run(
diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl
index 3c100545f..f713417af 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl
@@ -35,4 +35,4 @@ function Base.run(
end
hook(POST_EXPERIMENT_STAGE, policy, env)
hook
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl
index 8a975936f..d7379e8a2 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl
@@ -20,8 +20,8 @@ See the paper https://arxiv.org/abs/1603.01121 for more details.
mutable struct NFSPAgent <: AbstractPolicy
rl_agent::Agent
sl_agent::Agent
- η
- rng
+ η::Any
+ rng::Any
update_freq::Int
update_step::Int
mode::Bool
@@ -96,7 +96,7 @@ function (π::NFSPAgent)(::PostEpisodeStage, env::AbstractEnv, player::Any)
if haskey(rl.trajectory, :legal_actions_mask)
push!(rl.trajectory[:legal_actions_mask], legal_action_space_mask(env, player))
end
-
+
# update the policy
π.update_step += 1
if π.update_step % π.update_freq == 0
@@ -113,7 +113,7 @@ end
function rl_learn!(policy::QBasedPolicy, t::AbstractTrajectory)
learner = policy.learner
length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return
-
+
_, batch = sample(learner.rng, t, learner.sampler)
if t isa PrioritizedTrajectory
diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl
index 2e152509c..ba3280e3f 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl
@@ -7,7 +7,7 @@ export NFSPAgentManager
A special MultiAgentManager in which all agents use NFSP policy to play the game.
"""
mutable struct NFSPAgentManager <: AbstractPolicy
- agents::Dict{Any, NFSPAgent}
+ agents::Dict{Any,NFSPAgent}
end
## interactions between the policy and env.
@@ -20,7 +20,8 @@ function (π::NFSPAgentManager)(env::AbstractEnv)
end
end
-RLBase.prob(π::NFSPAgentManager, env::AbstractEnv, args...) = prob(π.agents[current_player(env)], env, args...)
+RLBase.prob(π::NFSPAgentManager, env::AbstractEnv, args...) =
+ prob(π.agents[current_player(env)], env, args...)
## update NFSPAgentManager
function RLBase.update!(π::NFSPAgentManager, env::AbstractEnv)
@@ -30,7 +31,10 @@ function RLBase.update!(π::NFSPAgentManager, env::AbstractEnv)
update!(π.agents[current_player(env)], env)
end
-function (π::NFSPAgentManager)(stage::Union{PreEpisodeStage, PostEpisodeStage}, env::AbstractEnv)
+function (π::NFSPAgentManager)(
+ stage::Union{PreEpisodeStage,PostEpisodeStage},
+ env::AbstractEnv,
+)
@sync for (player, agent) in π.agents
@async agent(stage, env, player)
end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl
index ed232f846..856a88837 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl
@@ -99,7 +99,7 @@ end
function (l::BCQLearner)(env)
s = send_to_device(device(l.policy), state(env))
s = Flux.unsqueeze(s, ndims(s) + 1)
- s = repeat(s, outer=(1, 1, l.p))
+ s = repeat(s, outer = (1, 1, l.p))
action = l.policy(s, decode(l.vae.model, s))
q_value = l.qnetwork1(vcat(s, action))
idx = argmax(q_value)
@@ -128,11 +128,15 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS})
γ, τ, λ = l.γ, l.τ, l.λ
- repeat_s′ = repeat(s′, outer=(1, 1, l.p))
+ repeat_s′ = repeat(s′, outer = (1, 1, l.p))
repeat_a′ = l.target_policy(repeat_s′, decode(l.vae.model, repeat_s′))
q′_input = vcat(repeat_s′, repeat_a′)
- q′ = maximum(λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)), dims=3)
+ q′ = maximum(
+ λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) +
+ (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)),
+ dims = 3,
+ )
y = r .+ γ .* (1 .- t) .* vec(q′)
@@ -143,7 +147,7 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS})
q_grad_1 = gradient(Flux.params(l.qnetwork1)) do
q1 = l.qnetwork1(q_input) |> vec
loss = mse(q1, y)
- ignore() do
+ ignore() do
l.critic_loss = loss
end
loss
@@ -153,7 +157,7 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS})
q_grad_2 = gradient(Flux.params(l.qnetwork2)) do
q2 = l.qnetwork2(q_input) |> vec
loss = mse(q2, y)
- ignore() do
+ ignore() do
l.critic_loss += loss
end
loss
@@ -165,7 +169,7 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS})
sampled_action = decode(l.vae.model, s)
perturbed_action = l.policy(s, sampled_action)
actor_loss = -mean(l.qnetwork1(vcat(s, perturbed_action)))
- ignore() do
+ ignore() do
l.actor_loss = actor_loss
end
actor_loss
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl
index 71321fdf3..40844f0c3 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl
@@ -33,7 +33,7 @@ mutable struct BEARLearner{
# Logging
actor_loss::Float32
critic_loss::Float32
- mmd_loss
+ mmd_loss::Any
end
"""
@@ -122,8 +122,8 @@ end
function (l::BEARLearner)(env)
s = send_to_device(device(l.policy), state(env))
s = Flux.unsqueeze(s, ndims(s) + 1)
- s = repeat(s, outer=(1, 1, l.p))
- action = l.policy(l.rng, s; is_sampling=true)
+ s = repeat(s, outer = (1, 1, l.p))
+ action = l.policy(l.rng, s; is_sampling = true)
q_value = l.qnetwork1(vcat(s, action))
idx = argmax(q_value)
action[idx]
@@ -134,13 +134,17 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS})
γ, τ, λ = l.γ, l.τ, l.λ
update_vae!(l, s, a)
-
- repeat_s′ = repeat(s′, outer=(1, 1, l.p))
- repeat_action′ = l.target_policy(l.rng, repeat_s′, is_sampling=true)
+
+ repeat_s′ = repeat(s′, outer = (1, 1, l.p))
+ repeat_action′ = l.target_policy(l.rng, repeat_s′, is_sampling = true)
q′_input = vcat(repeat_s′, repeat_action′)
- q′ = maximum(λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)), dims=3)
+ q′ = maximum(
+ λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) +
+ (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)),
+ dims = 3,
+ )
y = r .+ γ .* (1 .- t) .* vec(q′)
@@ -151,7 +155,7 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS})
q_grad_1 = gradient(Flux.params(l.qnetwork1)) do
q1 = l.qnetwork1(q_input) |> vec
loss = mse(q1, y)
- ignore() do
+ ignore() do
l.critic_loss = loss
end
loss
@@ -161,30 +165,40 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS})
q_grad_2 = gradient(Flux.params(l.qnetwork2)) do
q2 = l.qnetwork2(q_input) |> vec
loss = mse(q2, y)
- ignore() do
+ ignore() do
l.critic_loss += loss
end
loss
end
update!(l.qnetwork2, q_grad_2)
- repeat_s = repeat(s, outer=(1, 1, l.p))
- repeat_a = repeat(a, outer=(1, 1, l.p))
- repeat_q1 = mean(l.target_qnetwork1(vcat(repeat_s, repeat_a)), dims=(1, 3))
- repeat_q2 = mean(l.target_qnetwork2(vcat(repeat_s, repeat_a)), dims=(1, 3))
+ repeat_s = repeat(s, outer = (1, 1, l.p))
+ repeat_a = repeat(a, outer = (1, 1, l.p))
+ repeat_q1 = mean(l.target_qnetwork1(vcat(repeat_s, repeat_a)), dims = (1, 3))
+ repeat_q2 = mean(l.target_qnetwork2(vcat(repeat_s, repeat_a)), dims = (1, 3))
q = vec(min.(repeat_q1, repeat_q2))
alpha = exp(l.log_α.model[1])
# Train Policy
p_grad = gradient(Flux.params(l.policy)) do
- raw_sample_action = decode(l.vae.model, repeat(s, outer=(1, 1, l.sample_num)); is_normalize=false) # action_dim * batch_size * sample_num
- raw_actor_action = l.policy(repeat(s, outer=(1, 1, l.sample_num)); is_sampling=true) # action_dim * batch_size * sample_num
-
- mmd_loss = maximum_mean_discrepancy_loss(raw_sample_action, raw_actor_action, l.kernel_type, l.mmd_σ)
+ raw_sample_action = decode(
+ l.vae.model,
+ repeat(s, outer = (1, 1, l.sample_num));
+ is_normalize = false,
+ ) # action_dim * batch_size * sample_num
+ raw_actor_action =
+ l.policy(repeat(s, outer = (1, 1, l.sample_num)); is_sampling = true) # action_dim * batch_size * sample_num
+
+ mmd_loss = maximum_mean_discrepancy_loss(
+ raw_sample_action,
+ raw_actor_action,
+ l.kernel_type,
+ l.mmd_σ,
+ )
actor_loss = mean(-q .+ alpha .* mmd_loss)
- ignore() do
+ ignore() do
l.actor_loss = actor_loss
l.mmd_loss = mmd_loss
end
@@ -193,11 +207,11 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS})
update!(l.policy, p_grad)
# Update lagrange multiplier
- l_grad = gradient(Flux.params(l.log_α)) do
+ l_grad = gradient(Flux.params(l.log_α)) do
mean(-q .+ alpha .* (l.mmd_loss .- l.ε))
end
update!(l.log_α, l_grad)
-
+
clamp!(l.log_α.model, -5.0f0, l.max_log_α)
# polyak averaging
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl
index 102600f21..2a5ae8031 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl
@@ -22,11 +22,7 @@ See paper: [Critic Regularized Regression](https://arxiv.org/abs/2006.15134).
- `continuous::Bool`: type of action space.
- `rng = Random.GLOBAL_RNG`
"""
-mutable struct CRRLearner{
- Aq<:ActorCritic,
- At<:ActorCritic,
- R<:AbstractRNG,
-} <: AbstractLearner
+mutable struct CRRLearner{Aq<:ActorCritic,At<:ActorCritic,R<:AbstractRNG} <: AbstractLearner
approximator::Aq
target_approximator::At
γ::Float32
@@ -61,7 +57,7 @@ function CRRLearner(;
target_update_freq::Int = 100,
continuous::Bool,
rng = Random.GLOBAL_RNG,
-) where {Aq<:ActorCritic, At<:ActorCritic}
+) where {Aq<:ActorCritic,At<:ActorCritic}
copyto!(approximator, target_approximator)
CRRLearner(
approximator,
@@ -95,7 +91,7 @@ function (learner::CRRLearner)(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
s = send_to_device(device(learner), s)
if learner.continuous
- learner.approximator.actor(s; is_sampling=true) |> vec |> send_to_host
+ learner.approximator.actor(s; is_sampling = true) |> vec |> send_to_host
else
learner.approximator.actor(s) |> vec |> send_to_host
end
@@ -125,7 +121,7 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple)
r = reshape(r, :, batch_size)
t = reshape(t, :, batch_size)
- target_a_t = target_AC.actor(s′; is_sampling=true)
+ target_a_t = target_AC.actor(s′; is_sampling = true)
target_q_input = vcat(s′, target_a_t)
expected_target_q = target_AC.critic(target_q_input)
@@ -133,7 +129,7 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple)
q_t = Matrix{Float32}(undef, learner.m, batch_size)
for i in 1:learner.m
- a_sample = AC.actor(s; is_sampling=true)
+ a_sample = AC.actor(s; is_sampling = true)
q_t[i, :] = AC.critic(vcat(s, a_sample))
end
@@ -142,14 +138,14 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple)
# Critic loss
qa_t = AC.critic(vcat(s, a))
critic_loss = Flux.Losses.mse(qa_t, target)
-
+
# Actor loss
log_π = AC.actor.model(s, a)
if advantage_estimator == :max
- advantage = qa_t .- maximum(q_t, dims=1)
+ advantage = qa_t .- maximum(q_t, dims = 1)
elseif advantage_estimator == :mean
- advantage = qa_t .- mean(q_t, dims=1)
+ advantage = qa_t .- mean(q_t, dims = 1)
else
error("Wrong parameter.")
end
@@ -168,7 +164,7 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple)
learner.actor_loss = actor_loss
learner.critic_loss = critic_loss
end
-
+
actor_loss + critic_loss
end
@@ -193,7 +189,7 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple)
target_a_t = softmax(target_AC.actor(s′))
target_q_t = target_AC.critic(s′)
- expected_target_q = sum(target_a_t .* target_q_t, dims=1)
+ expected_target_q = sum(target_a_t .* target_q_t, dims = 1)
target = r .+ γ .* (1 .- t) .* expected_target_q
@@ -203,14 +199,14 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple)
q_t = AC.critic(s)
qa_t = reshape(q_t[a], :, batch_size)
critic_loss = Flux.Losses.mse(qa_t, target)
-
+
# Actor loss
a_t = softmax(AC.actor(s))
if advantage_estimator == :max
- advantage = qa_t .- maximum(q_t, dims=1)
+ advantage = qa_t .- maximum(q_t, dims = 1)
elseif advantage_estimator == :mean
- advantage = qa_t .- mean(q_t, dims=1)
+ advantage = qa_t .- mean(q_t, dims = 1)
else
error("Wrong parameter.")
end
@@ -222,16 +218,16 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple)
else
error("Wrong parameter.")
end
-
+
actor_loss = mean(-log.(a_t[a]) .* actor_loss_coef)
ignore() do
learner.actor_loss = actor_loss
learner.critic_loss = critic_loss
end
-
+
actor_loss + critic_loss
end
update!(AC, gs)
-end
\ No newline at end of file
+end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl
index d1ef288cd..ee7f6be35 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl
@@ -18,11 +18,8 @@ See paper: [Benchmarking Batch Deep Reinforcement Learning Algorithms](https://a
- `update_step::Int = 0`
- `rng = Random.GLOBAL_RNG`
"""
-mutable struct BCQDLearner{
- Aq<:ActorCritic,
- At<:ActorCritic,
- R<:AbstractRNG,
-} <: AbstractLearner
+mutable struct BCQDLearner{Aq<:ActorCritic,At<:ActorCritic,R<:AbstractRNG} <:
+ AbstractLearner
approximator::Aq
target_approximator::At
γ::Float32
@@ -49,7 +46,7 @@ function BCQDLearner(;
update_freq::Int = 10,
update_step::Int = 0,
rng = Random.GLOBAL_RNG,
-) where {Aq<:ActorCritic, At<:ActorCritic}
+) where {Aq<:ActorCritic,At<:ActorCritic}
copyto!(approximator, target_approximator)
BCQDLearner(
approximator,
@@ -79,8 +76,8 @@ function (learner::BCQDLearner)(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
s = send_to_device(device(learner), s)
q = learner.approximator.critic(s)
- prob = softmax(learner.approximator.actor(s), dims=1)
- mask = Float32.((prob ./ maximum(prob, dims=1)) .> learner.threshold)
+ prob = softmax(learner.approximator.actor(s), dims = 1)
+ mask = Float32.((prob ./ maximum(prob, dims = 1)) .> learner.threshold)
new_q = q .* mask .+ (1.0f0 .- mask) .* -1f8
new_q |> vec |> send_to_host
end
@@ -98,9 +95,9 @@ function RLBase.update!(learner::BCQDLearner, batch::NamedTuple)
t = reshape(t, :, batch_size)
prob = softmax(AC.actor(s′))
- mask = Float32.((prob ./ maximum(prob, dims=1)) .> learner.threshold)
+ mask = Float32.((prob ./ maximum(prob, dims = 1)) .> learner.threshold)
q′ = AC.critic(s′)
- a′ = argmax(q′ .* mask .+ (1.0f0 .- mask) .* -1f8, dims=1)
+ a′ = argmax(q′ .* mask .+ (1.0f0 .- mask) .* -1f8, dims = 1)
target_q = target_AC.critic(s′)
target = r .+ γ .* (1 .- t) .* target_q[a′]
@@ -111,27 +108,25 @@ function RLBase.update!(learner::BCQDLearner, batch::NamedTuple)
q_t = AC.critic(s)
qa_t = reshape(q_t[a], :, batch_size)
critic_loss = Flux.Losses.huber_loss(qa_t, target)
-
+
# Actor loss
logit = AC.actor(s)
- log_prob = -log.(softmax(logit, dims=1))
+ log_prob = -log.(softmax(logit, dims = 1))
actor_loss = mean(log_prob[a])
ignore() do
learner.actor_loss = actor_loss
learner.critic_loss = critic_loss
end
-
+
actor_loss + critic_loss + θ * mean(logit .^ 2)
end
update!(AC, gs)
# polyak averaging
- for (dest, src) in zip(
- Flux.params([learner.target_approximator]),
- Flux.params([learner.approximator]),
- )
+ for (dest, src) in
+ zip(Flux.params([learner.target_approximator]), Flux.params([learner.approximator]))
dest .= (1 - τ) .* dest .+ τ .* src
end
end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl
index 34c50913e..89602e3e1 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl
@@ -91,7 +91,8 @@ function FisherBRCLearner(;
)
copyto!(qnetwork1, target_qnetwork1) # force sync
copyto!(qnetwork2, target_qnetwork2) # force sync
- entropy_behavior_policy = EntropyBC(behavior_policy, 0.0f0, behavior_lr_alpha, Float32(-action_dims), 0.0f0)
+ entropy_behavior_policy =
+ EntropyBC(behavior_policy, 0.0f0, behavior_lr_alpha, Float32(-action_dims), 0.0f0)
FisherBRCLearner(
policy,
entropy_behavior_policy,
@@ -111,8 +112,8 @@ function FisherBRCLearner(;
lr_alpha,
Float32(-action_dims),
rng,
- 0f0,
- 0f0,
+ 0.0f0,
+ 0.0f0,
)
end
@@ -120,7 +121,7 @@ function (l::FisherBRCLearner)(env)
D = device(l.policy)
s = send_to_device(D, state(env))
s = Flux.unsqueeze(s, ndims(s) + 1)
- action = dropdims(l.policy(l.rng, s; is_sampling=true), dims=2)
+ action = dropdims(l.policy(l.rng, s; is_sampling = true), dims = 2)
end
function RLBase.update!(l::FisherBRCLearner, batch::NamedTuple{SARTS})
@@ -137,7 +138,7 @@ function update_behavior_policy!(l::EntropyBC, batch::NamedTuple{SARTS})
ps = Flux.params(l.policy)
gs = gradient(ps) do
log_π = l.policy.model(s, a)
- _, entropy = l.policy.model(s; is_sampling=true, is_return_log_prob=true)
+ _, entropy = l.policy.model(s; is_sampling = true, is_return_log_prob = true)
loss = mean(l.α .* entropy .- log_π)
# Update entropy
ignore() do
@@ -154,7 +155,7 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS})
r .+= l.reward_bonus
γ, τ, α = l.γ, l.τ, l.α
- a′ = l.policy(l.rng, s′; is_sampling=true)
+ a′ = l.policy(l.rng, s′; is_sampling = true)
q′_input = vcat(s′, a′)
target_q′ = min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input))
@@ -164,16 +165,16 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS})
a = reshape(a, :, l.batch_size)
q_input = vcat(s, a)
log_μ = l.behavior_policy.policy.model(s, a) |> vec
- a_policy = l.policy(l.rng, s; is_sampling=true)
+ a_policy = l.policy(l.rng, s; is_sampling = true)
q_grad_1 = gradient(Flux.params(l.qnetwork1)) do
q1 = l.qnetwork1(q_input) |> vec
- q1_grad_norm = gradient(Flux.params([a_policy])) do
+ q1_grad_norm = gradient(Flux.params([a_policy])) do
q1_reg = mean(l.qnetwork1(vcat(s, a_policy)))
end
reg = mean(q1_grad_norm[a_policy] .^ 2)
loss = mse(q1 .+ log_μ, y) + l.f_reg * reg
- ignore() do
+ ignore() do
l.qnetwork_loss = loss
end
loss
@@ -182,12 +183,12 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS})
q_grad_2 = gradient(Flux.params(l.qnetwork2)) do
q2 = l.qnetwork2(q_input) |> vec
- q2_grad_norm = gradient(Flux.params([a_policy])) do
+ q2_grad_norm = gradient(Flux.params([a_policy])) do
q2_reg = mean(l.qnetwork2(vcat(s, a_policy)))
end
reg = mean(q2_grad_norm[a_policy] .^ 2)
loss = mse(q2 .+ log_μ, y) + l.f_reg * reg
- ignore() do
+ ignore() do
l.qnetwork_loss += loss
end
loss
@@ -196,7 +197,7 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS})
# Train Policy
p_grad = gradient(Flux.params(l.policy)) do
- a, log_π = l.policy(l.rng, s; is_sampling=true, is_return_log_prob=true)
+ a, log_π = l.policy(l.rng, s; is_sampling = true, is_return_log_prob = true)
q_input = vcat(s, a)
q = min.(l.qnetwork1(q_input), l.qnetwork2(q_input)) .+ log_μ
policy_loss = mean(α .* log_π .- q)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl
index 8ced04148..1bcfd6477 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl
@@ -96,7 +96,7 @@ function (l::PLASLearner)(env)
s = send_to_device(device(l.policy), state(env))
s = Flux.unsqueeze(s, ndims(s) + 1)
latent_action = tanh.(l.policy(s))
- action = dropdims(decode(l.vae.model, s, latent_action), dims=2)
+ action = dropdims(decode(l.vae.model, s, latent_action), dims = 2)
end
function RLBase.update!(l::PLASLearner, batch::NamedTuple{SARTS})
@@ -125,7 +125,9 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS})
latent_action′ = tanh.(l.target_policy(s′))
action′ = decode(l.vae.model, s′, latent_action′)
q′_input = vcat(s′, action′)
- q′ = λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input))
+ q′ =
+ λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) +
+ (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input))
y = r .+ γ .* (1 .- t) .* vec(q′)
@@ -136,7 +138,7 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS})
q_grad_1 = gradient(Flux.params(l.qnetwork1)) do
q1 = l.qnetwork1(q_input) |> vec
loss = mse(q1, y)
- ignore() do
+ ignore() do
l.critic_loss = loss
end
loss
@@ -146,7 +148,7 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS})
q_grad_2 = gradient(Flux.params(l.qnetwork2)) do
q2 = l.qnetwork2(q_input) |> vec
loss = mse(q2, y)
- ignore() do
+ ignore() do
l.critic_loss += loss
end
loss
@@ -158,7 +160,7 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS})
latent_action = tanh.(l.policy(s))
action = decode(l.vae.model, s, latent_action)
actor_loss = -mean(l.qnetwork1(vcat(s, action)))
- ignore() do
+ ignore() do
l.actor_loss = actor_loss
end
actor_loss
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl
index c5b7bdd09..48c021035 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl
@@ -19,19 +19,14 @@ end
- `rng = Random.GLOBAL_RNG`
"""
function BehaviorCloningPolicy(;
- approximator::A,
- explorer::AbstractExplorer = GreedyExplorer(),
- batch_size::Int = 32,
- min_reservoir_history::Int = 100,
- rng = Random.GLOBAL_RNG
+ approximator::A,
+ explorer::AbstractExplorer = GreedyExplorer(),
+ batch_size::Int = 32,
+ min_reservoir_history::Int = 100,
+ rng = Random.GLOBAL_RNG,
) where {A}
sampler = BatchSampler{(:state, :action)}(batch_size; rng = rng)
- BehaviorCloningPolicy(
- approximator,
- explorer,
- sampler,
- min_reservoir_history,
- )
+ BehaviorCloningPolicy(approximator, explorer, sampler, min_reservoir_history)
end
function (p::BehaviorCloningPolicy)(env::AbstractEnv)
@@ -39,7 +34,8 @@ function (p::BehaviorCloningPolicy)(env::AbstractEnv)
s_batch = Flux.unsqueeze(s, ndims(s) + 1)
s_batch = send_to_device(device(p.approximator), s_batch)
logits = p.approximator(s_batch) |> vec |> send_to_host # drop dimension
- typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) : p.explorer(logits, legal_action_space_mask(env))
+ typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) :
+ p.explorer(logits, legal_action_space_mask(env))
end
function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :action)})
@@ -65,7 +61,8 @@ function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv)
m = p.approximator
s_batch = send_to_device(device(m), Flux.unsqueeze(s, ndims(s) + 1))
values = m(s_batch) |> vec |> send_to_host
- typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) : prob(p.explorer, values, legal_action_space_mask(env))
+ typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) :
+ prob(p.explorer, values, legal_action_space_mask(env))
end
function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv, action)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl
index 3c8401dd6..5bf0e0b7f 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl
@@ -3,11 +3,11 @@ export OfflinePolicy, AtariRLTransition
export calculate_CQL_loss, maximum_mean_discrepancy_loss
struct AtariRLTransition
- state
- action
- reward
- terminal
- next_state
+ state::Any
+ action::Any
+ reward::Any
+ terminal::Any
+ next_state::Any
end
Base.@kwdef struct OfflinePolicy{L,T} <: AbstractPolicy
@@ -26,7 +26,8 @@ function (π::OfflinePolicy)(env, ::MinimalActionSet, ::Base.OneTo)
findmax(π.learner(env))[2]
end
end
-(π::OfflinePolicy)(env, ::FullActionSet, ::Base.OneTo) = findmax(π.learner(env), legal_action_space_mask(env))[2]
+(π::OfflinePolicy)(env, ::FullActionSet, ::Base.OneTo) =
+ findmax(π.learner(env), legal_action_space_mask(env))[2]
function (π::OfflinePolicy)(env, ::MinimalActionSet, A)
if π.continuous
@@ -35,7 +36,8 @@ function (π::OfflinePolicy)(env, ::MinimalActionSet, A)
A[findmax(π.learner(env))[2]]
end
end
-(π::OfflinePolicy)(env, ::FullActionSet, A) = A[findmax(π.learner(env), legal_action_space_mask(env))[2]]
+(π::OfflinePolicy)(env, ::FullActionSet, A) =
+ A[findmax(π.learner(env), legal_action_space_mask(env))[2]]
function RLBase.update!(
p::OfflinePolicy,
@@ -62,7 +64,8 @@ function RLBase.update!(
l = p.learner
l.update_step += 1
- if in(:target_update_freq, fieldnames(typeof(l))) && l.update_step % l.target_update_freq == 0
+ if in(:target_update_freq, fieldnames(typeof(l))) &&
+ l.update_step % l.target_update_freq == 0
copyto!(l.target_approximator, l.approximator)
end
@@ -99,19 +102,30 @@ end
calculate_CQL_loss(q_value, action; method)
See paper: [Conservative Q-Learning for Offline Reinforcement Learning](https://arxiv.org/abs/2006.04779)
"""
-function calculate_CQL_loss(q_value::Matrix{T}, action::Vector{R}; method = "CQL(H)") where {T, R}
+function calculate_CQL_loss(
+ q_value::Matrix{T},
+ action::Vector{R};
+ method = "CQL(H)",
+) where {T,R}
if method == "CQL(H)"
- cql_loss = mean(log.(sum(exp.(q_value), dims=1)) .- q_value[action])
+ cql_loss = mean(log.(sum(exp.(q_value), dims = 1)) .- q_value[action])
else
@error Wrong method parameter
end
return cql_loss
end
-function maximum_mean_discrepancy_loss(raw_sample_action, raw_actor_action, type::Symbol, mmd_σ::Float32=10.0f0)
+function maximum_mean_discrepancy_loss(
+ raw_sample_action,
+ raw_actor_action,
+ type::Symbol,
+ mmd_σ::Float32 = 10.0f0,
+)
A, B, N = size(raw_sample_action)
- diff_xx = reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_sample_action, A, B, 1, N)
- diff_xy = reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_actor_action, A, B, 1, N)
+ diff_xx =
+ reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_sample_action, A, B, 1, N)
+ diff_xy =
+ reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_actor_action, A, B, 1, N)
diff_yy = reshape(raw_actor_action, A, B, N, 1) .- reshape(raw_actor_action, A, B, 1, N)
diff_xx = calculate_sample_distance(diff_xx, type, mmd_σ)
diff_xy = calculate_sample_distance(diff_xy, type, mmd_σ)
@@ -127,5 +141,5 @@ function calculate_sample_distance(diff, type::Symbol, mmd_σ::Float32)
else
error("Wrong parameter.")
end
- return vec(mean(exp.(-sum(diff, dims=1) ./ (2.0f0 * mmd_σ)), dims=(3, 4)))
+ return vec(mean(exp.(-sum(diff, dims = 1) ./ (2.0f0 * mmd_σ)), dims = (3, 4)))
end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl
index d49f971b0..0daac6a2e 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl
@@ -118,7 +118,12 @@ function (p::DDPGPolicy)(env, player::Any = nothing)
s = DynamicStyle(env) == SEQUENTIAL ? state(env) : state(env, player)
s = Flux.unsqueeze(s, ndims(s) + 1)
actions = p.behavior_actor(send_to_device(D, s)) |> vec |> send_to_host
- c = clamp.(actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na), -p.act_limit, p.act_limit)
+ c =
+ clamp.(
+ actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na),
+ -p.act_limit,
+ p.act_limit,
+ )
p.na == 1 && return c[1]
c
end
@@ -154,7 +159,7 @@ function RLBase.update!(p::DDPGPolicy, batch::NamedTuple{SARTS})
a′ = Aₜ(s′)
qₜ = Cₜ(vcat(s′, a′)) |> vec
y = r .+ γ .* (1 .- t) .* qₜ
- a = Flux.unsqueeze(a, ndims(a)+1)
+ a = Flux.unsqueeze(a, ndims(a) + 1)
gs1 = gradient(Flux.params(C)) do
q = C(vcat(s, a)) |> vec
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl
index 8f2d67046..6357e74c0 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl
@@ -16,8 +16,8 @@ See the paper https://arxiv.org/abs/1706.02275 for more details.
- `rng::AbstractRNG`.
"""
mutable struct MADDPGManager <: AbstractPolicy
- agents::Dict{<:Any, <:Agent}
- traces
+ agents::Dict{<:Any,<:Agent}
+ traces::Any
batch_size::Int
update_freq::Int
update_step::Int
@@ -29,12 +29,10 @@ function (π::MADDPGManager)(env::AbstractEnv)
while current_player(env) == chance_player(env)
env |> legal_action_space |> rand |> env
end
- Dict(
- player => agent.policy(env)
- for (player, agent) in π.agents)
+ Dict(player => agent.policy(env) for (player, agent) in π.agents)
end
-function (π::MADDPGManager)(stage::Union{PreEpisodeStage, PostActStage}, env::AbstractEnv)
+function (π::MADDPGManager)(stage::Union{PreEpisodeStage,PostActStage}, env::AbstractEnv)
# only need to update trajectory.
for (_, agent) in π.agents
update!(agent.trajectory, agent.policy, env, stage)
@@ -46,7 +44,7 @@ function (π::MADDPGManager)(stage::PreActStage, env::AbstractEnv, actions)
for (player, agent) in π.agents
update!(agent.trajectory, agent.policy, env, stage, actions[player])
end
-
+
# update policy
update!(π, env)
end
@@ -70,14 +68,18 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv)
length(agent.trajectory) > agent.policy.policy.update_after || return
length(agent.trajectory) > π.batch_size || return
end
-
+
# get training data
temp_player = collect(keys(π.agents))[1]
t = π.agents[temp_player].trajectory
inds = rand(π.rng, 1:length(t), π.batch_size)
- batches = Dict((player, RLCore.fetch!(BatchSampler{π.traces}(π.batch_size), agent.trajectory, inds))
- for (player, agent) in π.agents)
-
+ batches = Dict(
+ (
+ player,
+ RLCore.fetch!(BatchSampler{π.traces}(π.batch_size), agent.trajectory, inds),
+ ) for (player, agent) in π.agents
+ )
+
# get s, a, s′ for critic
s = vcat((batches[player][:state] for (player, _) in π.agents)...)
a = vcat((batches[player][:action] for (player, _) in π.agents)...)
@@ -100,17 +102,17 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv)
t = batches[player][:terminal]
# for training behavior_actor.
mu_actions = vcat(
- ((
- batches[p][:next_state] |>
- a.policy.policy.behavior_actor
- ) for (p, a) in π.agents)...
+ (
+ (batches[p][:next_state] |> a.policy.policy.behavior_actor) for
+ (p, a) in π.agents
+ )...,
)
# for training behavior_critic.
new_actions = vcat(
- ((
- batches[p][:next_state] |>
- a.policy.policy.target_actor
- ) for (p, a) in π.agents)...
+ (
+ (batches[p][:next_state] |> a.policy.policy.target_actor) for
+ (p, a) in π.agents
+ )...,
)
if π.traces == SLARTSL
@@ -120,18 +122,18 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv)
@assert env isa ActionTransformedEnv
mask = batches[player][:next_legal_actions_mask]
- mu_l′ = Flux.batch(
- (begin
+ mu_l′ = Flux.batch((
+ begin
actions = env.action_mapping(mu_actions[:, i])
mask[actions[player]]
- end for i = 1:π.batch_size)
- )
- new_l′ = Flux.batch(
- (begin
+ end for i in 1:π.batch_size
+ ))
+ new_l′ = Flux.batch((
+ begin
actions = env.action_mapping(new_actions[:, i])
mask[actions[player]]
- end for i = 1:π.batch_size)
- )
+ end for i in 1:π.batch_size
+ ))
end
qₜ = Cₜ(vcat(s′, new_actions)) |> vec
@@ -157,7 +159,7 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv)
v .+= ifelse.(mu_l′, 0.0f0, typemin(Float32))
end
reg = mean(A(batches[player][:state]) .^ 2)
- loss = -mean(v) + reg * 1e-3
+ loss = -mean(v) + reg * 1e-3
ignore() do
p.actor_loss = loss
end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl
index fa06a9132..bf9e8f27e 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl
@@ -148,7 +148,9 @@ function RLBase.prob(
if p.update_step < p.n_random_start
@error "todo"
else
- μ, logσ = p.approximator.actor(send_to_device(device(p.approximator), state)) |> send_to_host
+ μ, logσ =
+ p.approximator.actor(send_to_device(device(p.approximator), state)) |>
+ send_to_host
StructArray{Normal}((μ, exp.(logσ)))
end
end
@@ -256,11 +258,11 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
end
s = send_to_device(D, select_last_dim(states_flatten, inds)) # !!! performance critical
a = send_to_device(D, select_last_dim(actions_flatten, inds))
-
+
if eltype(a) === Int
a = CartesianIndex.(a, 1:length(a))
end
-
+
r = send_to_device(D, vec(returns)[inds])
log_p = send_to_device(D, vec(action_log_probs)[inds])
adv = send_to_device(D, vec(advantages)[inds])
@@ -275,7 +277,8 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory)
else
log_p′ₐ = normlogpdf(μ, exp.(logσ), a)
end
- entropy_loss = mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2
+ entropy_loss =
+ mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2
else
# actor is assumed to return discrete logits
logit′ = AC.actor(s)
diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl
index b41c0397a..fef3c53d1 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl
@@ -104,8 +104,8 @@ function SACPolicy(;
Float32(-action_dims),
update_step,
rng,
- 0f0,
- 0f0,
+ 0.0f0,
+ 0.0f0,
)
end
@@ -120,7 +120,7 @@ function (p::SACPolicy)(env)
s = state(env)
s = Flux.unsqueeze(s, ndims(s) + 1)
# trainmode:
- action = dropdims(p.policy(p.rng, s; is_sampling=true), dims=2) # Single action vec, drop second dim
+ action = dropdims(p.policy(p.rng, s; is_sampling = true), dims = 2) # Single action vec, drop second dim
# testmode:
# if testing dont sample an action, but act deterministically by
@@ -146,7 +146,7 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS})
γ, τ, α = p.γ, p.τ, p.α
- a′, log_π = p.policy(p.rng, s′; is_sampling=true, is_return_log_prob=true)
+ a′, log_π = p.policy(p.rng, s′; is_sampling = true, is_return_log_prob = true)
q′_input = vcat(s′, a′)
q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input))
@@ -168,12 +168,12 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS})
# Train Policy
p_grad = gradient(Flux.params(p.policy)) do
- a, log_π = p.policy(p.rng, s; is_sampling=true, is_return_log_prob=true)
+ a, log_π = p.policy(p.rng, s; is_sampling = true, is_return_log_prob = true)
q_input = vcat(s, a)
q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input))
reward = mean(q)
entropy = mean(log_π)
- ignore() do
+ ignore() do
p.reward_term = reward
p.entropy_term = entropy
end
diff --git a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl
index a91c22b62..d0bbedfa3 100644
--- a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl
+++ b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl
@@ -11,7 +11,7 @@ A `Dict` is used internally to store the mapping from state to action.
"""
Base.@kwdef struct TabularPolicy{S,A} <: AbstractPolicy
table::Dict{S,A} = Dict{Int,Int}()
- n_action::Union{Int, Nothing} = nothing
+ n_action::Union{Int,Nothing} = nothing
end
(p::TabularPolicy)(env::AbstractEnv) = p(state(env))
diff --git a/test/runtests.jl b/test/runtests.jl
index 04be42cbd..d6dd95c9e 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,4 @@
using Test
using ReinforcementLearning
-@testset "ReinforcementLearning" begin
-end
+@testset "ReinforcementLearning" begin end