From 301f8a866dff88bf8d69d2686e740c3778cbf8e3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 10 Aug 2021 00:33:06 +0000 Subject: [PATCH] Format .jl files --- .../CFR/JuliaRL_DeepCFR_OpenSpiel.jl | 10 +- .../CFR/JuliaRL_TabularCFR_OpenSpiel.jl | 10 +- .../experiments/DQN/Dopamine_DQN_Atari.jl | 74 ++-- .../experiments/DQN/Dopamine_IQN_Atari.jl | 73 ++-- .../experiments/DQN/Dopamine_Rainbow_Atari.jl | 71 ++-- .../DQN/JuliaRL_BasicDQN_CartPole.jl | 2 +- .../DQN/JuliaRL_BasicDQN_MountainCar.jl | 2 +- .../JuliaRL_BasicDQN_SingleRoomUndirected.jl | 48 +-- .../experiments/DQN/JuliaRL_DQN_CartPole.jl | 12 +- .../DQN/JuliaRL_DQN_MountainCar.jl | 2 +- .../experiments/DQN/JuliaRL_IQN_CartPole.jl | 2 +- .../experiments/DQN/JuliaRL_QRDQN_Cartpole.jl | 52 +-- .../DQN/JuliaRL_REMDQN_CartPole.jl | 4 +- .../DQN/JuliaRL_Rainbow_CartPole.jl | 2 +- .../NFSP/JuliaRL_NFSP_KuhnPoker.jl | 54 ++- .../Offline/JuliaRL_BC_CartPole.jl | 10 +- .../JuliaRL_A2CGAE_CartPole.jl | 4 +- .../Policy Gradient/JuliaRL_A2C_CartPole.jl | 4 +- .../Policy Gradient/JuliaRL_DDPG_Pendulum.jl | 4 +- .../Policy Gradient/JuliaRL_MAC_CartPole.jl | 4 +- .../Policy Gradient/JuliaRL_PPO_CartPole.jl | 4 +- .../Policy Gradient/JuliaRL_PPO_Pendulum.jl | 7 +- .../Policy Gradient/JuliaRL_SAC_Pendulum.jl | 11 +- .../Policy Gradient/JuliaRL_TD3_Pendulum.jl | 2 +- .../Policy Gradient/JuliaRL_VPG_CartPole.jl | 2 +- .../Policy Gradient/rlpyt_A2C_Atari.jl | 10 +- .../Policy Gradient/rlpyt_PPO_Atari.jl | 12 +- .../Search/JuliaRL_Minimax_OpenSpiel.jl | 8 +- docs/homepage/utils.jl | 31 +- docs/make.jl | 8 +- .../src/actor_model.jl | 17 +- .../src/core.jl | 13 +- .../src/extensions.jl | 2 +- .../test/actor.jl | 72 ++-- .../test/core.jl | 311 +++++++------- .../test/runtests.jl | 4 +- .../src/CommonRLInterface.jl | 5 +- .../src/interface.jl | 9 +- .../test/CommonRLInterface.jl | 50 +-- .../test/runtests.jl | 4 +- .../src/core/hooks.jl | 29 +- .../src/extensions/ArrayInterface.jl | 7 +- .../trajectories/trajectory_extension.jl | 2 +- .../neural_network_approximator.jl | 32 +- .../test/components/trajectories.jl | 4 +- .../test/core/core.jl | 12 +- .../test/core/stop_conditions_test.jl | 5 +- .../src/ReinforcementLearningDatasets.jl | 2 +- .../src/atari/atari_dataset.jl | 85 ++-- .../src/atari/register.jl | 87 +++- .../src/common.jl | 2 +- .../src/d4rl/d4rl/register.jl | 380 +++++++++--------- .../src/d4rl/d4rl_dataset.jl | 73 ++-- .../src/d4rl/d4rl_pybullet/register.jl | 28 +- src/ReinforcementLearningDatasets/src/init.jl | 2 +- .../test/atari_dataset.jl | 10 +- .../test/d4rl_pybullet.jl | 16 +- .../test/dataset.jl | 6 +- .../src/ReinforcementLearningEnvironments.jl | 6 +- .../src/environments/3rd_party/AcrobotEnv.jl | 42 +- .../src/environments/3rd_party/open_spiel.jl | 44 +- .../src/environments/3rd_party/structs.jl | 4 +- .../environments/examples/StockTradingEnv.jl | 47 +-- .../wrappers/ActionTransformedEnv.jl | 2 +- .../wrappers/DefaultStateStyle.jl | 9 +- .../environments/wrappers/SequentialEnv.jl | 6 +- .../environments/wrappers/StateCachedEnv.jl | 2 +- .../wrappers/StateTransformedEnv.jl | 4 +- .../src/plots.jl | 68 ++-- .../examples/stock_trading_env.jl | 1 - .../test/environments/wrappers/wrappers.jl | 20 +- .../deps/build.jl | 7 +- .../src/ReinforcementLearningExperiments.jl | 3 +- .../src/algorithms/dqns/common.jl | 5 +- .../src/algorithms/dqns/dqn.jl | 8 +- .../src/algorithms/dqns/dqns.jl | 2 +- .../src/algorithms/dqns/qr_dqn.jl | 51 +-- .../src/algorithms/dqns/rem_dqn.jl | 7 +- .../src/algorithms/nfsp/abstract_nfsp.jl | 2 +- .../src/algorithms/nfsp/nfsp.jl | 11 +- .../src/algorithms/nfsp/nfsp_manager.jl | 10 +- .../algorithms/offline_rl/behavior_cloning.jl | 23 +- .../src/algorithms/policy_gradient/ddpg.jl | 9 +- .../src/algorithms/policy_gradient/ppo.jl | 11 +- .../src/algorithms/policy_gradient/sac.jl | 12 +- .../src/algorithms/tabular/tabular_policy.jl | 2 +- test/runtests.jl | 3 +- 87 files changed, 1210 insertions(+), 1019 deletions(-) diff --git a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl index ff647a3a4..2eeecbd74 100644 --- a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl +++ b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl @@ -61,5 +61,11 @@ function RL.Experiment( batch_size_Π = 2048, initializer = glorot_normal(CUDA.CURAND.default_rng()), ) - Experiment(p, env, StopAfterStep(500, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), "# run DeepcCFR on leduc_poker") -end \ No newline at end of file + Experiment( + p, + env, + StopAfterStep(500, is_show_progress = !haskey(ENV, "CI")), + EmptyHook(), + "# run DeepcCFR on leduc_poker", + ) +end diff --git a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl index edfd7f199..d89cabb16 100644 --- a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl +++ b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl @@ -23,8 +23,14 @@ function RL.Experiment( π = TabularCFRPolicy(; rng = rng) description = "# Play `$game` in OpenSpiel with TabularCFRPolicy" - Experiment(π, env, StopAfterStep(300, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), description) + Experiment( + π, + env, + StopAfterStep(300, is_show_progress = !haskey(ENV, "CI")), + EmptyHook(), + description, + ) end ex = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)` -run(ex) \ No newline at end of file +run(ex) diff --git a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl index f51b4a2a9..8cee680cc 100644 --- a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl +++ b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl @@ -79,39 +79,35 @@ function atari_env_factory( repeat_action_probability = 0.25, n_replica = nothing, ) - init(seed) = - RewardOverriddenEnv( - StateCachedEnv( - StateTransformedEnv( - AtariEnv(; - name = string(name), - grayscale_obs = true, - noop_max = 30, - frame_skip = 4, - terminal_on_life_loss = false, - repeat_action_probability = repeat_action_probability, - max_num_frames_per_episode = n_frames * max_episode_steps, - color_averaging = false, - full_action_space = false, - seed = seed, - ); - state_mapping=Chain( - ResizeImage(state_size...), - StackFrames(state_size..., n_frames) - ), - state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames)) - ) + init(seed) = RewardOverriddenEnv( + StateCachedEnv( + StateTransformedEnv( + AtariEnv(; + name = string(name), + grayscale_obs = true, + noop_max = 30, + frame_skip = 4, + terminal_on_life_loss = false, + repeat_action_probability = repeat_action_probability, + max_num_frames_per_episode = n_frames * max_episode_steps, + color_averaging = false, + full_action_space = false, + seed = seed, + ); + state_mapping = Chain( + ResizeImage(state_size...), + StackFrames(state_size..., n_frames), + ), + state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)), ), - r -> clamp(r, -1, 1) - ) + ), + r -> clamp(r, -1, 1), + ) if isnothing(n_replica) init(seed) else - envs = [ - init(isnothing(seed) ? nothing : hash(seed + i)) - for i in 1:n_replica - ] + envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica] states = Flux.batch(state.(envs)) rewards = reward.(envs) terminals = is_terminated.(envs) @@ -172,7 +168,7 @@ function RL.Experiment( ::Val{:Atari}, name::AbstractString; save_dir = nothing, - seed = nothing + seed = nothing, ) rng = Random.GLOBAL_RNG Random.seed!(rng, seed) @@ -190,7 +186,7 @@ function RL.Experiment( name, STATE_SIZE, N_FRAMES; - seed = isnothing(seed) ? nothing : hash(seed + 1) + seed = isnothing(seed) ? nothing : hash(seed + 1), ) N_ACTIONS = length(action_space(env)) init = glorot_uniform(rng) @@ -254,17 +250,15 @@ function RL.Experiment( end, DoEveryNEpisode() do t, agent, env with_logger(lg) do - @info "training" episode_length = step_per_episode.steps[end] reward = reward_per_episode.rewards[end] log_step_increment = 0 + @info "training" episode_length = step_per_episode.steps[end] reward = + reward_per_episode.rewards[end] log_step_increment = 0 end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." p = agent.policy p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon - h = ComposedHook( - TotalOriginalRewardPerEpisode(), - StepsPerEpisode(), - ) + h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode()) s = @elapsed run( p, atari_env_factory( @@ -281,16 +275,18 @@ function RL.Experiment( avg_score = mean(h[1].rewards[1:end-1]) avg_length = mean(h[2].steps[1:end-1]) - @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score + @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = + avg_score with_logger(lg) do - @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0 + @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = + 0 end end, ) stop_condition = StopAfterStep( haskey(ENV, "CI") ? 1_000 : 50_000_000, - is_show_progress=!haskey(ENV, "CI") + is_show_progress = !haskey(ENV, "CI"), ) Experiment(agent, env, stop_condition, hook, "# DQN <-> Atari($name)") end diff --git a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl index 6e0305ae5..7f5c8e2e7 100644 --- a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl +++ b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl @@ -84,39 +84,35 @@ function atari_env_factory( repeat_action_probability = 0.25, n_replica = nothing, ) - init(seed) = - RewardOverriddenEnv( - StateCachedEnv( - StateTransformedEnv( - AtariEnv(; - name = string(name), - grayscale_obs = true, - noop_max = 30, - frame_skip = 4, - terminal_on_life_loss = false, - repeat_action_probability = repeat_action_probability, - max_num_frames_per_episode = n_frames * max_episode_steps, - color_averaging = false, - full_action_space = false, - seed = seed, - ); - state_mapping=Chain( - ResizeImage(state_size...), - StackFrames(state_size..., n_frames) - ), - state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames)) - ) + init(seed) = RewardOverriddenEnv( + StateCachedEnv( + StateTransformedEnv( + AtariEnv(; + name = string(name), + grayscale_obs = true, + noop_max = 30, + frame_skip = 4, + terminal_on_life_loss = false, + repeat_action_probability = repeat_action_probability, + max_num_frames_per_episode = n_frames * max_episode_steps, + color_averaging = false, + full_action_space = false, + seed = seed, + ); + state_mapping = Chain( + ResizeImage(state_size...), + StackFrames(state_size..., n_frames), + ), + state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)), ), - r -> clamp(r, -1, 1) - ) + ), + r -> clamp(r, -1, 1), + ) if isnothing(n_replica) init(seed) else - envs = [ - init(isnothing(seed) ? nothing : hash(seed + i)) - for i in 1:n_replica - ] + envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica] states = Flux.batch(state.(envs)) rewards = reward.(envs) terminals = is_terminated.(envs) @@ -195,7 +191,12 @@ function RL.Experiment( N_FRAMES = 4 STATE_SIZE = (84, 84) - env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 2)) + env = atari_env_factory( + name, + STATE_SIZE, + N_FRAMES; + seed = isnothing(seed) ? nothing : hash(seed + 2), + ) N_ACTIONS = length(action_space(env)) Nₑₘ = 64 @@ -250,7 +251,7 @@ function RL.Experiment( ), ), trajectory = CircularArraySARTTrajectory( - capacity = haskey(ENV, "CI") : 1_000 : 1_000_000, + capacity = haskey(ENV, "CI"):1_000:1_000_000, state = Matrix{Float32} => STATE_SIZE, ), ) @@ -274,7 +275,7 @@ function RL.Experiment( steps_per_episode.steps[end] log_step_increment = 0 end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." p = agent.policy p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon @@ -286,7 +287,7 @@ function RL.Experiment( STATE_SIZE, N_FRAMES, MAX_EPISODE_STEPS_EVAL; - seed = isnothing(seed) ? nothing : hash(seed + t) + seed = isnothing(seed) ? nothing : hash(seed + t), ), StopAfterStep(125_000; is_show_progress = false), h, @@ -295,16 +296,18 @@ function RL.Experiment( avg_score = mean(h[1].rewards[1:end-1]) avg_length = mean(h[2].steps[1:end-1]) - @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score + @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = + avg_score with_logger(lg) do - @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0 + @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = + 0 end end, ) stop_condition = StopAfterStep( haskey(ENV, "CI") ? 10_000 : 50_000_000, - is_show_progress=!haskey(ENV, "CI") + is_show_progress = !haskey(ENV, "CI"), ) Experiment(agent, env, stop_condition, hook, "# IQN <-> Atari($name)") end diff --git a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl index 432e110e4..fd6d76c66 100644 --- a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl +++ b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl @@ -83,39 +83,35 @@ function atari_env_factory( repeat_action_probability = 0.25, n_replica = nothing, ) - init(seed) = - RewardOverriddenEnv( - StateCachedEnv( - StateTransformedEnv( - AtariEnv(; - name = string(name), - grayscale_obs = true, - noop_max = 30, - frame_skip = 4, - terminal_on_life_loss = false, - repeat_action_probability = repeat_action_probability, - max_num_frames_per_episode = n_frames * max_episode_steps, - color_averaging = false, - full_action_space = false, - seed = seed, - ); - state_mapping=Chain( - ResizeImage(state_size...), - StackFrames(state_size..., n_frames) - ), - state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames)) - ) + init(seed) = RewardOverriddenEnv( + StateCachedEnv( + StateTransformedEnv( + AtariEnv(; + name = string(name), + grayscale_obs = true, + noop_max = 30, + frame_skip = 4, + terminal_on_life_loss = false, + repeat_action_probability = repeat_action_probability, + max_num_frames_per_episode = n_frames * max_episode_steps, + color_averaging = false, + full_action_space = false, + seed = seed, + ); + state_mapping = Chain( + ResizeImage(state_size...), + StackFrames(state_size..., n_frames), + ), + state_space_mapping = _ -> Space(fill(0..256, state_size..., n_frames)), ), - r -> clamp(r, -1, 1) - ) + ), + r -> clamp(r, -1, 1), + ) if isnothing(n_replica) init(seed) else - envs = [ - init(isnothing(seed) ? nothing : hash(seed + i)) - for i in 1:n_replica - ] + envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica] states = Flux.batch(state.(envs)) rewards = reward.(envs) terminals = is_terminated.(envs) @@ -191,7 +187,12 @@ function RL.Experiment( N_FRAMES = 4 STATE_SIZE = (84, 84) - env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 1)) + env = atari_env_factory( + name, + STATE_SIZE, + N_FRAMES; + seed = isnothing(seed) ? nothing : hash(seed + 1), + ) N_ACTIONS = length(action_space(env)) N_ATOMS = 51 init = glorot_uniform(rng) @@ -238,7 +239,7 @@ function RL.Experiment( ), ), trajectory = CircularArrayPSARTTrajectory( - capacity = haskey(ENV, "CI") : 1_000 : 1_000_000, + capacity = haskey(ENV, "CI"):1_000:1_000_000, state = Matrix{Float32} => STATE_SIZE, ), ) @@ -262,7 +263,7 @@ function RL.Experiment( steps_per_episode.steps[end] log_step_increment = 0 end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." p = agent.policy p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon @@ -282,16 +283,18 @@ function RL.Experiment( avg_length = mean(h[2].steps[1:end-1]) avg_score = mean(h[1].rewards[1:end-1]) - @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score + @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = + avg_score with_logger(lg) do - @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0 + @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = + 0 end end, ) stop_condition = StopAfterStep( haskey(ENV, "CI") ? 10_000 : 50_000_000, - is_show_progress=!haskey(ENV, "CI") + is_show_progress = !haskey(ENV, "CI"), ) Experiment(agent, env, stop_condition, hook, "# Rainbow <-> Atari($name)") diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl index 9f32c48d6..d5ba9c9c7 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl @@ -51,7 +51,7 @@ function RL.Experiment( state = Vector{Float32} => (ns,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(policy, env, stop_condition, hook, "# BasicDQN <-> CartPole") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl index ae8f02cb5..bc79c94be 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl @@ -51,7 +51,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(70_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(70_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl index 4b7f2a5cd..39748e307 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl @@ -18,47 +18,47 @@ function RL.Experiment( ::Val{:BasicDQN}, ::Val{:SingleRoomUndirected}, ::Nothing; - seed=123, + seed = 123, ) rng = StableRNG(seed) - env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng=rng) + env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng = rng) env = GridWorlds.RLBaseEnv(env) - env = RLEnvs.StateTransformedEnv(env;state_mapping=x -> vec(Float32.(x))) + env = RLEnvs.StateTransformedEnv(env; state_mapping = x -> vec(Float32.(x))) env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01)) env = MaxTimeoutEnv(env, 240) ns, na = length(state(env)), length(action_space(env)) agent = Agent( - policy=QBasedPolicy( - learner=BasicDQNLearner( - approximator=NeuralNetworkApproximator( - model=Chain( - Dense(ns, 128, relu; init=glorot_uniform(rng)), - Dense(128, 128, relu; init=glorot_uniform(rng)), - Dense(128, na; init=glorot_uniform(rng)), + policy = QBasedPolicy( + learner = BasicDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( + Dense(ns, 128, relu; init = glorot_uniform(rng)), + Dense(128, 128, relu; init = glorot_uniform(rng)), + Dense(128, na; init = glorot_uniform(rng)), ) |> cpu, - optimizer=ADAM(), + optimizer = ADAM(), ), - batch_size=32, - min_replay_history=100, - loss_func=huber_loss, - rng=rng, + batch_size = 32, + min_replay_history = 100, + loss_func = huber_loss, + rng = rng, ), - explorer=EpsilonGreedyExplorer( - kind=:exp, - ϵ_stable=0.01, - decay_steps=500, - rng=rng, + explorer = EpsilonGreedyExplorer( + kind = :exp, + ϵ_stable = 0.01, + decay_steps = 500, + rng = rng, ), ), - trajectory=CircularArraySARTTrajectory( - capacity=1000, - state=Vector{Float32} => (ns,), + trajectory = CircularArraySARTTrajectory( + capacity = 1000, + state = Vector{Float32} => (ns,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl index 7e922e13e..7e2e218f6 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl @@ -14,13 +14,13 @@ using Flux.Losses function build_dueling_network(network::Chain) lm = length(network) - if !(network[lm] isa Dense) || !(network[lm-1] isa Dense) + if !(network[lm] isa Dense) || !(network[lm-1] isa Dense) error("The Qnetwork provided is incompatible with dueling.") end - base = Chain([deepcopy(network[i]) for i=1:lm-2]...) + base = Chain([deepcopy(network[i]) for i in 1:lm-2]...) last_layer_dims = size(network[lm].weight, 2) val = Chain(deepcopy(network[lm-1]), Dense(last_layer_dims, 1)) - adv = Chain([deepcopy(network[i]) for i=lm-1:lm]...) + adv = Chain([deepcopy(network[i]) for i in lm-1:lm]...) return DuelingNetwork(base, val, adv) end @@ -37,8 +37,8 @@ function RL.Experiment( base_model = Chain( Dense(ns, 128, relu; init = glorot_uniform(rng)), Dense(128, 128, relu; init = glorot_uniform(rng)), - Dense(128, na; init = glorot_uniform(rng)) - ) + Dense(128, na; init = glorot_uniform(rng)), + ) agent = Agent( policy = QBasedPolicy( @@ -72,7 +72,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl index d8b1eb633..f74bcaea1 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl @@ -64,7 +64,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(40_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(40_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl index f3ab3c98f..cba0fcea5 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl @@ -71,7 +71,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl index 7fb238d28..c45bb0b03 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl @@ -17,58 +17,58 @@ function RL.Experiment( ::Val{:QRDQN}, ::Val{:CartPole}, ::Nothing; - seed=123, + seed = 123, ) N = 10 rng = StableRNG(seed) - env = CartPoleEnv(; T=Float32, rng=rng) + env = CartPoleEnv(; T = Float32, rng = rng) ns, na = length(state(env)), length(action_space(env)) init = glorot_uniform(rng) agent = Agent( - policy=QBasedPolicy( - learner=QRDQNLearner( - approximator=NeuralNetworkApproximator( - model=Chain( + policy = QBasedPolicy( + learner = QRDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( Dense(ns, 128, relu; init = init), Dense(128, 128, relu; init = init), Dense(128, N * na; init = init), ) |> cpu, - optimizer=ADAM(), + optimizer = ADAM(), ), - target_approximator=NeuralNetworkApproximator( - model=Chain( + target_approximator = NeuralNetworkApproximator( + model = Chain( Dense(ns, 128, relu; init = init), Dense(128, 128, relu; init = init), Dense(128, N * na; init = init), ) |> cpu, ), - stack_size=nothing, - batch_size=32, - update_horizon=1, - min_replay_history=100, - update_freq=1, - target_update_freq=100, - n_quantile=N, - rng=rng, + stack_size = nothing, + batch_size = 32, + update_horizon = 1, + min_replay_history = 100, + update_freq = 1, + target_update_freq = 100, + n_quantile = N, + rng = rng, ), - explorer=EpsilonGreedyExplorer( - kind=:exp, - ϵ_stable=0.01, - decay_steps=500, - rng=rng, + explorer = EpsilonGreedyExplorer( + kind = :exp, + ϵ_stable = 0.01, + decay_steps = 500, + rng = rng, ), ), - trajectory=CircularArraySARTTrajectory( - capacity=1000, - state=Vector{Float32} => (ns,), + trajectory = CircularArraySARTTrajectory( + capacity = 1000, + state = Vector{Float32} => (ns,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl index fdf473a83..7f74dd096 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl @@ -52,7 +52,7 @@ function RL.Experiment( update_freq = 1, target_update_freq = 100, ensemble_num = ensemble_num, - ensemble_method = :rand, + ensemble_method = :rand, rng = rng, ), explorer = EpsilonGreedyExplorer( @@ -68,7 +68,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl index f367d1cf5..d8f5d2437 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl @@ -71,7 +71,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl index fd61dfd42..02555fee9 100644 --- a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl +++ b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl @@ -35,14 +35,15 @@ function RL.Experiment( seed = 123, ) rng = StableRNG(seed) - + ## Encode the KuhnPokerEnv's states for training. env = KuhnPokerEnv() wrapped_env = StateTransformedEnv( env; state_mapping = s -> [findfirst(==(s), state_space(env))], - state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)] - ) + state_space_mapping = ss -> + [[findfirst(==(s), state_space(env))] for s in state_space(env)], + ) player = 1 # or 2 ns, na = length(state(wrapped_env, player)), length(action_space(wrapped_env, player)) @@ -53,14 +54,14 @@ function RL.Experiment( approximator = NeuralNetworkApproximator( model = Chain( Dense(ns, 64, relu; init = glorot_normal(rng)), - Dense(64, na; init = glorot_normal(rng)) + Dense(64, na; init = glorot_normal(rng)), ) |> cpu, optimizer = Descent(0.01), ), target_approximator = NeuralNetworkApproximator( model = Chain( Dense(ns, 64, relu; init = glorot_normal(rng)), - Dense(64, na; init = glorot_normal(rng)) + Dense(64, na; init = glorot_normal(rng)), ) |> cpu, ), γ = 1.0f0, @@ -81,7 +82,7 @@ function RL.Experiment( ), trajectory = CircularArraySARTTrajectory( capacity = 200_000, - state = Vector{Int} => (ns, ), + state = Vector{Int} => (ns,), ), ) @@ -89,9 +90,9 @@ function RL.Experiment( policy = BehaviorCloningPolicy(; approximator = NeuralNetworkApproximator( model = Chain( - Dense(ns, 64, relu; init = glorot_normal(rng)), - Dense(64, na; init = glorot_normal(rng)) - ) |> cpu, + Dense(ns, 64, relu; init = glorot_normal(rng)), + Dense(64, na; init = glorot_normal(rng)), + ) |> cpu, optimizer = Descent(0.01), ), explorer = WeightedSoftmaxExplorer(), @@ -111,19 +112,22 @@ function RL.Experiment( η = 0.1 # anticipatory parameter nfsp = NFSPAgentManager( Dict( - (player, NFSPAgent( - deepcopy(rl_agent), - deepcopy(sl_agent), - η, - rng, - 128, # update_freq - 0, # initial update_step - true, # initial NFSPAgent's learn mode - )) for player in players(wrapped_env) if player != chance_player(wrapped_env) - ) + ( + player, + NFSPAgent( + deepcopy(rl_agent), + deepcopy(sl_agent), + η, + rng, + 128, # update_freq + 0, # initial update_step + true, # initial NFSPAgent's learn mode + ), + ) for player in players(wrapped_env) if player != chance_player(wrapped_env) + ), ) - stop_condition = StopAfterEpisode(1_200_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(1_200_000, is_show_progress = !haskey(ENV, "CI")) hook = ResultNEpisode(10_000, 0, [], []) Experiment(nfsp, wrapped_env, stop_condition, hook, "# run NFSP on KuhnPokerEnv") @@ -133,8 +137,14 @@ end using Plots ex = E`JuliaRL_NFSP_KuhnPoker` run(ex) -plot(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="nash_conv") +plot( + ex.hook.episode, + ex.hook.results, + xaxis = :log, + xlabel = "episode", + ylabel = "nash_conv", +) savefig("assets/JuliaRL_NFSP_KuhnPoker.png")#hide -# ![](assets/JuliaRL_NFSP_KuhnPoker.png) \ No newline at end of file +# ![](assets/JuliaRL_NFSP_KuhnPoker.png) diff --git a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl index f1e42bb51..808719818 100644 --- a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl +++ b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl @@ -61,7 +61,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = RecordStateAction() run(agent, env, stop_condition, hook) @@ -84,7 +84,13 @@ function RL.Experiment( end hook = TotalRewardPerEpisode() - Experiment(bc, env, StopAfterEpisode(100, is_show_progress=!haskey(ENV, "CI")), hook, "BehaviorCloning <-> CartPole") + Experiment( + bc, + env, + StopAfterEpisode(100, is_show_progress = !haskey(ENV, "CI")), + hook, + "BehaviorCloning <-> CartPole", + ) end #+ tangle=false diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl index eeed90609..2ea78b92d 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl @@ -63,7 +63,7 @@ function RL.Experiment( terminal = Vector{Bool} => (N_ENV,), ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# A2CGAE with CartPole") end @@ -78,7 +78,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_A2CGAE_CartPole.png") #hide # ![](assets/JuliaRL_A2CGAE_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl index cb03dd025..e56733a92 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl @@ -59,7 +59,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=true) + stop_condition = StopAfterStep(50_000, is_show_progress = true) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# A2C with CartPole") end @@ -73,7 +73,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_A2C_CartPole.png") #hide # ![](assets/JuliaRL_A2C_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl index 927cdc2d2..a4ba28d58 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl @@ -79,11 +79,11 @@ function RL.Experiment( trajectory = CircularArraySARTTrajectory( capacity = 10000, state = Vector{Float32} => (ns,), - action = Float32 => (na, ), + action = Float32 => (na,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "# Play Pendulum with DDPG") end diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl index 3559b1b03..c526fcfe3 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl @@ -64,7 +64,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# MAC with CartPole") end @@ -78,7 +78,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_MAC_CartPole.png") #hide # ![](assets/JuliaRL_MAC_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl index 45cdd2e13..cbc0e3340 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl @@ -62,7 +62,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# PPO with CartPole") end @@ -76,7 +76,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_PPO_CartPole.png") #hide # ![](assets/JuliaRL_PPO_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl index 80625e104..1fd85f10c 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl @@ -32,7 +32,8 @@ function RL.Experiment( UPDATE_FREQ = 2048 env = MultiThreadEnv([ PendulumEnv(T = Float32, rng = StableRNG(hash(seed + i))) |> - env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high)) for i in 1:N_ENV + env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high)) + for i in 1:N_ENV ]) init = glorot_uniform(rng) @@ -78,7 +79,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# Play Pendulum with PPO") end @@ -92,7 +93,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_PPO_Pendulum.png") #hide # ![](assets/JuliaRL_PPO_Pendulum.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl index c15e3c07d..3ef66ddf6 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl @@ -38,12 +38,11 @@ function RL.Experiment( create_policy_net() = NeuralNetworkApproximator( model = GaussianNetwork( - pre = Chain( - Dense(ns, 30, relu), - Dense(30, 30, relu), - ), + pre = Chain(Dense(ns, 30, relu), Dense(30, 30, relu)), μ = Chain(Dense(30, na, init = init)), - logσ = Chain(Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init)), + logσ = Chain( + Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init), + ), ), optimizer = ADAM(0.003), ) @@ -84,7 +83,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "# Play Pendulum with SAC") end diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl index 49bc26c64..ab99a7278 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl @@ -86,7 +86,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "# Play Pendulum with TD3") end diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl index 87130f3dc..8cfbe125d 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl @@ -49,7 +49,7 @@ function RL.Experiment( ), trajectory = ElasticSARTTrajectory(state = Vector{Float32} => (ns,)), ) - stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(500, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() description = "# Play CartPole with VPG" diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl index 7bf2add71..56eb5d90e 100644 --- a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl +++ b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl @@ -83,7 +83,7 @@ function RL.Experiment( hook = ComposedHook( total_batch_reward_per_episode, batch_steps_per_episode, - DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env + DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env learner = agent.policy.policy.learner with_logger(lg) do @info "training" loss = learner.loss actor_loss = learner.actor_loss critic_loss = @@ -94,20 +94,22 @@ function RL.Experiment( DoEveryNStep() do t, agent, env with_logger(lg) do rewards = [ - total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i]) + total_batch_reward_per_episode.rewards[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(rewards) > 0 @info "training" rewards = mean(rewards) log_step_increment = 0 end steps = [ - batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i]) + batch_steps_per_episode.steps[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(steps) > 0 @info "training" steps = mean(steps) log_step_increment = 0 end end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." h = TotalBatchOriginalRewardPerEpisode(N_ENV) s = @elapsed run( diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl index e9eb574a3..4b33dc27f 100644 --- a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl +++ b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl @@ -85,7 +85,7 @@ function RL.Experiment( hook = ComposedHook( total_batch_reward_per_episode, batch_steps_per_episode, - DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env + DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env p = agent.policy with_logger(lg) do @info "training" loss = mean(p.loss) actor_loss = mean(p.actor_loss) critic_loss = @@ -93,7 +93,7 @@ function RL.Experiment( mean(p.norm) log_step_increment = UPDATE_FREQ end end, - DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env + DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env decay = (N_TRAINING_STEPS - t) / N_TRAINING_STEPS agent.policy.approximator.optimizer.eta = INIT_LEARNING_RATE * decay agent.policy.clip_range = INIT_CLIP_RANGE * Float32(decay) @@ -101,20 +101,22 @@ function RL.Experiment( DoEveryNStep() do t, agent, env with_logger(lg) do rewards = [ - total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i]) + total_batch_reward_per_episode.rewards[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(rewards) > 0 @info "training" rewards = mean(rewards) log_step_increment = 0 end steps = [ - batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i]) + batch_steps_per_episode.steps[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(steps) > 0 @info "training" steps = mean(steps) log_step_increment = 0 end end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." ## switch to GreedyExplorer? h = TotalBatchRewardPerEpisode(N_ENV) diff --git a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl index f31098c00..a6bfc1d4b 100644 --- a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl +++ b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl @@ -18,7 +18,13 @@ function RL.Experiment(::Val{:JuliaRL}, ::Val{:Minimax}, ::Val{:OpenSpiel}, game ) hooks = MultiAgentHook(0 => TotalRewardPerEpisode(), 1 => TotalRewardPerEpisode()) description = "# Play `$game` in OpenSpiel with Minimax" - Experiment(agents, env, StopAfterEpisode(1, is_show_progress=!haskey(ENV, "CI")), hooks, description) + Experiment( + agents, + env, + StopAfterEpisode(1, is_show_progress = !haskey(ENV, "CI")), + hooks, + description, + ) end using Plots diff --git a/docs/homepage/utils.jl b/docs/homepage/utils.jl index 816810d71..64787a91a 100644 --- a/docs/homepage/utils.jl +++ b/docs/homepage/utils.jl @@ -5,7 +5,7 @@ html(s) = "\n~~~$s~~~\n" function hfun_adddescription() d = locvar(:description) - isnothing(d) ? "" : F.fd2html(d, internal=true) + isnothing(d) ? "" : F.fd2html(d, internal = true) end function hfun_frontmatter() @@ -28,7 +28,7 @@ function hfun_byline() if isnothing(fm) "" else - "" + "" end end @@ -62,7 +62,7 @@ function hfun_appendix() if isfile(bib_in_cur_folder) bib_resolved = F.parse_rpath("/" * bib_in_cur_folder) else - bib_resolved = F.parse_rpath(bib; canonical=false, code=true) + bib_resolved = F.parse_rpath(bib; canonical = false, code = true) end bib = "" end @@ -74,7 +74,7 @@ function hfun_appendix() """ end -function lx_dcite(lxc,_) +function lx_dcite(lxc, _) content = F.content(lxc.braces[1]) "" |> html end @@ -92,7 +92,7 @@ end """ Possible layouts: """ -function lx_dfig(lxc,lxd) +function lx_dfig(lxc, lxd) content = F.content(lxc.braces[1]) info = split(content, ';') layout = info[1] @@ -111,7 +111,7 @@ function lx_dfig(lxc,lxd) end # (case 3) assume it is generated by code - src = F.parse_rpath(src; canonical=false, code=true) + src = F.parse_rpath(src; canonical = false, code = true) # !!! directly take from `lx_fig` in Franklin.jl fdir, fext = splitext(src) @@ -122,11 +122,10 @@ function lx_dfig(lxc,lxd) # then in both cases there can be a relative path set but the user may mean # that it's in the subfolder /output/ (if generated by code) so should look # both in the relpath and if not found and if /output/ not already last dir - candext = ifelse(isempty(fext), - (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,)) - for ext ∈ candext + candext = ifelse(isempty(fext), (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,)) + for ext in candext candpath = fdir * ext - syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) + syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) isfile(syspath) && return dfigure(layout, candpath, caption) end # now try in the output dir just in case (provided we weren't already @@ -134,20 +133,20 @@ function lx_dfig(lxc,lxd) p1, p2 = splitdir(fdir) @debug "TEST" p1 p2 if splitdir(p1)[2] != "output" - for ext ∈ candext + for ext in candext candpath = joinpath(splitdir(p1)[1], "output", p2 * ext) - syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) + syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) isfile(syspath) && return dfigure(layout, candpath, caption) end end end -function lx_aside(lxc,lxd) +function lx_aside(lxc, lxd) content = F.reprocess(F.content(lxc.braces[1]), lxd) "" |> html end -function lx_footnote(lxc,lxd) +function lx_footnote(lxc, lxd) content = F.reprocess(F.content(lxc.braces[1]), lxd) # workaround if startswith(content, "

") @@ -156,7 +155,7 @@ function lx_footnote(lxc,lxd) "$content" |> html end -function lx_appendix(lxc,lxd) +function lx_appendix(lxc, lxd) content = F.reprocess(F.content(lxc.braces[1]), lxd) "$content" |> html -end \ No newline at end of file +end diff --git a/docs/make.jl b/docs/make.jl index bd61a9b1c..f81451cfd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,11 +14,7 @@ end experiments, postprocess_cb, experiments_assets = makedemos("experiments") -assets = [ - "assets/favicon.ico", - "assets/custom.css", - experiments_assets -] +assets = ["assets/favicon.ico", "assets/custom.css", experiments_assets] makedocs( modules = [ @@ -53,7 +49,7 @@ makedocs( "RLEnvs" => "rlenvs.md", "RLZoo" => "rlzoo.md", ], - ] + ], ) postprocess_cb() diff --git a/src/DistributedReinforcementLearning/src/actor_model.jl b/src/DistributedReinforcementLearning/src/actor_model.jl index 3a2414c5c..af6b05214 100644 --- a/src/DistributedReinforcementLearning/src/actor_model.jl +++ b/src/DistributedReinforcementLearning/src/actor_model.jl @@ -1,19 +1,12 @@ -export AbstractMessage, - StartMsg, - StopMsg, - PingMsg, - PongMsg, - ProxyMsg, - actor, - self +export AbstractMessage, StartMsg, StopMsg, PingMsg, PongMsg, ProxyMsg, actor, self abstract type AbstractMessage end -struct StartMsg{A, K} <: AbstractMessage +struct StartMsg{A,K} <: AbstractMessage args::A kwargs::K - StartMsg(args...;kwargs...) = new{typeof(args), typeof(kwargs)}(args, kwargs) + StartMsg(args...; kwargs...) = new{typeof(args),typeof(kwargs)}(args, kwargs) end struct StopMsg <: AbstractMessage end @@ -45,9 +38,9 @@ const DEFAULT_MAILBOX_SIZE = 32 Create a task to handle messages one-by-one by calling `f(msg)`. A mailbox (`RemoteChannel`) is returned. """ -function actor(f;sz=DEFAULT_MAILBOX_SIZE) +function actor(f; sz = DEFAULT_MAILBOX_SIZE) RemoteChannel() do - Channel(sz;spawn=true) do ch + Channel(sz; spawn = true) do ch task_local_storage("MAILBOX", RemoteChannel(() -> ch)) while true msg = take!(ch) diff --git a/src/DistributedReinforcementLearning/src/core.jl b/src/DistributedReinforcementLearning/src/core.jl index 44ed67937..34a04571a 100644 --- a/src/DistributedReinforcementLearning/src/core.jl +++ b/src/DistributedReinforcementLearning/src/core.jl @@ -51,7 +51,7 @@ Base.@kwdef struct Trainer{P,S} sealer::S = deepcopy end -Trainer(p) = Trainer(;policy=p) +Trainer(p) = Trainer(; policy = p) function (trainer::Trainer)(msg::BatchDataMsg) update!(trainer.policy, msg.data) @@ -94,7 +94,7 @@ mutable struct Worker end function (w::Worker)(msg::StartMsg) - w.experiment = w.init(msg.args...;msg.kwargs...) + w.experiment = w.init(msg.args...; msg.kwargs...) w.task = Threads.@spawn run(w.experiment) end @@ -128,7 +128,7 @@ end function (wp::WorkerProxy)(::FetchParamMsg) if !wp.is_fetch_msg_sent[] put!(wp.target, FetchParamMsg(self())) - wp.is_fetch_msg_sent[] = true + wp.is_fetch_msg_sent[] = true end end @@ -172,9 +172,12 @@ function (orc::Orchestrator)(msg::InsertTrajectoryMsg) put!(orc.trajectory_proxy, BatchSampleMsg(orc.trainer)) L.n_sample += 1 if L.n_sample == (L.n_load + 1) * L.sample_load_ratio - put!(orc.trajectory_proxy, ProxyMsg(to=orc.trainer, msg=FetchParamMsg(orc.worker))) + put!( + orc.trajectory_proxy, + ProxyMsg(to = orc.trainer, msg = FetchParamMsg(orc.worker)), + ) L.n_load += 1 end end end -end \ No newline at end of file +end diff --git a/src/DistributedReinforcementLearning/src/extensions.jl b/src/DistributedReinforcementLearning/src/extensions.jl index efeda428b..1c6f32e9d 100644 --- a/src/DistributedReinforcementLearning/src/extensions.jl +++ b/src/DistributedReinforcementLearning/src/extensions.jl @@ -76,4 +76,4 @@ function (hook::FetchParamsHook)(::PostActStage, agent, env) end end end -end \ No newline at end of file +end diff --git a/src/DistributedReinforcementLearning/test/actor.jl b/src/DistributedReinforcementLearning/test/actor.jl index 34881c5b1..e2cc9ae87 100644 --- a/src/DistributedReinforcementLearning/test/actor.jl +++ b/src/DistributedReinforcementLearning/test/actor.jl @@ -1,53 +1,53 @@ @testset "basic tests" begin -Base.@kwdef mutable struct TestActor - state::Union{Nothing, Int} = nothing -end + Base.@kwdef mutable struct TestActor + state::Union{Nothing,Int} = nothing + end -struct CurrentStateMsg <: AbstractMessage - state -end + struct CurrentStateMsg <: AbstractMessage + state::Any + end -Base.@kwdef struct ReadStateMsg <: AbstractMessage - from = self() -end + Base.@kwdef struct ReadStateMsg <: AbstractMessage + from = self() + end -struct IncMsg <: AbstractMessage end -struct DecMsg <: AbstractMessage end + struct IncMsg <: AbstractMessage end + struct DecMsg <: AbstractMessage end -(x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1] -(x::TestActor)(msg::StopMsg) = x.state = nothing -(x::TestActor)(::IncMsg) = x.state += 1 -(x::TestActor)(::DecMsg) = x.state -= 1 -(x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state)) + (x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1] + (x::TestActor)(msg::StopMsg) = x.state = nothing + (x::TestActor)(::IncMsg) = x.state += 1 + (x::TestActor)(::DecMsg) = x.state -= 1 + (x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state)) -x = actor(TestActor()) -put!(x, StartMsg(0)) + x = actor(TestActor()) + put!(x, StartMsg(0)) -put!(x, ReadStateMsg()) -@test take!(self()).state == 0 + put!(x, ReadStateMsg()) + @test take!(self()).state == 0 -@sync begin - for _ in 1:100 - Threads.@spawn put!(x, IncMsg()) - Threads.@spawn put!(x, DecMsg()) - end - for _ in 1:10 - for _ in 1:10 + @sync begin + for _ in 1:100 Threads.@spawn put!(x, IncMsg()) + Threads.@spawn put!(x, DecMsg()) end for _ in 1:10 - Threads.@spawn put!(x, DecMsg()) + for _ in 1:10 + Threads.@spawn put!(x, IncMsg()) + end + for _ in 1:10 + Threads.@spawn put!(x, DecMsg()) + end end end -end -put!(x, ReadStateMsg()) -@test take!(self()).state == 0 + put!(x, ReadStateMsg()) + @test take!(self()).state == 0 -y = actor(TestActor()) -put!(x, ProxyMsg(;to=y,msg=StartMsg(0))) -put!(x, ProxyMsg(;to=y,msg=ReadStateMsg())) -@test take!(self()).state == 0 + y = actor(TestActor()) + put!(x, ProxyMsg(; to = y, msg = StartMsg(0))) + put!(x, ProxyMsg(; to = y, msg = ReadStateMsg())) + @test take!(self()).state == 0 -end \ No newline at end of file +end diff --git a/src/DistributedReinforcementLearning/test/core.jl b/src/DistributedReinforcementLearning/test/core.jl index 25d5c2c49..d4a196f37 100644 --- a/src/DistributedReinforcementLearning/test/core.jl +++ b/src/DistributedReinforcementLearning/test/core.jl @@ -1,181 +1,202 @@ @testset "core.jl" begin -@testset "Trainer" begin - _trainer = Trainer(; - policy=BasicDQNLearner( - approximator = NeuralNetworkApproximator( - model = Chain( - Dense(4, 128, relu; initW = glorot_uniform), - Dense(128, 128, relu; initW = glorot_uniform), - Dense(128, 2; initW = glorot_uniform), - ) |> cpu, - optimizer = ADAM(), + @testset "Trainer" begin + _trainer = Trainer(; + policy = BasicDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( + Dense(4, 128, relu; initW = glorot_uniform), + Dense(128, 128, relu; initW = glorot_uniform), + Dense(128, 2; initW = glorot_uniform), + ) |> cpu, + optimizer = ADAM(), + ), + loss_func = huber_loss, ), - loss_func = huber_loss, ) - ) - trainer = actor(_trainer) + trainer = actor(_trainer) - put!(trainer, FetchParamMsg()) - ps = take!(self()) - original_sum = sum(sum, ps.data) + put!(trainer, FetchParamMsg()) + ps = take!(self()) + original_sum = sum(sum, ps.data) - for x in ps.data - fill!(x, 0.) - end + for x in ps.data + fill!(x, 0.0) + end - put!(trainer, FetchParamMsg()) - ps = take!(self()) - new_sum = sum(sum, ps.data) + put!(trainer, FetchParamMsg()) + ps = take!(self()) + new_sum = sum(sum, ps.data) - # make sure no state sharing between messages - @test original_sum == new_sum + # make sure no state sharing between messages + @test original_sum == new_sum - batch_data = ( - state = rand(4, 32), - action = rand(1:2, 32), - reward = rand(32), - terminal = rand(Bool, 32), - next_state = rand(4,32), - next_action = rand(1:2, 32) - ) + batch_data = ( + state = rand(4, 32), + action = rand(1:2, 32), + reward = rand(32), + terminal = rand(Bool, 32), + next_state = rand(4, 32), + next_action = rand(1:2, 32), + ) - put!(trainer, BatchDataMsg(batch_data)) + put!(trainer, BatchDataMsg(batch_data)) - put!(trainer, FetchParamMsg()) - ps = take!(self()) - updated_sum = sum(sum, ps.data) - @test original_sum != updated_sum -end + put!(trainer, FetchParamMsg()) + ps = take!(self()) + updated_sum = sum(sum, ps.data) + @test original_sum != updated_sum + end -@testset "TrajectoryManager" begin - _trajectory_proxy = TrajectoryManager( - trajectory = CircularSARTSATrajectory(;capacity=5, state_type=Any, ), - sampler = UniformBatchSampler(3), - inserter = NStepInserter(), - ) + @testset "TrajectoryManager" begin + _trajectory_proxy = TrajectoryManager( + trajectory = CircularSARTSATrajectory(; capacity = 5, state_type = Any), + sampler = UniformBatchSampler(3), + inserter = NStepInserter(), + ) - trajectory_proxy = actor(_trajectory_proxy) + trajectory_proxy = actor(_trajectory_proxy) - # 1. init traj for testing - traj = CircularCompactSARTSATrajectory( - capacity = 2, - state_type = Float32, - state_size = (4,), - ) - push!(traj;state=rand(Float32, 4), action=rand(1:2)) - push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2)) - push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2)) + # 1. init traj for testing + traj = CircularCompactSARTSATrajectory( + capacity = 2, + state_type = Float32, + state_size = (4,), + ) + push!(traj; state = rand(Float32, 4), action = rand(1:2)) + push!( + traj; + reward = rand(), + terminal = rand(Bool), + state = rand(Float32, 4), + action = rand(1:2), + ) + push!( + traj; + reward = rand(), + terminal = rand(Bool), + state = rand(Float32, 4), + action = rand(1:2), + ) - # 2. insert - put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here + # 2. insert + put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here - # 3. make sure the above message is already been handled - put!(trajectory_proxy, PingMsg()) - take!(self()) + # 3. make sure the above message is already been handled + put!(trajectory_proxy, PingMsg()) + take!(self()) - # 4. test that updating traj will not affect data in trajectory_proxy - s_tp = _trajectory_proxy.trajectory[:state] - s_traj = traj[:state] + # 4. test that updating traj will not affect data in trajectory_proxy + s_tp = _trajectory_proxy.trajectory[:state] + s_traj = traj[:state] - @test s_tp[1] == s_traj[:, 1] + @test s_tp[1] == s_traj[:, 1] - push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2)) + push!( + traj; + reward = rand(), + terminal = rand(Bool), + state = rand(Float32, 4), + action = rand(1:2), + ) - @test s_tp[1] != s_traj[:, 1] + @test s_tp[1] != s_traj[:, 1] - s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler) - fill!(s[:state], 0.) - @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy -end + s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler) + fill!(s[:state], 0.0) + @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy + end -@testset "Worker" begin - _worker = Worker() do worker_proxy - Experiment( - Agent( - policy = StaticPolicy( + @testset "Worker" begin + _worker = Worker() do worker_proxy + Experiment( + Agent( + policy = StaticPolicy( QBasedPolicy( - learner = BasicDQNLearner( - approximator = NeuralNetworkApproximator( - model = Chain( - Dense(4, 128, relu; initW = glorot_uniform), - Dense(128, 128, relu; initW = glorot_uniform), - Dense(128, 2; initW = glorot_uniform), - ) |> cpu, - optimizer = ADAM(), + learner = BasicDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( + Dense(4, 128, relu; initW = glorot_uniform), + Dense(128, 128, relu; initW = glorot_uniform), + Dense(128, 2; initW = glorot_uniform), + ) |> cpu, + optimizer = ADAM(), + ), + loss_func = huber_loss, + ), + explorer = EpsilonGreedyExplorer( + kind = :exp, + ϵ_stable = 0.01, + decay_steps = 500, ), - loss_func = huber_loss, - ), - explorer = EpsilonGreedyExplorer( - kind = :exp, - ϵ_stable = 0.01, - decay_steps = 500, ), ), + trajectory = CircularCompactSARTSATrajectory( + capacity = 10, + state_type = Float32, + state_size = (4,), + ), ), - trajectory = CircularCompactSARTSATrajectory( - capacity = 10, - state_type = Float32, - state_size = (4,), + CartPoleEnv(; T = Float32), + ComposedStopCondition(StopAfterStep(1_000), StopSignal()), + ComposedHook( + UploadTrajectoryEveryNStep( + mailbox = worker_proxy, + n = 10, + sealer = x -> InsertTrajectoryMsg(deepcopy(x)), + ), + LoadParamsHook(), + TotalRewardPerEpisode(), ), - ), - CartPoleEnv(; T = Float32), - ComposedStopCondition( - StopAfterStep(1_000), - StopSignal(), - ), - ComposedHook( - UploadTrajectoryEveryNStep(mailbox=worker_proxy, n=10, sealer=x -> InsertTrajectoryMsg(deepcopy(x))), - LoadParamsHook(), - TotalRewardPerEpisode(), - ), - "experimenting..." - ) - end + "experimenting...", + ) + end - worker = actor(_worker) - tmp_mailbox = Channel(100) - put!(worker, StartMsg(tmp_mailbox)) -end - -@testset "WorkerProxy" begin - target = RemoteChannel(() -> Channel(10)) - workers = [RemoteChannel(()->Channel(10)) for _ in 1:10] - _wp = WorkerProxy(workers) - wp = actor(_wp) - - put!(wp, StartMsg(target)) - for w in workers - # @test take!(w).args[1] === wp - @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) === Distributed.channel_from_id(remoteref_id(wp)) + worker = actor(_worker) + tmp_mailbox = Channel(100) + put!(worker, StartMsg(tmp_mailbox)) end - msg = InsertTrajectoryMsg(1) - put!(wp, msg) - @test take!(target) === msg - - for w in workers - put!(wp, FetchParamMsg(w)) + @testset "WorkerProxy" begin + target = RemoteChannel(() -> Channel(10)) + workers = [RemoteChannel(() -> Channel(10)) for _ in 1:10] + _wp = WorkerProxy(workers) + wp = actor(_wp) + + put!(wp, StartMsg(target)) + for w in workers + # @test take!(w).args[1] === wp + @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) === + Distributed.channel_from_id(remoteref_id(wp)) + end + + msg = InsertTrajectoryMsg(1) + put!(wp, msg) + @test take!(target) === msg + + for w in workers + put!(wp, FetchParamMsg(w)) + end + # @test take!(target).from === wp + @test Distributed.channel_from_id(remoteref_id(take!(target).from)) === + Distributed.channel_from_id(remoteref_id(wp)) + + # make sure target only received one FetchParamMsg + msg = PingMsg() + put!(target, msg) + @test take!(target) === msg + + msg = LoadParamMsg([]) + put!(wp, msg) + for w in workers + @test take!(w) === msg + end end - # @test take!(target).from === wp - @test Distributed.channel_from_id(remoteref_id(take!(target).from)) === Distributed.channel_from_id(remoteref_id(wp)) - - # make sure target only received one FetchParamMsg - msg = PingMsg() - put!(target, msg) - @test take!(target) === msg - - msg = LoadParamMsg([]) - put!(wp, msg) - for w in workers - @test take!(w) === msg + + @testset "Orchestrator" begin + # TODO + # Add an integration test end -end -@testset "Orchestrator" begin - # TODO - # Add an integration test end - -end \ No newline at end of file diff --git a/src/DistributedReinforcementLearning/test/runtests.jl b/src/DistributedReinforcementLearning/test/runtests.jl index f9f572d55..a73cd0521 100644 --- a/src/DistributedReinforcementLearning/test/runtests.jl +++ b/src/DistributedReinforcementLearning/test/runtests.jl @@ -9,7 +9,7 @@ using Flux @testset "DistributedReinforcementLearning.jl" begin -include("actor.jl") -include("core.jl") + include("actor.jl") + include("core.jl") end diff --git a/src/ReinforcementLearningBase/src/CommonRLInterface.jl b/src/ReinforcementLearningBase/src/CommonRLInterface.jl index af86a2d83..28c73ec40 100644 --- a/src/ReinforcementLearningBase/src/CommonRLInterface.jl +++ b/src/ReinforcementLearningBase/src/CommonRLInterface.jl @@ -41,7 +41,8 @@ end # !!! may need to be extended by user CRL.@provide CRL.observe(env::CommonRLEnv) = state(env.env) -CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = !isnothing(find_state_style(env.env, InternalState)) +CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = + !isnothing(find_state_style(env.env, InternalState)) CRL.state(env::CommonRLEnv) = state(env.env, find_state_style(env.env, InternalState)) CRL.@provide CRL.clone(env::CommonRLEnv) = CommonRLEnv(copy(env.env)) @@ -94,4 +95,4 @@ ActionStyle(env::RLBaseEnv) = CRL.provided(CRL.valid_actions, env.env) ? FullActionSet() : MinimalActionSet() current_player(env::RLBaseEnv) = CRL.player(env.env) -players(env::RLBaseEnv) = CRL.players(env.env) \ No newline at end of file +players(env::RLBaseEnv) = CRL.players(env.env) diff --git a/src/ReinforcementLearningBase/src/interface.jl b/src/ReinforcementLearningBase/src/interface.jl index 635bdf87f..d3ade2d75 100644 --- a/src/ReinforcementLearningBase/src/interface.jl +++ b/src/ReinforcementLearningBase/src/interface.jl @@ -410,12 +410,13 @@ Make an independent copy of `env`, !!! warning Only check the state of all players in the env. """ -function Base.:(==)(env1::T, env2::T) where T<:AbstractEnv +function Base.:(==)(env1::T, env2::T) where {T<:AbstractEnv} len = length(players(env1)) - len == length(players(env2)) && - all(state(env1, player) == state(env2, player) for player in players(env1)) + len == length(players(env2)) && + all(state(env1, player) == state(env2, player) for player in players(env1)) end -Base.hash(env::AbstractEnv, h::UInt) = hash([state(env, player) for player in players(env)], h) +Base.hash(env::AbstractEnv, h::UInt) = + hash([state(env, player) for player in players(env)], h) @api nameof(env::AbstractEnv) = nameof(typeof(env)) diff --git a/src/ReinforcementLearningBase/test/CommonRLInterface.jl b/src/ReinforcementLearningBase/test/CommonRLInterface.jl index fc38a102b..7b32dbe08 100644 --- a/src/ReinforcementLearningBase/test/CommonRLInterface.jl +++ b/src/ReinforcementLearningBase/test/CommonRLInterface.jl @@ -1,34 +1,34 @@ @testset "CommonRLInterface" begin -@testset "MDPEnv" begin - struct RLTestMDP <: MDP{Int, Int} end + @testset "MDPEnv" begin + struct RLTestMDP <: MDP{Int,Int} end - POMDPs.actions(m::RLTestMDP) = [-1, 1] - POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) - POMDPs.initialstate(m::RLTestMDP) = Deterministic(1) - POMDPs.isterminal(m::RLTestMDP, s) = s == 3 - POMDPs.reward(m::RLTestMDP, s, a, sp) = sp - POMDPs.states(m::RLTestMDP) = 1:3 + POMDPs.actions(m::RLTestMDP) = [-1, 1] + POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) + POMDPs.initialstate(m::RLTestMDP) = Deterministic(1) + POMDPs.isterminal(m::RLTestMDP, s) = s == 3 + POMDPs.reward(m::RLTestMDP, s, a, sp) = sp + POMDPs.states(m::RLTestMDP) = 1:3 - env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP())) - RLBase.test_runnable!(env) -end + env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP())) + RLBase.test_runnable!(env) + end -@testset "POMDPEnv" begin + @testset "POMDPEnv" begin - struct RLTestPOMDP <: POMDP{Int, Int, Int} end + struct RLTestPOMDP <: POMDP{Int,Int,Int} end - POMDPs.actions(m::RLTestPOMDP) = [-1, 1] - POMDPs.states(m::RLTestPOMDP) = 1:3 - POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) - POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1) - POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1) - POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1) - POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3 - POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp - POMDPs.observations(m::RLTestPOMDP) = 2:4 + POMDPs.actions(m::RLTestPOMDP) = [-1, 1] + POMDPs.states(m::RLTestPOMDP) = 1:3 + POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) + POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1) + POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1) + POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1) + POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3 + POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp + POMDPs.observations(m::RLTestPOMDP) = 2:4 - env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP())) + env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP())) - RLBase.test_runnable!(env) + RLBase.test_runnable!(env) + end end -end \ No newline at end of file diff --git a/src/ReinforcementLearningBase/test/runtests.jl b/src/ReinforcementLearningBase/test/runtests.jl index 6b44a29b8..d4f743f68 100644 --- a/src/ReinforcementLearningBase/test/runtests.jl +++ b/src/ReinforcementLearningBase/test/runtests.jl @@ -8,5 +8,5 @@ using POMDPs using POMDPModelTools: Deterministic @testset "ReinforcementLearningBase" begin -include("CommonRLInterface.jl") -end \ No newline at end of file + include("CommonRLInterface.jl") +end diff --git a/src/ReinforcementLearningCore/src/core/hooks.jl b/src/ReinforcementLearningCore/src/core/hooks.jl index 11c65e686..a8a58be3e 100644 --- a/src/ReinforcementLearningCore/src/core/hooks.jl +++ b/src/ReinforcementLearningCore/src/core/hooks.jl @@ -13,7 +13,7 @@ export AbstractHook, UploadTrajectoryEveryNStep, MultiAgentHook -using UnicodePlots:lineplot, lineplot! +using UnicodePlots: lineplot, lineplot! using Statistics """ @@ -155,7 +155,14 @@ end function (hook::TotalRewardPerEpisode)(::PostExperimentStage, agent, env) if hook.is_display_on_exit - println(lineplot(hook.rewards, title="Total reward per episode", xlabel="Episode", ylabel="Score")) + println( + lineplot( + hook.rewards, + title = "Total reward per episode", + xlabel = "Episode", + ylabel = "Score", + ), + ) end end @@ -178,8 +185,12 @@ which return a `Vector` of rewards (a typical case with `MultiThreadEnv`). If `is_display_on_exit` is set to `true`, a ribbon plot will be shown to reflect the mean and std of rewards. """ -function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit=true) - TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size), is_display_on_exit) +function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit = true) + TotalBatchRewardPerEpisode( + [Float64[] for _ in 1:batch_size], + zeros(batch_size), + is_display_on_exit, + ) end function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env) @@ -198,7 +209,12 @@ function (hook::TotalBatchRewardPerEpisode)(::PostExperimentStage, agent, env) n = minimum(map(length, hook.rewards)) m = mean([@view(x[1:n]) for x in hook.rewards]) s = std([@view(x[1:n]) for x in hook.rewards]) - p = lineplot(m, title="Avg total reward per episode", xlabel="Episode", ylabel="Score") + p = lineplot( + m, + title = "Avg total reward per episode", + xlabel = "Episode", + ylabel = "Score", + ) lineplot!(p, m .- s) lineplot!(p, m .+ s) println(p) @@ -288,8 +304,7 @@ end Execute `f(t, agent, env)` every `n` episode. `t` is a counter of episodes. """ -mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: - AbstractHook +mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: AbstractHook f::F n::Int t::Int diff --git a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl index d641c615b..507907b8a 100644 --- a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl +++ b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl @@ -1,7 +1,10 @@ using ArrayInterface -function ArrayInterface.restructure(x::AbstractArray{T1, 0}, y::AbstractArray{T2, 0}) where {T1, T2} +function ArrayInterface.restructure( + x::AbstractArray{T1,0}, + y::AbstractArray{T2,0}, +) where {T1,T2} out = similar(x, eltype(y)) out .= y out -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl index a5d143530..8dfb73c12 100644 --- a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl @@ -140,7 +140,7 @@ end function fetch!( sampler::NStepBatchSampler{traces}, - traj::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory}, + traj::Union{CircularArraySARTTrajectory,CircularArraySLARTTrajectory}, inds::Vector{Int}, ) where {traces} γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl index 7f89e2f15..d0adb2015 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl @@ -78,7 +78,7 @@ Base.@kwdef struct GaussianNetwork{P,U,S} pre::P = identity μ::U logσ::S - min_σ::Float32 = 0f0 + min_σ::Float32 = 0.0f0 max_σ::Float32 = Inf32 end @@ -91,15 +91,24 @@ This function is compatible with a multidimensional action space. When outputtin - `is_sampling::Bool=false`, whether to sample from the obtained normal distribution. - `is_return_log_prob::Bool=false`, whether to calculate the conditional probability of getting actions in the given state. """ -function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false) +function (model::GaussianNetwork)( + rng::AbstractRNG, + state; + is_sampling::Bool = false, + is_return_log_prob::Bool = false, +) x = model.pre(state) - μ, raw_logσ = model.μ(x), model.logσ(x) + μ, raw_logσ = model.μ(x), model.logσ(x) logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ)) if is_sampling π_dist = Normal.(μ, exp.(logσ)) z = rand.(rng, π_dist) if is_return_log_prob - logp_π = sum(logpdf.(π_dist, z) .- (2.0f0 .* (log(2.0f0) .- z .- softplus.(-2.0f0 .* z))), dims = 1) + logp_π = sum( + logpdf.(π_dist, z) .- + (2.0f0 .* (log(2.0f0) .- z .- softplus.(-2.0f0 .* z))), + dims = 1, + ) return tanh.(z), logp_π else return tanh.(z) @@ -109,8 +118,17 @@ function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=fal end end -function (model::GaussianNetwork)(state; is_sampling::Bool=false, is_return_log_prob::Bool=false) - model(Random.GLOBAL_RNG, state; is_sampling=is_sampling, is_return_log_prob=is_return_log_prob) +function (model::GaussianNetwork)( + state; + is_sampling::Bool = false, + is_return_log_prob::Bool = false, +) + model( + Random.GLOBAL_RNG, + state; + is_sampling = is_sampling, + is_return_log_prob = is_return_log_prob, + ) end ##### @@ -133,5 +151,5 @@ Flux.@functor DuelingNetwork function (m::DuelingNetwork)(state) x = m.base(state) val = m.val(x) - return val .+ m.adv(x) .- mean(m.adv(x), dims=1) + return val .+ m.adv(x) .- mean(m.adv(x), dims = 1) end diff --git a/src/ReinforcementLearningCore/test/components/trajectories.jl b/src/ReinforcementLearningCore/test/components/trajectories.jl index 8fb8eb9ae..c7c60e163 100644 --- a/src/ReinforcementLearningCore/test/components/trajectories.jl +++ b/src/ReinforcementLearningCore/test/components/trajectories.jl @@ -52,9 +52,9 @@ t = CircularArraySLARTTrajectory( capacity = 3, state = Vector{Int} => (4,), - legal_actions_mask = Vector{Bool} => (4, ), + legal_actions_mask = Vector{Bool} => (4,), ) - + # test instance type is same as type @test isa(t, CircularArraySLARTTrajectory) diff --git a/src/ReinforcementLearningCore/test/core/core.jl b/src/ReinforcementLearningCore/test/core/core.jl index fb6f5fde5..56a81b809 100644 --- a/src/ReinforcementLearningCore/test/core/core.jl +++ b/src/ReinforcementLearningCore/test/core/core.jl @@ -1,22 +1,18 @@ @testset "simple workflow" begin - env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy) + env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy) policy = RandomPolicy(action_space(env)) N_EPISODE = 10_000 hook = TotalRewardPerEpisode() run(policy, env, StopAfterEpisode(N_EPISODE), hook) - @test isapprox(sum(hook[]) / N_EPISODE, 21; atol=2) + @test isapprox(sum(hook[]) / N_EPISODE, 21; atol = 2) end @testset "multi agent" begin # https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/393 rps = RockPaperScissorsEnv() |> SequentialEnv - ma_policy = MultiAgentManager( - ( - NamedPolicy(p => RandomPolicy()) - for p in players(rps) - )... - ) + ma_policy = + MultiAgentManager((NamedPolicy(p => RandomPolicy()) for p in players(rps))...) run(ma_policy, rps, StopAfterEpisode(10)) end diff --git a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl index a2d657fa1..fc5b0bfc8 100644 --- a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl +++ b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl @@ -1,5 +1,5 @@ @testset "test StopAfterNoImprovement" begin - env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy) + env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy) policy = RandomPolicy(action_space(env)) total_reward_per_episode = TotalRewardPerEpisode() @@ -14,7 +14,8 @@ hook = ComposedHook(total_reward_per_episode) run(policy, env, stop_condition, hook) - @test argmax(total_reward_per_episode.rewards) + patience == length(total_reward_per_episode.rewards) + @test argmax(total_reward_per_episode.rewards) + patience == + length(total_reward_per_episode.rewards) end @testset "StopAfterNSeconds" begin diff --git a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl index 0f5569670..7448a7b2f 100644 --- a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl +++ b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl @@ -13,4 +13,4 @@ include("init.jl") include("d4rl/d4rl_dataset.jl") include("atari/atari_dataset.jl") -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl b/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl index 9b9ca9bcc..64def20a3 100644 --- a/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl +++ b/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl @@ -14,8 +14,8 @@ Represents an iterable dataset of type AtariDataSet with the following fields: `meta`: Dict, the metadata provided along with the dataset `is_shuffle`: Bool, determines if the batches returned by `iterate` are shuffled. """ -struct AtariDataSet{T<:AbstractRNG} <:RLDataSet - dataset::Dict{Symbol, Any} +struct AtariDataSet{T<:AbstractRNG} <: RLDataSet + dataset::Dict{Symbol,Any} epochs::Vector{Int} repo::String length::Integer @@ -48,30 +48,31 @@ The `AtariDataSet` type is an iterable that fetches batches when used in a for l The returned type is an infinite iterator which can be called using `iterate` and will return batches as specified in the dataset. """ -function dataset(game::String, +function dataset( + game::String, index::Int, epochs::Vector{Int}; - style=SARTS, + style = SARTS, repo = "atari-replay-datasets", - rng = StableRNG(123), - is_shuffle = true, - batch_size=256 + rng = StableRNG(123), + is_shuffle = true, + batch_size = 256, ) - - try + + try @datadep_str "$repo-$game-$index" catch - throw("The provided dataset is not available") + throw("The provided dataset is not available") end - - path = @datadep_str "$repo-$game-$index" + + path = @datadep_str "$repo-$game-$index" @assert length(readdir(path)) == 1 folder_name = readdir(path)[1] - + folder_path = "$path/$folder_name" files = readdir(folder_path) - file_prefixes = collect(Set(map(x->join(split(x,"_")[1:2], "_"), files))) + file_prefixes = collect(Set(map(x -> join(split(x, "_")[1:2], "_"), files))) fields = map(collect(file_prefixes)) do x if split(x, "_")[1] == "\$store\$" x = split(x, "_")[2] @@ -81,7 +82,7 @@ function dataset(game::String, end s_epochs = Set(epochs) - + dataset = Dict() for (prefix, field) in zip(file_prefixes, fields) @@ -98,9 +99,9 @@ function dataset(game::String, if haskey(dataset, field) if field == "observation" - dataset[field] = cat(dataset[field], data, dims=3) + dataset[field] = cat(dataset[field], data, dims = 3) else - dataset[field] = cat(dataset[field], data, dims=1) + dataset[field] = cat(dataset[field], data, dims = 1) end else dataset[field] = data @@ -110,24 +111,37 @@ function dataset(game::String, num_epochs = length(s_epochs) - atari_verify(dataset, num_epochs) + atari_verify(dataset, num_epochs) N_samples = size(dataset["observation"])[3] - - final_dataset = Dict{Symbol, Any}() - meta = Dict{String, Any}() - for (key, d_key) in zip(["observation", "action", "reward", "terminal"], Symbol.(["state", "action", "reward", "terminal"])) - final_dataset[d_key] = dataset[key] + final_dataset = Dict{Symbol,Any}() + meta = Dict{String,Any}() + + for (key, d_key) in zip( + ["observation", "action", "reward", "terminal"], + Symbol.(["state", "action", "reward", "terminal"]), + ) + final_dataset[d_key] = dataset[key] end - + for key in keys(dataset) if !(key in ["observation", "action", "reward", "terminal"]) meta[key] = dataset[key] end end - return AtariDataSet(final_dataset, epochs, repo, N_samples, batch_size, style, rng, meta, is_shuffle) + return AtariDataSet( + final_dataset, + epochs, + repo, + N_samples, + batch_size, + style, + rng, + meta, + is_shuffle, + ) end @@ -141,7 +155,7 @@ function iterate(ds::AtariDataSet, state = 0) if is_shuffle inds = rand(rng, 1:length-1, batch_size) else - if (state+1) * batch_size <= length + if (state + 1) * batch_size <= length inds = state*batch_size+1:(state+1)*batch_size else return nothing @@ -149,15 +163,17 @@ function iterate(ds::AtariDataSet, state = 0) state += 1 end - batch = (state = view(ds.dataset[:state], :, :, inds), - action = view(ds.dataset[:action], inds), - reward = view(ds.dataset[:reward], inds), - terminal = view(ds.dataset[:terminal], inds)) + batch = ( + state = view(ds.dataset[:state], :, :, inds), + action = view(ds.dataset[:action], inds), + reward = view(ds.dataset[:reward], inds), + terminal = view(ds.dataset[:terminal], inds), + ) if style == SARTS - batch = merge(batch, (next_state = view(ds.dataset[:state], :, :, (1).+(inds)),)) + batch = merge(batch, (next_state = view(ds.dataset[:state], :, :, (1) .+ (inds)),)) end - + return batch, state end @@ -167,8 +183,9 @@ length(ds::AtariDataSet) = ds.length IteratorEltype(::Type{AtariDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit) function atari_verify(dataset::Dict, num_epochs::Int) - @assert size(dataset["observation"]) == (atari_frame_size, atari_frame_size, num_epochs*samples_per_epoch) + @assert size(dataset["observation"]) == + (atari_frame_size, atari_frame_size, num_epochs * samples_per_epoch) @assert size(dataset["action"]) == (num_epochs * samples_per_epoch,) @assert size(dataset["reward"]) == (num_epochs * samples_per_epoch,) @assert size(dataset["terminal"]) == (num_epochs * samples_per_epoch,) -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/atari/register.jl b/src/ReinforcementLearningDatasets/src/atari/register.jl index 61862e380..3479fd30c 100644 --- a/src/ReinforcementLearningDatasets/src/atari/register.jl +++ b/src/ReinforcementLearningDatasets/src/atari/register.jl @@ -4,24 +4,75 @@ export atari_init const atari_checksum = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" const ATARI_GAMES = [ - "air-raid", "alien", "amidar", "assault", "asterix", - "asteroids", "atlantis", "bank-heist", "battle-zone", "beam-rider", - "berzerk", "bowling", "boxing", "breakout", "carnival", "centipede", - "chopper-command", "crazy-climber", "demon-attack", - "double-dunk", "elevator-action", "enduro", "fishing-derby", "freeway", - "frostbite", "gopher", "gravitar", "hero", "ice-hockey", "jamesbond", - "journey-escape", "kangaroo", "krull", "kung-fu-master", - "montezuma-revenge", "ms-pacman", "name-this-game", "phoenix", - "pitfall", "pong", "pooyan", "private-eye", "qbert", "riverraid", - "road-runner", "robotank", "seaquest", "skiing", "solaris", - "space-invaders", "star-gunner", "tennis", "time-pilot", "tutankham", - "up-n-down", "venture", "video-pinball", "wizard-of-wor", - "yars-revenge", "zaxxon" + "air-raid", + "alien", + "amidar", + "assault", + "asterix", + "asteroids", + "atlantis", + "bank-heist", + "battle-zone", + "beam-rider", + "berzerk", + "bowling", + "boxing", + "breakout", + "carnival", + "centipede", + "chopper-command", + "crazy-climber", + "demon-attack", + "double-dunk", + "elevator-action", + "enduro", + "fishing-derby", + "freeway", + "frostbite", + "gopher", + "gravitar", + "hero", + "ice-hockey", + "jamesbond", + "journey-escape", + "kangaroo", + "krull", + "kung-fu-master", + "montezuma-revenge", + "ms-pacman", + "name-this-game", + "phoenix", + "pitfall", + "pong", + "pooyan", + "private-eye", + "qbert", + "riverraid", + "road-runner", + "robotank", + "seaquest", + "skiing", + "solaris", + "space-invaders", + "star-gunner", + "tennis", + "time-pilot", + "tutankham", + "up-n-down", + "venture", + "video-pinball", + "wizard-of-wor", + "yars-revenge", + "zaxxon", ] function fetch_atari_ds(src, dest) - try run(`which gsutil`) catch x throw("gsutil not found, install gsutil to proceed further") end - + try + run(`which gsutil`) + catch x + throw("gsutil not found, install gsutil to proceed further") + end + run(`gsutil -m cp -r $src $dest`) return dest end @@ -50,9 +101,9 @@ function atari_init() encountered during training into 5 replay datasets per game, resulting in a total of 300 datasets. """, "gs://atari-replay-datasets/dqn/$(game_name(game))/$index/replay_logs/"; - fetch_method = fetch_atari_ds - ) + fetch_method = fetch_atari_ds, + ), ) end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/common.jl b/src/ReinforcementLearningDatasets/src/common.jl index 587a036bd..8ec1ac432 100644 --- a/src/ReinforcementLearningDatasets/src/common.jl +++ b/src/ReinforcementLearningDatasets/src/common.jl @@ -5,4 +5,4 @@ export RLDataSet abstract type RLDataSet end const SARTS = (:state, :action, :reward, :terminal, :next_state) -const SART = (:state, :action, :reward, :terminal) \ No newline at end of file +const SART = (:state, :action, :reward, :terminal) diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl index d307aa560..c8f598a2b 100644 --- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl +++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl @@ -7,7 +7,7 @@ This file holds the registration information for d4rl datasets. It also registers the information in DataDeps for further use in this package. """ -const D4RL_DATASET_URLS = Dict{String, String}( +const D4RL_DATASET_URLS = Dict{String,String}( "maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5", "maze2d-umaze-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5", "maze2d-medium-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5", @@ -63,209 +63,209 @@ const D4RL_DATASET_URLS = Dict{String, String}( "antmaze-medium-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", "antmaze-large-play-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", "antmaze-large-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", - "flow-ring-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5", - "flow-ring-controller-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5", - "flow-merge-random-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5", - "flow-merge-controller-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5", + "flow-ring-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5", + "flow-ring-controller-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5", + "flow-merge-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5", + "flow-merge-controller-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5", "kitchen-complete-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5", "kitchen-partial-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5", "kitchen-mixed-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5", - "carla-lane-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5", - "carla-town-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5", - "carla-town-full-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", - "bullet-halfcheetah-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5", - "bullet-halfcheetah-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5", - "bullet-halfcheetah-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5", - "bullet-halfcheetah-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5", - "bullet-halfcheetah-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5", - "bullet-hopper-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5", - "bullet-hopper-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5", - "bullet-hopper-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5", - "bullet-hopper-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5", - "bullet-hopper-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5", - "bullet-ant-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5", - "bullet-ant-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5", - "bullet-ant-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5", - "bullet-ant-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5", - "bullet-ant-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5", - "bullet-walker2d-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5", - "bullet-walker2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5", - "bullet-walker2d-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5", - "bullet-walker2d-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5", - "bullet-walker2d-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5", - "bullet-maze2d-open-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5", - "bullet-maze2d-umaze-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5", - "bullet-maze2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5", - "bullet-maze2d-large-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5", + "carla-lane-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5", + "carla-town-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5", + "carla-town-full-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", + "bullet-halfcheetah-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5", + "bullet-halfcheetah-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5", + "bullet-halfcheetah-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5", + "bullet-halfcheetah-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5", + "bullet-halfcheetah-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5", + "bullet-hopper-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5", + "bullet-hopper-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5", + "bullet-hopper-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5", + "bullet-hopper-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5", + "bullet-hopper-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5", + "bullet-ant-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5", + "bullet-ant-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5", + "bullet-ant-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5", + "bullet-ant-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5", + "bullet-ant-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5", + "bullet-walker2d-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5", + "bullet-walker2d-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5", + "bullet-walker2d-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5", + "bullet-walker2d-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5", + "bullet-walker2d-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5", + "bullet-maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5", + "bullet-maze2d-umaze-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5", + "bullet-maze2d-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5", + "bullet-maze2d-large-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5", ) -const D4RL_REF_MIN_SCORE = Dict{String, Float32}( - "maze2d-open-v0" => 0.01 , - "maze2d-umaze-v1" => 23.85 , - "maze2d-medium-v1" => 13.13 , - "maze2d-large-v1" => 6.7 , - "maze2d-open-dense-v0" => 11.17817 , - "maze2d-umaze-dense-v1" => 68.537689 , - "maze2d-medium-dense-v1" => 44.264742 , - "maze2d-large-dense-v1" => 30.569041 , - "minigrid-fourrooms-v0" => 0.01442 , - "minigrid-fourrooms-random-v0" => 0.01442 , - "pen-human-v0" => 96.262799 , - "pen-cloned-v0" => 96.262799 , - "pen-expert-v0" => 96.262799 , - "hammer-human-v0" => -274.856578 , - "hammer-cloned-v0" => -274.856578 , - "hammer-expert-v0" => -274.856578 , - "relocate-human-v0" => -6.425911 , - "relocate-cloned-v0" => -6.425911 , - "relocate-expert-v0" => -6.425911 , - "door-human-v0" => -56.512833 , - "door-cloned-v0" => -56.512833 , - "door-expert-v0" => -56.512833 , - "halfcheetah-random-v0" => -280.178953 , - "halfcheetah-medium-v0" => -280.178953 , - "halfcheetah-expert-v0" => -280.178953 , - "halfcheetah-medium-replay-v0" => -280.178953 , - "halfcheetah-medium-expert-v0" => -280.178953 , - "walker2d-random-v0" => 1.629008 , - "walker2d-medium-v0" => 1.629008 , - "walker2d-expert-v0" => 1.629008 , - "walker2d-medium-replay-v0" => 1.629008 , - "walker2d-medium-expert-v0" => 1.629008 , - "hopper-random-v0" => -20.272305 , - "hopper-medium-v0" => -20.272305 , - "hopper-expert-v0" => -20.272305 , - "hopper-medium-replay-v0" => -20.272305 , - "hopper-medium-expert-v0" => -20.272305 , +const D4RL_REF_MIN_SCORE = Dict{String,Float32}( + "maze2d-open-v0" => 0.01, + "maze2d-umaze-v1" => 23.85, + "maze2d-medium-v1" => 13.13, + "maze2d-large-v1" => 6.7, + "maze2d-open-dense-v0" => 11.17817, + "maze2d-umaze-dense-v1" => 68.537689, + "maze2d-medium-dense-v1" => 44.264742, + "maze2d-large-dense-v1" => 30.569041, + "minigrid-fourrooms-v0" => 0.01442, + "minigrid-fourrooms-random-v0" => 0.01442, + "pen-human-v0" => 96.262799, + "pen-cloned-v0" => 96.262799, + "pen-expert-v0" => 96.262799, + "hammer-human-v0" => -274.856578, + "hammer-cloned-v0" => -274.856578, + "hammer-expert-v0" => -274.856578, + "relocate-human-v0" => -6.425911, + "relocate-cloned-v0" => -6.425911, + "relocate-expert-v0" => -6.425911, + "door-human-v0" => -56.512833, + "door-cloned-v0" => -56.512833, + "door-expert-v0" => -56.512833, + "halfcheetah-random-v0" => -280.178953, + "halfcheetah-medium-v0" => -280.178953, + "halfcheetah-expert-v0" => -280.178953, + "halfcheetah-medium-replay-v0" => -280.178953, + "halfcheetah-medium-expert-v0" => -280.178953, + "walker2d-random-v0" => 1.629008, + "walker2d-medium-v0" => 1.629008, + "walker2d-expert-v0" => 1.629008, + "walker2d-medium-replay-v0" => 1.629008, + "walker2d-medium-expert-v0" => 1.629008, + "hopper-random-v0" => -20.272305, + "hopper-medium-v0" => -20.272305, + "hopper-expert-v0" => -20.272305, + "hopper-medium-replay-v0" => -20.272305, + "hopper-medium-expert-v0" => -20.272305, "ant-random-v0" => -325.6, "ant-medium-v0" => -325.6, "ant-expert-v0" => -325.6, "ant-medium-replay-v0" => -325.6, "ant-medium-expert-v0" => -325.6, - "antmaze-umaze-v0" => 0.0 , - "antmaze-umaze-diverse-v0" => 0.0 , - "antmaze-medium-play-v0" => 0.0 , - "antmaze-medium-diverse-v0" => 0.0 , - "antmaze-large-play-v0" => 0.0 , - "antmaze-large-diverse-v0" => 0.0 , - "kitchen-complete-v0" => 0.0 , - "kitchen-partial-v0" => 0.0 , - "kitchen-mixed-v0" => 0.0 , - "flow-ring-random-v0" => -165.22 , - "flow-ring-controller-v0" => -165.22 , - "flow-merge-random-v0" => 118.67993 , - "flow-merge-controller-v0" => 118.67993 , - "carla-lane-v0"=> -0.8503839912088142, - "carla-town-v0"=> -114.81579500772153, # random score - "bullet-halfcheetah-random-v0"=> -1275.766996, - "bullet-halfcheetah-medium-v0"=> -1275.766996, - "bullet-halfcheetah-expert-v0"=> -1275.766996, - "bullet-halfcheetah-medium-expert-v0"=> -1275.766996, - "bullet-halfcheetah-medium-replay-v0"=> -1275.766996, - "bullet-hopper-random-v0"=> 20.058972, - "bullet-hopper-medium-v0"=> 20.058972, - "bullet-hopper-expert-v0"=> 20.058972, - "bullet-hopper-medium-expert-v0"=> 20.058972, - "bullet-hopper-medium-replay-v0"=> 20.058972, - "bullet-ant-random-v0"=> 373.705955, - "bullet-ant-medium-v0"=> 373.705955, - "bullet-ant-expert-v0"=> 373.705955, - "bullet-ant-medium-expert-v0"=> 373.705955, - "bullet-ant-medium-replay-v0"=> 373.705955, - "bullet-walker2d-random-v0"=> 16.523877, - "bullet-walker2d-medium-v0"=> 16.523877, - "bullet-walker2d-expert-v0"=> 16.523877, - "bullet-walker2d-medium-expert-v0"=> 16.523877, - "bullet-walker2d-medium-replay-v0"=> 16.523877, - "bullet-maze2d-open-v0"=> 8.750000, - "bullet-maze2d-umaze-v0"=> 32.460000, - "bullet-maze2d-medium-v0"=> 14.870000, - "bullet-maze2d-large-v0"=> 1.820000, + "antmaze-umaze-v0" => 0.0, + "antmaze-umaze-diverse-v0" => 0.0, + "antmaze-medium-play-v0" => 0.0, + "antmaze-medium-diverse-v0" => 0.0, + "antmaze-large-play-v0" => 0.0, + "antmaze-large-diverse-v0" => 0.0, + "kitchen-complete-v0" => 0.0, + "kitchen-partial-v0" => 0.0, + "kitchen-mixed-v0" => 0.0, + "flow-ring-random-v0" => -165.22, + "flow-ring-controller-v0" => -165.22, + "flow-merge-random-v0" => 118.67993, + "flow-merge-controller-v0" => 118.67993, + "carla-lane-v0" => -0.8503839912088142, + "carla-town-v0" => -114.81579500772153, # random score + "bullet-halfcheetah-random-v0" => -1275.766996, + "bullet-halfcheetah-medium-v0" => -1275.766996, + "bullet-halfcheetah-expert-v0" => -1275.766996, + "bullet-halfcheetah-medium-expert-v0" => -1275.766996, + "bullet-halfcheetah-medium-replay-v0" => -1275.766996, + "bullet-hopper-random-v0" => 20.058972, + "bullet-hopper-medium-v0" => 20.058972, + "bullet-hopper-expert-v0" => 20.058972, + "bullet-hopper-medium-expert-v0" => 20.058972, + "bullet-hopper-medium-replay-v0" => 20.058972, + "bullet-ant-random-v0" => 373.705955, + "bullet-ant-medium-v0" => 373.705955, + "bullet-ant-expert-v0" => 373.705955, + "bullet-ant-medium-expert-v0" => 373.705955, + "bullet-ant-medium-replay-v0" => 373.705955, + "bullet-walker2d-random-v0" => 16.523877, + "bullet-walker2d-medium-v0" => 16.523877, + "bullet-walker2d-expert-v0" => 16.523877, + "bullet-walker2d-medium-expert-v0" => 16.523877, + "bullet-walker2d-medium-replay-v0" => 16.523877, + "bullet-maze2d-open-v0" => 8.750000, + "bullet-maze2d-umaze-v0" => 32.460000, + "bullet-maze2d-medium-v0" => 14.870000, + "bullet-maze2d-large-v0" => 1.820000, ) -const D4RL_REF_MAX_SCORE = Dict{String, Float32}( - "maze2d-open-v0" => 20.66 , - "maze2d-umaze-v1" => 161.86 , - "maze2d-medium-v1" => 277.39 , - "maze2d-large-v1" => 273.99 , - "maze2d-open-dense-v0" => 27.166538620695782 , - "maze2d-umaze-dense-v1" => 193.66285642381482 , - "maze2d-medium-dense-v1" => 297.4552547777125 , - "maze2d-large-dense-v1" => 303.4857382709002 , - "minigrid-fourrooms-v0" => 2.89685 , - "minigrid-fourrooms-random-v0" => 2.89685 , - "pen-human-v0" => 3076.8331017826877 , - "pen-cloned-v0" => 3076.8331017826877 , - "pen-expert-v0" => 3076.8331017826877 , - "hammer-human-v0" => 12794.134825156867 , - "hammer-cloned-v0" => 12794.134825156867 , - "hammer-expert-v0" => 12794.134825156867 , - "relocate-human-v0" => 4233.877797728884 , - "relocate-cloned-v0" => 4233.877797728884 , - "relocate-expert-v0" => 4233.877797728884 , - "door-human-v0" => 2880.5693087298737 , - "door-cloned-v0" => 2880.5693087298737 , - "door-expert-v0" => 2880.5693087298737 , - "halfcheetah-random-v0" => 12135.0 , - "halfcheetah-medium-v0" => 12135.0 , - "halfcheetah-expert-v0" => 12135.0 , - "halfcheetah-medium-replay-v0" => 12135.0 , - "halfcheetah-medium-expert-v0" => 12135.0 , - "walker2d-random-v0" => 4592.3 , - "walker2d-medium-v0" => 4592.3 , - "walker2d-expert-v0" => 4592.3 , - "walker2d-medium-replay-v0" => 4592.3 , - "walker2d-medium-expert-v0" => 4592.3 , - "hopper-random-v0" => 3234.3 , - "hopper-medium-v0" => 3234.3 , - "hopper-expert-v0" => 3234.3 , - "hopper-medium-replay-v0" => 3234.3 , - "hopper-medium-expert-v0" => 3234.3 , +const D4RL_REF_MAX_SCORE = Dict{String,Float32}( + "maze2d-open-v0" => 20.66, + "maze2d-umaze-v1" => 161.86, + "maze2d-medium-v1" => 277.39, + "maze2d-large-v1" => 273.99, + "maze2d-open-dense-v0" => 27.166538620695782, + "maze2d-umaze-dense-v1" => 193.66285642381482, + "maze2d-medium-dense-v1" => 297.4552547777125, + "maze2d-large-dense-v1" => 303.4857382709002, + "minigrid-fourrooms-v0" => 2.89685, + "minigrid-fourrooms-random-v0" => 2.89685, + "pen-human-v0" => 3076.8331017826877, + "pen-cloned-v0" => 3076.8331017826877, + "pen-expert-v0" => 3076.8331017826877, + "hammer-human-v0" => 12794.134825156867, + "hammer-cloned-v0" => 12794.134825156867, + "hammer-expert-v0" => 12794.134825156867, + "relocate-human-v0" => 4233.877797728884, + "relocate-cloned-v0" => 4233.877797728884, + "relocate-expert-v0" => 4233.877797728884, + "door-human-v0" => 2880.5693087298737, + "door-cloned-v0" => 2880.5693087298737, + "door-expert-v0" => 2880.5693087298737, + "halfcheetah-random-v0" => 12135.0, + "halfcheetah-medium-v0" => 12135.0, + "halfcheetah-expert-v0" => 12135.0, + "halfcheetah-medium-replay-v0" => 12135.0, + "halfcheetah-medium-expert-v0" => 12135.0, + "walker2d-random-v0" => 4592.3, + "walker2d-medium-v0" => 4592.3, + "walker2d-expert-v0" => 4592.3, + "walker2d-medium-replay-v0" => 4592.3, + "walker2d-medium-expert-v0" => 4592.3, + "hopper-random-v0" => 3234.3, + "hopper-medium-v0" => 3234.3, + "hopper-expert-v0" => 3234.3, + "hopper-medium-replay-v0" => 3234.3, + "hopper-medium-expert-v0" => 3234.3, "ant-random-v0" => 3879.7, "ant-medium-v0" => 3879.7, "ant-expert-v0" => 3879.7, "ant-medium-replay-v0" => 3879.7, "ant-medium-expert-v0" => 3879.7, - "antmaze-umaze-v0" => 1.0 , - "antmaze-umaze-diverse-v0" => 1.0 , - "antmaze-medium-play-v0" => 1.0 , - "antmaze-medium-diverse-v0" => 1.0 , - "antmaze-large-play-v0" => 1.0 , - "antmaze-large-diverse-v0" => 1.0 , - "kitchen-complete-v0" => 4.0 , - "kitchen-partial-v0" => 4.0 , - "kitchen-mixed-v0" => 4.0 , - "flow-ring-random-v0" => 24.42 , - "flow-ring-controller-v0" => 24.42 , - "flow-merge-random-v0" => 330.03179 , - "flow-merge-controller-v0" => 330.03179 , - "carla-lane-v0"=> 1023.5784385429523, - "carla-town-v0"=> 2440.1772022247314, # avg dataset score - "bullet-halfcheetah-random-v0"=> 2381.6725, - "bullet-halfcheetah-medium-v0"=> 2381.6725, - "bullet-halfcheetah-expert-v0"=> 2381.6725, - "bullet-halfcheetah-medium-expert-v0"=> 2381.6725, - "bullet-halfcheetah-medium-replay-v0"=> 2381.6725, - "bullet-hopper-random-v0"=> 1441.8059623430963, - "bullet-hopper-medium-v0"=> 1441.8059623430963, - "bullet-hopper-expert-v0"=> 1441.8059623430963, - "bullet-hopper-medium-expert-v0"=> 1441.8059623430963, - "bullet-hopper-medium-replay-v0"=> 1441.8059623430963, - "bullet-ant-random-v0"=> 2650.495, - "bullet-ant-medium-v0"=> 2650.495, - "bullet-ant-expert-v0"=> 2650.495, - "bullet-ant-medium-expert-v0"=> 2650.495, - "bullet-ant-medium-replay-v0"=> 2650.495, - "bullet-walker2d-random-v0"=> 1623.6476303317536, - "bullet-walker2d-medium-v0"=> 1623.6476303317536, - "bullet-walker2d-expert-v0"=> 1623.6476303317536, - "bullet-walker2d-medium-expert-v0"=> 1623.6476303317536, - "bullet-walker2d-medium-replay-v0"=> 1623.6476303317536, - "bullet-maze2d-open-v0"=> 64.15, - "bullet-maze2d-umaze-v0"=> 153.99, - "bullet-maze2d-medium-v0"=> 238.05, - "bullet-maze2d-large-v0"=> 285.92, + "antmaze-umaze-v0" => 1.0, + "antmaze-umaze-diverse-v0" => 1.0, + "antmaze-medium-play-v0" => 1.0, + "antmaze-medium-diverse-v0" => 1.0, + "antmaze-large-play-v0" => 1.0, + "antmaze-large-diverse-v0" => 1.0, + "kitchen-complete-v0" => 4.0, + "kitchen-partial-v0" => 4.0, + "kitchen-mixed-v0" => 4.0, + "flow-ring-random-v0" => 24.42, + "flow-ring-controller-v0" => 24.42, + "flow-merge-random-v0" => 330.03179, + "flow-merge-controller-v0" => 330.03179, + "carla-lane-v0" => 1023.5784385429523, + "carla-town-v0" => 2440.1772022247314, # avg dataset score + "bullet-halfcheetah-random-v0" => 2381.6725, + "bullet-halfcheetah-medium-v0" => 2381.6725, + "bullet-halfcheetah-expert-v0" => 2381.6725, + "bullet-halfcheetah-medium-expert-v0" => 2381.6725, + "bullet-halfcheetah-medium-replay-v0" => 2381.6725, + "bullet-hopper-random-v0" => 1441.8059623430963, + "bullet-hopper-medium-v0" => 1441.8059623430963, + "bullet-hopper-expert-v0" => 1441.8059623430963, + "bullet-hopper-medium-expert-v0" => 1441.8059623430963, + "bullet-hopper-medium-replay-v0" => 1441.8059623430963, + "bullet-ant-random-v0" => 2650.495, + "bullet-ant-medium-v0" => 2650.495, + "bullet-ant-expert-v0" => 2650.495, + "bullet-ant-medium-expert-v0" => 2650.495, + "bullet-ant-medium-replay-v0" => 2650.495, + "bullet-walker2d-random-v0" => 1623.6476303317536, + "bullet-walker2d-medium-v0" => 1623.6476303317536, + "bullet-walker2d-expert-v0" => 1623.6476303317536, + "bullet-walker2d-medium-expert-v0" => 1623.6476303317536, + "bullet-walker2d-medium-replay-v0" => 1623.6476303317536, + "bullet-maze2d-open-v0" => 64.15, + "bullet-maze2d-umaze-v0" => 153.99, + "bullet-maze2d-medium-v0" => 238.05, + "bullet-maze2d-large-v0" => 285.92, ) # give a prompt for flow and carla tasks @@ -275,20 +275,20 @@ function d4rl_init() for ds in keys(D4RL_DATASET_URLS) register( DataDep( - repo*"-"* ds, + repo * "-" * ds, """ Credits: https://arxiv.org/abs/2004.07219 The following dataset is fetched from the d4rl. The dataset is fetched and modified in a form that is useful for RL.jl package. - + Dataset information: Name: $(ds) $(if ds in keys(D4RL_REF_MAX_SCORE) "MAXIMUM_SCORE: " * string(D4RL_REF_MAX_SCORE[ds]) end) $(if ds in keys(D4RL_REF_MIN_SCORE) "MINIMUM_SCORE: " * string(D4RL_REF_MIN_SCORE[ds]) end) """, #check if the MAX and MIN score part is even necessary and make the log file prettier D4RL_DATASET_URLS[ds], - ) + ), ) end nothing -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl index 408e63903..ddaece13f 100644 --- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl +++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl @@ -21,7 +21,7 @@ Represents an iterable dataset of type D4RLDataSet with the following fields: `is_shuffle`: Bool, determines if the batches returned by `iterate` are shuffled. """ struct D4RLDataSet{T<:AbstractRNG} <: RLDataSet - dataset::Dict{Symbol, Any} + dataset::Dict{Symbol,Any} repo::String dataset_size::Integer batch_size::Integer @@ -50,41 +50,45 @@ The `D4RLDataSet` type is an iterable that fetches batches when used in a for lo The returned type is an infinite iterator which can be called using `iterate` and will return batches as specified in the dataset. """ -function dataset(dataset::String; - style=SARTS, +function dataset( + dataset::String; + style = SARTS, repo = "d4rl", - rng = StableRNG(123), - is_shuffle = true, - batch_size=256 + rng = StableRNG(123), + is_shuffle = true, + batch_size = 256, ) - - try - @datadep_str repo*"-"*dataset - catch - throw("The provided dataset is not available") + + try + @datadep_str repo * "-" * dataset + catch + throw("The provided dataset is not available") end - - path = @datadep_str repo*"-"*dataset + + path = @datadep_str repo * "-" * dataset @assert length(readdir(path)) == 1 file_name = readdir(path)[1] - - data = h5open(path*"/"*file_name, "r") do file + + data = h5open(path * "/" * file_name, "r") do file read(file) end # sanity checks on data d4rl_verify(data) - dataset = Dict{Symbol, Any}() - meta = Dict{String, Any}() + dataset = Dict{Symbol,Any}() + meta = Dict{String,Any}() N_samples = size(data["observations"])[2] - - for (key, d_key) in zip(["observations", "actions", "rewards", "terminals"], Symbol.(["state", "action", "reward", "terminal"])) - dataset[d_key] = data[key] + + for (key, d_key) in zip( + ["observations", "actions", "rewards", "terminals"], + Symbol.(["state", "action", "reward", "terminal"]), + ) + dataset[d_key] = data[key] end - + for key in keys(data) if !(key in ["observations", "actions", "rewards", "terminals"]) meta[key] = data[key] @@ -104,9 +108,13 @@ function iterate(ds::D4RLDataSet, state = 0) if is_shuffle inds = rand(rng, 1:size, batch_size) - map((x)-> if x <= size x else 1 end, inds) + map((x) -> if x <= size + x + else + 1 + end, inds) else - if (state+1) * batch_size <= size + if (state + 1) * batch_size <= size inds = state*batch_size+1:(state+1)*batch_size else return nothing @@ -114,15 +122,17 @@ function iterate(ds::D4RLDataSet, state = 0) state += 1 end - batch = (state = copy(ds.dataset[:state][:, inds]), - action = copy(ds.dataset[:action][:, inds]), - reward = copy(ds.dataset[:reward][inds]), - terminal = copy(ds.dataset[:terminal][inds])) + batch = ( + state = copy(ds.dataset[:state][:, inds]), + action = copy(ds.dataset[:action][:, inds]), + reward = copy(ds.dataset[:reward][inds]), + terminal = copy(ds.dataset[:terminal][inds]), + ) if style == SARTS batch = merge(batch, (next_state = copy(ds.dataset[:state][:, (1).+(inds)]),)) end - + return batch, state end @@ -132,11 +142,12 @@ length(ds::D4RLDataSet) = ds.dataset_size IteratorEltype(::Type{D4RLDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit) -function d4rl_verify(data::Dict{String, Any}) +function d4rl_verify(data::Dict{String,Any}) for key in ["observations", "actions", "rewards", "terminals"] @assert (key in keys(data)) "Expected keys not present in data" end N_samples = size(data["observations"])[2] @assert size(data["rewards"]) == (N_samples,) || size(data["rewards"]) == (1, N_samples) - @assert size(data["terminals"]) == (N_samples,) || size(data["terminals"]) == (1, N_samples) -end \ No newline at end of file + @assert size(data["terminals"]) == (N_samples,) || + size(data["terminals"]) == (1, N_samples) +end diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl index 90f828e03..690b13a59 100644 --- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl +++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl @@ -1,18 +1,18 @@ export D4RL_PYBULLET_URLS const D4RL_PYBULLET_URLS = Dict( - "hopper-bullet-mixed-v0" => "https://www.dropbox.com/s/xv3p0h7dzgxt8xb/hopper-bullet-mixed-v0.hdf5?dl=1", - "walker2d-bullet-random-v0" => "https://www.dropbox.com/s/1gwcfl2nmx6878m/walker2d-bullet-random-v0.hdf5?dl=1", - "hopper-bullet-medium-v0" => "https://www.dropbox.com/s/w22kgzldn6eng7j/hopper-bullet-medium-v0.hdf5?dl=1", + "hopper-bullet-mixed-v0" => "https://www.dropbox.com/s/xv3p0h7dzgxt8xb/hopper-bullet-mixed-v0.hdf5?dl=1", + "walker2d-bullet-random-v0" => "https://www.dropbox.com/s/1gwcfl2nmx6878m/walker2d-bullet-random-v0.hdf5?dl=1", + "hopper-bullet-medium-v0" => "https://www.dropbox.com/s/w22kgzldn6eng7j/hopper-bullet-medium-v0.hdf5?dl=1", "walker2d-bullet-mixed-v0" => "https://www.dropbox.com/s/i4u2ii0d85iblou/walker2d-bullet-mixed-v0.hdf5?dl=1", - "halfcheetah-bullet-mixed-v0" => "https://www.dropbox.com/s/scj1rqun963aw90/halfcheetah-bullet-mixed-v0.hdf5?dl=1", + "halfcheetah-bullet-mixed-v0" => "https://www.dropbox.com/s/scj1rqun963aw90/halfcheetah-bullet-mixed-v0.hdf5?dl=1", "halfcheetah-bullet-random-v0" => "https://www.dropbox.com/s/jnvpb1hp60zt2ak/halfcheetah-bullet-random-v0.hdf5?dl=1", - "walker2d-bullet-medium-v0" => "https://www.dropbox.com/s/v0f2kz48b1hw6or/walker2d-bullet-medium-v0.hdf5?dl=1", - "hopper-bullet-random-v0" => "https://www.dropbox.com/s/bino8ojd7iq4p4d/hopper-bullet-random-v0.hdf5?dl=1", - "ant-bullet-random-v0" => "https://www.dropbox.com/s/2xpmh4wk2m7i8xh/ant-bullet-random-v0.hdf5?dl=1", - "halfcheetah-bullet-medium-v0" => "https://www.dropbox.com/s/v4xgssp1w968a9l/halfcheetah-bullet-medium-v0.hdf5?dl=1", - "ant-bullet-medium-v0" => "https://www.dropbox.com/s/6n79kwd94xthr1t/ant-bullet-medium-v0.hdf5?dl=1", - "ant-bullet-mixed-v0" => "https://www.dropbox.com/s/pmy3dzab35g4whk/ant-bullet-mixed-v0.hdf5?dl=1" + "walker2d-bullet-medium-v0" => "https://www.dropbox.com/s/v0f2kz48b1hw6or/walker2d-bullet-medium-v0.hdf5?dl=1", + "hopper-bullet-random-v0" => "https://www.dropbox.com/s/bino8ojd7iq4p4d/hopper-bullet-random-v0.hdf5?dl=1", + "ant-bullet-random-v0" => "https://www.dropbox.com/s/2xpmh4wk2m7i8xh/ant-bullet-random-v0.hdf5?dl=1", + "halfcheetah-bullet-medium-v0" => "https://www.dropbox.com/s/v4xgssp1w968a9l/halfcheetah-bullet-medium-v0.hdf5?dl=1", + "ant-bullet-medium-v0" => "https://www.dropbox.com/s/6n79kwd94xthr1t/ant-bullet-medium-v0.hdf5?dl=1", + "ant-bullet-mixed-v0" => "https://www.dropbox.com/s/pmy3dzab35g4whk/ant-bullet-mixed-v0.hdf5?dl=1", ) function d4rl_pybullet_init() @@ -20,14 +20,14 @@ function d4rl_pybullet_init() for ds in keys(D4RL_PYBULLET_URLS) register( DataDep( - repo* "-" * ds, + repo * "-" * ds, """ Credits: https://github.com/takuseno/d4rl-pybullet The following dataset is fetched from the d4rl-pybullet. - """, + """, D4RL_PYBULLET_URLS[ds], - ) + ), ) end nothing -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/init.jl b/src/ReinforcementLearningDatasets/src/init.jl index 69bfad217..256c8a28d 100644 --- a/src/ReinforcementLearningDatasets/src/init.jl +++ b/src/ReinforcementLearningDatasets/src/init.jl @@ -2,4 +2,4 @@ function __init__() RLDatasets.d4rl_init() RLDatasets.d4rl_pybullet_init() RLDatasets.atari_init() -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/atari_dataset.jl b/src/ReinforcementLearningDatasets/test/atari_dataset.jl index eb6c03175..b89c63ef3 100644 --- a/src/ReinforcementLearningDatasets/test/atari_dataset.jl +++ b/src/ReinforcementLearningDatasets/test/atari_dataset.jl @@ -13,11 +13,11 @@ rng = StableRNG(123) "pong", index, epochs; - repo="atari-replay-datasets", + repo = "atari-replay-datasets", style = style, rng = rng, is_shuffle = true, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset @@ -64,11 +64,11 @@ end "pong", index, epochs; - repo="atari-replay-datasets", + repo = "atari-replay-datasets", style = style, rng = rng, is_shuffle = false, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset @@ -118,4 +118,4 @@ end @test data_dict[:reward][batch_size+1:batch_size*2] == iter2[:reward] @test data_dict[:terminal][batch_size+1:batch_size*2] == iter2[:terminal] @test data_dict[:state][:, :, batch_size+2:batch_size*2+1] == iter2[:next_state] -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl b/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl index a146647c7..245bfae92 100644 --- a/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl +++ b/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl @@ -2,11 +2,11 @@ using Base: batch_size_err_str @testset "d4rl_pybullet" begin ds = dataset( "hopper-bullet-mixed-v0"; - repo="d4rl-pybullet", + repo = "d4rl-pybullet", style = style, rng = rng, is_shuffle = true, - batch_size = batch_size + batch_size = batch_size, ) n_s = 15 @@ -23,10 +23,12 @@ using Base: batch_size_err_str for sample in Iterators.take(ds, 3) @test typeof(sample) <: NamedTuple{SARTS} - @test size(sample[:state]) == (n_s, batch_size) - @test size(sample[:action]) == (n_a, batch_size) - @test size(sample[:reward]) == (1, batch_size) || size(sample[:reward]) == (batch_size,) - @test size(sample[:terminal]) == (1, batch_size) || size(sample[:terminal]) == (batch_size,) + @test size(sample[:state]) == (n_s, batch_size) + @test size(sample[:action]) == (n_a, batch_size) + @test size(sample[:reward]) == (1, batch_size) || + size(sample[:reward]) == (batch_size,) + @test size(sample[:terminal]) == (1, batch_size) || + size(sample[:terminal]) == (batch_size,) end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/dataset.jl b/src/ReinforcementLearningDatasets/test/dataset.jl index f0263d847..fe325f44f 100644 --- a/src/ReinforcementLearningDatasets/test/dataset.jl +++ b/src/ReinforcementLearningDatasets/test/dataset.jl @@ -8,11 +8,11 @@ rng = StableRNG(123) @testset "dataset_shuffle" begin ds = dataset( "hopper-medium-replay-v0"; - repo="d4rl", + repo = "d4rl", style = style, rng = rng, is_shuffle = true, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset @@ -58,7 +58,7 @@ end style = style, rng = rng, is_shuffle = false, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset diff --git a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl index f3d2dd7e9..dfd1f0ee8 100644 --- a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl +++ b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl @@ -4,7 +4,7 @@ using ReinforcementLearningBase using Random using Requires using IntervalSets -using Base.Threads:@spawn +using Base.Threads: @spawn using Markdown const RLEnvs = ReinforcementLearningEnvironments @@ -31,9 +31,7 @@ function __init__() @require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" include( "environments/3rd_party/AcrobotEnv.jl", ) - @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include( - "plots.jl", - ) + @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plots.jl") end diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl index c371bb80b..eac7957c7 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl @@ -18,23 +18,23 @@ AcrobotEnv(;kwargs...) - `avail_torque = [T(-1.), T(0.), T(1.)]` """ function AcrobotEnv(; - T=Float64, - link_length_a=T(1.0), - link_length_b=T(1.0), - link_mass_a=T(1.0), - link_mass_b=T(1.0), - link_com_pos_a=T(0.5), - link_com_pos_b=T(0.5), - link_moi=T(1.0), - max_torque_noise=T(0.0), - max_vel_a=T(4 * π), - max_vel_b=T(9 * π), - g=T(9.8), - dt=T(0.2), - max_steps=200, - rng=Random.GLOBAL_RNG, - book_or_nips="book", - avail_torque=[T(-1.0), T(0.0), T(1.0)], + T = Float64, + link_length_a = T(1.0), + link_length_b = T(1.0), + link_mass_a = T(1.0), + link_mass_b = T(1.0), + link_com_pos_a = T(0.5), + link_com_pos_b = T(0.5), + link_moi = T(1.0), + max_torque_noise = T(0.0), + max_vel_a = T(4 * π), + max_vel_b = T(9 * π), + g = T(9.8), + dt = T(0.2), + max_steps = 200, + rng = Random.GLOBAL_RNG, + book_or_nips = "book", + avail_torque = [T(-1.0), T(0.0), T(1.0)], ) params = AcrobotEnvParams{T}( @@ -81,7 +81,7 @@ RLBase.is_terminated(env::AcrobotEnv) = env.done RLBase.state(env::AcrobotEnv) = acrobot_observation(env.state) RLBase.reward(env::AcrobotEnv) = env.reward -function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number} +function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number} env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05) env.t = 0 env.action = 2 @@ -91,7 +91,7 @@ function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number} end # governing equations as per python gym -function (env::AcrobotEnv{T})(a) where {T <: Number} +function (env::AcrobotEnv{T})(a) where {T<:Number} env.action = a env.t += 1 torque = env.avail_torque[a] @@ -137,7 +137,7 @@ function dsdt(du, s_augmented, env::AcrobotEnv, t) # extract action and state a = s_augmented[end] - s = s_augmented[1:end - 1] + s = s_augmented[1:end-1] # writing in standard form theta1 = s[1] @@ -201,7 +201,7 @@ function wrap(x, m, M) while x < m x = x + diff end -return x + return x end function bound(x, m, M) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl index 3974bb37b..843bd39ee 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl @@ -44,7 +44,7 @@ import .OpenSpiel: `True` or `False` (instead of `true` or `false`). Another approach is to just specify parameters in `kwargs` in the Julia style. """ -function OpenSpielEnv(name="kuhn_poker"; kwargs...) +function OpenSpielEnv(name = "kuhn_poker"; kwargs...) game = load_game(String(name); kwargs...) state = new_initial_state(game) OpenSpielEnv(state, game) @@ -60,7 +60,7 @@ RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state) RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER) function RLBase.players(env::OpenSpielEnv) - p = 0:(num_players(env.game) - 1) + p = 0:(num_players(env.game)-1) if ChanceStyle(env) === EXPLICIT_STOCHASTIC (p..., RLBase.chance_player(env)) else @@ -91,7 +91,7 @@ function RLBase.prob(env::OpenSpielEnv, player) # @assert player == chance_player(env) p = zeros(length(action_space(env))) for (k, v) in chance_outcomes(env.state) - p[k + 1] = v + p[k+1] = v end p end @@ -102,7 +102,7 @@ function RLBase.legal_action_space_mask(env::OpenSpielEnv, player) num_distinct_actions(env.game) mask = BitArray(undef, n) for a in legal_actions(env.state, player) - mask[a + 1] = true + mask[a+1] = true end mask end @@ -136,12 +136,16 @@ end _state(env::OpenSpielEnv, ::RLBase.InformationSet{String}, player) = information_state_string(env.state, player) -_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) = - reshape(information_state_tensor(env.state, player), reverse(information_state_tensor_shape(env.game))...) +_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) = reshape( + information_state_tensor(env.state, player), + reverse(information_state_tensor_shape(env.game))..., +) _state(env::OpenSpielEnv, ::Observation{String}, player) = observation_string(env.state, player) -_state(env::OpenSpielEnv, ::Observation{Array}, player) = - reshape(observation_tensor(env.state, player), reverse(observation_tensor_shape(env.game))...) +_state(env::OpenSpielEnv, ::Observation{Array}, player) = reshape( + observation_tensor(env.state, player), + reverse(observation_tensor_shape(env.game))..., +) RLBase.state_space( env::OpenSpielEnv, @@ -149,16 +153,18 @@ RLBase.state_space( p, ) = WorldSpace{AbstractString}() -RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array}, - p, -) = Space( - fill(typemin(Float64)..typemax(Float64), reverse(information_state_tensor_shape(env.game))...), +RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array}, p) = Space( + fill( + typemin(Float64)..typemax(Float64), + reverse(information_state_tensor_shape(env.game))..., + ), ) -RLBase.state_space(env::OpenSpielEnv, ::Observation{Array}, - p, -) = Space( - fill(typemin(Float64)..typemax(Float64), reverse(observation_tensor_shape(env.game))...), +RLBase.state_space(env::OpenSpielEnv, ::Observation{Array}, p) = Space( + fill( + typemin(Float64)..typemax(Float64), + reverse(observation_tensor_shape(env.game))..., + ), ) Random.seed!(env::OpenSpielEnv, s) = @warn "seed!(OpenSpielEnv) is not supported currently." @@ -199,7 +205,9 @@ RLBase.RewardStyle(env::OpenSpielEnv) = reward_model(get_type(env.game)) == OpenSpiel.REWARDS ? RLBase.STEP_REWARD : RLBase.TERMINAL_REWARD -RLBase.StateStyle(env::OpenSpielEnv) = (RLBase.InformationSet{String}(), +RLBase.StateStyle(env::OpenSpielEnv) = ( + RLBase.InformationSet{String}(), RLBase.InformationSet{Array}(), RLBase.Observation{String}(), - RLBase.Observation{Array}(),) + RLBase.Observation{Array}(), +) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl index 83586f4e3..0acd51427 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl @@ -6,7 +6,7 @@ struct GymEnv{T,Ta,To,P} <: AbstractEnv end export GymEnv -mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S <: AbstractRNG} <: AbstractEnv +mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv ale::Ptr{Nothing} name::String screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling @@ -65,7 +65,7 @@ end export AcrobotEnvParams -mutable struct AcrobotEnv{T,R <: AbstractRNG} <: AbstractEnv +mutable struct AcrobotEnv{T,R<:AbstractRNG} <: AbstractEnv params::AcrobotEnvParams{T} state::Vector{T} action::Int diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl index 6a3f6a6f6..07f220bb7 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl @@ -2,12 +2,12 @@ export StockTradingEnv, StockTradingEnvWithTurbulence using Pkg.Artifacts using DelimitedFiles -using LinearAlgebra:dot +using LinearAlgebra: dot using IntervalSets function load_default_stock_data(s) if s == "prices.csv" || s == "features.csv" - data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header=true) + data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header = true) collect(data') elseif s == "turbulence.csv" readdlm(joinpath(artifact"stock_trading_data", "turbulence.csv")) |> vec @@ -16,7 +16,8 @@ function load_default_stock_data(s) end end -mutable struct StockTradingEnv{F<:AbstractMatrix{Float64}, P<:AbstractMatrix{Float64}} <: AbstractEnv +mutable struct StockTradingEnv{F<:AbstractMatrix{Float64},P<:AbstractMatrix{Float64}} <: + AbstractEnv features::F prices::P HMAX_NORMALIZE::Float32 @@ -48,14 +49,14 @@ This environment is originally provided in [Deep Reinforcement Learning for Auto - `initial_account_balance=1_000_000`. """ function StockTradingEnv(; - initial_account_balance=1_000_000f0, - features=nothing, - prices=nothing, - first_day=nothing, - last_day=nothing, - HMAX_NORMALIZE = 100f0, + initial_account_balance = 1_000_000.0f0, + features = nothing, + prices = nothing, + first_day = nothing, + last_day = nothing, + HMAX_NORMALIZE = 100.0f0, TRANSACTION_FEE_PERCENT = 0.001f0, - REWARD_SCALING = 1f-4 + REWARD_SCALING = 1f-4, ) prices = isnothing(prices) ? load_default_stock_data("prices.csv") : prices features = isnothing(features) ? load_default_stock_data("features.csv") : features @@ -77,11 +78,11 @@ function StockTradingEnv(; REWARD_SCALING, initial_account_balance, state, - 0f0, + 0.0f0, day, first_day, last_day, - 0f0 + 0.0f0, ) _balance(env)[] = initial_account_balance @@ -108,10 +109,10 @@ function (env::StockTradingEnv)(actions) # then buy # better to shuffle? - for (i,b) in enumerate(actions) + for (i, b) in enumerate(actions) if b > 0 max_buy = div(_balance(env)[], _prices(env)[i]) - buy = min(b*env.HMAX_NORMALIZE, max_buy) + buy = min(b * env.HMAX_NORMALIZE, max_buy) _holds(env)[i] += buy deduction = buy * _prices(env)[i] cost = deduction * env.TRANSACTION_FEE_PERCENT @@ -136,12 +137,12 @@ function RLBase.reset!(env::StockTradingEnv) _balance(env)[] = env.initial_account_balance _prices(env) .= @view env.prices[:, env.day] _features(env) .= @view env.features[:, env.day] - env.total_cost = 0. - env.daily_reward = 0. + env.total_cost = 0.0 + env.daily_reward = 0.0 end RLBase.state_space(env::StockTradingEnv) = Space(fill(-Inf32..Inf32, length(state(env)))) -RLBase.action_space(env::StockTradingEnv) = Space(fill(-1f0..1f0, length(_holds(env)))) +RLBase.action_space(env::StockTradingEnv) = Space(fill(-1.0f0..1.0f0, length(_holds(env)))) RLBase.ChanceStyle(::StockTradingEnv) = DETERMINISTIC @@ -154,16 +155,16 @@ struct StockTradingEnvWithTurbulence{E<:StockTradingEnv} <: AbstractEnvWrapper end function StockTradingEnvWithTurbulence(; - turbulence_threshold=140., - turbulences=nothing, - kw... + turbulence_threshold = 140.0, + turbulences = nothing, + kw..., ) turbulences = isnothing(turbulences) && load_default_stock_data("turbulence.csv") StockTradingEnvWithTurbulence( - StockTradingEnv(;kw...), + StockTradingEnv(; kw...), turbulences, - turbulence_threshold + turbulence_threshold, ) end @@ -172,4 +173,4 @@ function (w::StockTradingEnvWithTurbulence)(actions) actions .= ifelse.(actions .< 0, -Inf32, 0) end w.env(actions) -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl index b178b6ecf..3e5bd2264 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl @@ -13,7 +13,7 @@ end `legal_action_space(env)`. `action_mapping` will be applied to `action` before feeding it into `env`. """ -ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) = +ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) = ActionTransformedEnv(env, action_mapping, action_space_mapping) RLBase.action_space(env::ActionTransformedEnv, args...) = diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl index 3e6996c52..95ce36788 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl @@ -13,9 +13,10 @@ DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env) RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S -RLBase.state(env::DefaultStateStyleEnv{S}) where S = state(env.env, S) +RLBase.state(env::DefaultStateStyleEnv{S}) where {S} = state(env.env, S) RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state(env::DefaultStateStyleEnv{S}, player) where S = state(env.env, S, player) +RLBase.state(env::DefaultStateStyleEnv{S}, player) where {S} = state(env.env, S, player) -RLBase.state_space(env::DefaultStateStyleEnv{S}) where S = state_space(env.env, S) -RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss) \ No newline at end of file +RLBase.state_space(env::DefaultStateStyleEnv{S}) where {S} = state_space(env.env, S) +RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = + state_space(env.env, ss) diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl index 4f18af426..2a88903dd 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl @@ -9,7 +9,7 @@ mutable struct SequentialEnv{E<:AbstractEnv} <: AbstractEnvWrapper env::E current_player_idx::Int actions::Vector{Any} - function SequentialEnv(env::T) where T<:AbstractEnv + function SequentialEnv(env::T) where {T<:AbstractEnv} @assert DynamicStyle(env) === SIMULTANEOUS "The SequentialEnv wrapper can only be applied to SIMULTANEOUS environments" new{T}(env, 1, Vector{Any}(undef, length(players(env)))) end @@ -32,7 +32,8 @@ end RLBase.reward(env::SequentialEnv) = reward(env, current_player(env)) -RLBase.reward(env::SequentialEnv, player) = current_player(env) == 1 ? reward(env.env, player) : 0 +RLBase.reward(env::SequentialEnv, player) = + current_player(env) == 1 ? reward(env.env, player) : 0 function (env::SequentialEnv)(action) env.actions[env.current_player_idx] = action @@ -43,4 +44,3 @@ function (env::SequentialEnv)(action) env.current_player_idx += 1 end end - diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl index e8626a3b8..97e18e928 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl @@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some environments are stateful during each `state(env)`. For example: `StateTransformedEnv(StackFrames(...))`. """ -mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper +mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper s::S env::E is_state_cached::Bool diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl index dfe90bddd..840c9ecdb 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl @@ -12,11 +12,11 @@ end `state_mapping` will be applied on the original state when calling `state(env)`, and similarly `state_space_mapping` will be applied when calling `state_space(env)`. """ -StateTransformedEnv(env; state_mapping=identity, state_space_mapping=identity) = +StateTransformedEnv(env; state_mapping = identity, state_space_mapping = identity) = StateTransformedEnv(env, state_mapping, state_space_mapping) RLBase.state(env::StateTransformedEnv, args...; kwargs...) = env.state_mapping(state(env.env, args...; kwargs...)) -RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) = +RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) = env.state_space_mapping(state_space(env.env, args...; kwargs...)) diff --git a/src/ReinforcementLearningEnvironments/src/plots.jl b/src/ReinforcementLearningEnvironments/src/plots.jl index a2c3ce7bc..0e9e5cb61 100644 --- a/src/ReinforcementLearningEnvironments/src/plots.jl +++ b/src/ReinforcementLearningEnvironments/src/plots.jl @@ -8,36 +8,36 @@ function plot(env::CartPoleEnv; kwargs...) xthreshold = env.params.xthreshold # set the frame plot( - xlims=(-xthreshold, xthreshold), - ylims=(-.1, l + 0.1), - legend=false, - border=:none, + xlims = (-xthreshold, xthreshold), + ylims = (-.1, l + 0.1), + legend = false, + border = :none, ) # plot the cart - plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05]; - seriestype=:shape, - ) + plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05]; seriestype = :shape) # plot the pole - plot!([x, x + l * sin(theta)], [0, l * cos(theta)]; - linewidth=3, - ) + plot!([x, x + l * sin(theta)], [0, l * cos(theta)]; linewidth = 3) # plot the arrow - plot!([x + (a == 1) - 0.5, x + 1.4 * (a == 1)-0.7], [ -.025, -.025]; - linewidth=3, - arrow=true, - color=2, + plot!( + [x + (a == 1) - 0.5, x + 1.4 * (a == 1) - 0.7], + [-.025, -.025]; + linewidth = 3, + arrow = true, + color = 2, ) # if done plot pink circle in top right if d - plot!([xthreshold - 0.2], [l]; - marker=:circle, - markersize=20, - markerstrokewidth=0., - color=:pink, + plot!( + [xthreshold - 0.2], + [l]; + marker = :circle, + markersize = 20, + markerstrokewidth = 0.0, + color = :pink, ) end - - plot!(;kwargs...) + + plot!(; kwargs...) end @@ -51,10 +51,10 @@ function plot(env::MountainCarEnv; kwargs...) d = env.done plot( - xlims=(env.params.min_pos - 0.1, env.params.max_pos + 0.2), - ylims=(-.1, height(env.params.max_pos) + 0.2), - legend=false, - border=:none, + xlims = (env.params.min_pos - 0.1, env.params.max_pos + 0.2), + ylims = (-.1, height(env.params.max_pos) + 0.2), + legend = false, + border = :none, ) # plot the terrain xs = LinRange(env.params.min_pos, env.params.max_pos, 100) @@ -72,17 +72,19 @@ function plot(env::MountainCarEnv; kwargs...) ys .+= clearance xs, ys = rotate(xs, ys, θ) xs, ys = translate(xs, ys, [x, height(x)]) - plot!(xs, ys; seriestype=:shape) + plot!(xs, ys; seriestype = :shape) # if done plot pink circle in top right if d - plot!([xthreshold - 0.2], [l]; - marker=:circle, - markersize=20, - markerstrokewidth=0., - color=:pink, + plot!( + [xthreshold - 0.2], + [l]; + marker = :circle, + markersize = 20, + markerstrokewidth = 0.0, + color = :pink, ) end - plot!(;kwargs...) - end + plot!(; kwargs...) +end diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl index 7cb138328..c826e0e4e 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl @@ -5,4 +5,3 @@ RLBase.test_interfaces!(env) RLBase.test_runnable!(env) end - diff --git a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl index 208307a67..049defa22 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl @@ -1,11 +1,11 @@ @testset "wrappers" begin @testset "ActionTransformedEnv" begin - env = TigerProblemEnv(; rng=StableRNG(123)) + env = TigerProblemEnv(; rng = StableRNG(123)) env′ = ActionTransformedEnv( env; - action_space_mapping=x -> Base.OneTo(3), - action_mapping=i -> action_space(env)[i], + action_space_mapping = x -> Base.OneTo(3), + action_mapping = i -> action_space(env)[i], ) RLBase.test_interfaces!(env′) @@ -14,7 +14,7 @@ @testset "DefaultStateStyleEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) S = InternalState{Int}() env′ = DefaultStateStyleEnv{S}(env) @test DefaultStateStyle(env′) === S @@ -35,7 +35,7 @@ @testset "MaxTimeoutEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) n = 100 env′ = MaxTimeoutEnv(env, n) @@ -55,7 +55,7 @@ @testset "RewardOverriddenEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) env′ = RewardOverriddenEnv(env, x -> sign(x)) RLBase.test_interfaces!(env′) @@ -69,7 +69,7 @@ @testset "StateCachedEnv" begin rng = StableRNG(123) - env = CartPoleEnv(; rng=rng) + env = CartPoleEnv(; rng = rng) env′ = StateCachedEnv(env) RLBase.test_interfaces!(env′) @@ -85,7 +85,7 @@ @testset "StateTransformedEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) # S = (:door1, :door2, :door3, :none) # env′ = StateTransformedEnv(env, state_mapping=s -> s+1) # RLBase.state_space(env::typeof(env′), ::RLBase.AbstractStateStyle, ::Any) = S @@ -97,14 +97,14 @@ @testset "StochasticEnv" begin env = KuhnPokerEnv() rng = StableRNG(123) - env′ = StochasticEnv(env; rng=rng) + env′ = StochasticEnv(env; rng = rng) RLBase.test_interfaces!(env′) RLBase.test_runnable!(env′) end @testset "SequentialEnv" begin - env = RockPaperScissorsEnv() + env = RockPaperScissorsEnv() env′ = SequentialEnv(env) RLBase.test_interfaces!(env′) RLBase.test_runnable!(env′) diff --git a/src/ReinforcementLearningExperiments/deps/build.jl b/src/ReinforcementLearningExperiments/deps/build.jl index def42b559..6dfde6249 100644 --- a/src/ReinforcementLearningExperiments/deps/build.jl +++ b/src/ReinforcementLearningExperiments/deps/build.jl @@ -2,10 +2,11 @@ using Weave const DEST_DIR = joinpath(@__DIR__, "..", "src", "experiments") -for (root, dirs, files) in walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments")) +for (root, dirs, files) in + walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments")) for f in files if splitext(f)[2] == ".jl" - tangle(joinpath(root,f);informat="script", out_path=DEST_DIR) + tangle(joinpath(root, f); informat = "script", out_path = DEST_DIR) end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl index 0f615ba3f..2a273443a 100644 --- a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl +++ b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl @@ -23,7 +23,6 @@ for f in readdir(EXPERIMENTS_DIR) end # dynamic loading environments -function __init__() -end +function __init__() end end # module diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl index ddd2c6ace..793eaa897 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl @@ -4,7 +4,10 @@ const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner} -function RLBase.update!(learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory) +function RLBase.update!( + learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners}, + t::AbstractTrajectory, +) length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return learner.update_step += 1 diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl index b2ebab93c..ad2808f26 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl @@ -55,7 +55,7 @@ function DQNLearner(; traces = SARTS, update_step = 0, rng = Random.GLOBAL_RNG, - is_enable_double_DQN::Bool = true + is_enable_double_DQN::Bool = true, ) where {Tq,Tt,Tf} copyto!(approximator, target_approximator) sampler = NStepBatchSampler{traces}(; @@ -75,7 +75,7 @@ function DQNLearner(; sampler, rng, 0.0f0, - is_enable_double_DQN + is_enable_double_DQN, ) end @@ -117,14 +117,14 @@ function RLBase.update!(learner::DQNLearner, batch::NamedTuple) else q_values = Qₜ(s′) end - + if haskey(batch, :next_legal_actions_mask) l′ = send_to_device(D, batch[:next_legal_actions_mask]) q_values .+= ifelse.(l′, 0.0f0, typemin(Float32)) end if is_enable_double_DQN - selected_actions = dropdims(argmax(q_values, dims=1), dims=1) + selected_actions = dropdims(argmax(q_values, dims = 1), dims = 1) q′ = Qₜ(s′)[selected_actions] else q′ = dropdims(maximum(q_values; dims = 1), dims = 1) diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl index 4ec47c5ba..8f190ba2c 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl @@ -5,4 +5,4 @@ include("qr_dqn.jl") include("rem_dqn.jl") include("rainbow.jl") include("iqn.jl") -include("common.jl") \ No newline at end of file +include("common.jl") diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl index 2832b1905..0f6c1b519 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl @@ -1,6 +1,6 @@ export QRDQNLearner, quantile_huber_loss -function quantile_huber_loss(ŷ, y; κ=1.0f0) +function quantile_huber_loss(ŷ, y; κ = 1.0f0) N, B = size(y) Δ = reshape(y, N, 1, B) .- reshape(ŷ, 1, N, B) abs_error = abs.(Δ) @@ -8,12 +8,13 @@ function quantile_huber_loss(ŷ, y; κ=1.0f0) linear = abs_error .- quadratic huber_loss = 0.5f0 .* quadratic .* quadratic .+ κ .* linear - cum_prob = send_to_device(device(y), range(0.5f0 / N; length=N, step=1.0f0 / N)) + cum_prob = send_to_device(device(y), range(0.5f0 / N; length = N, step = 1.0f0 / N)) loss = Zygote.dropgrad(abs.(cum_prob .- (Δ .< 0))) .* huber_loss - mean(sum(loss;dims=1)) + mean(sum(loss; dims = 1)) end -mutable struct QRDQNLearner{Tq <: AbstractApproximator,Tt <: AbstractApproximator,Tf,R} <: AbstractLearner +mutable struct QRDQNLearner{Tq<:AbstractApproximator,Tt<:AbstractApproximator,Tf,R} <: + AbstractLearner approximator::Tq target_approximator::Tt min_replay_history::Int @@ -51,25 +52,25 @@ See paper: [Distributional Reinforcement Learning with Quantile Regression](http function QRDQNLearner(; approximator, target_approximator, - stack_size::Union{Int,Nothing}=nothing, - γ::Float32=0.99f0, - batch_size::Int=32, - update_horizon::Int=1, - min_replay_history::Int=32, - update_freq::Int=1, - n_quantile::Int=1, - target_update_freq::Int=100, - traces=SARTS, - update_step=0, - loss_func=quantile_huber_loss, - rng=Random.GLOBAL_RNG + stack_size::Union{Int,Nothing} = nothing, + γ::Float32 = 0.99f0, + batch_size::Int = 32, + update_horizon::Int = 1, + min_replay_history::Int = 32, + update_freq::Int = 1, + n_quantile::Int = 1, + target_update_freq::Int = 100, + traces = SARTS, + update_step = 0, + loss_func = quantile_huber_loss, + rng = Random.GLOBAL_RNG, ) copyto!(approximator, target_approximator) sampler = NStepBatchSampler{traces}(; - γ=γ, - n=update_horizon, - stack_size=stack_size, - batch_size=batch_size, + γ = γ, + n = update_horizon, + stack_size = stack_size, + batch_size = batch_size, ) N = n_quantile @@ -100,7 +101,7 @@ function (learner::QRDQNLearner)(env) s = send_to_device(device(learner.approximator), state(env)) s = Flux.unsqueeze(s, ndims(s) + 1) q = reshape(learner.approximator(s), learner.n_quantile, :) - vec(mean(q, dims=1)) |> send_to_host + vec(mean(q, dims = 1)) |> send_to_host end function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple) @@ -117,10 +118,12 @@ function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple) a = CartesianIndex.(a, 1:batch_size) target_quantiles = reshape(Qₜ(s′), N, :, batch_size) - qₜ = dropdims(mean(target_quantiles; dims=1); dims=1) - aₜ = dropdims(argmax(qₜ, dims=1); dims=1) + qₜ = dropdims(mean(target_quantiles; dims = 1); dims = 1) + aₜ = dropdims(argmax(qₜ, dims = 1); dims = 1) @views target_quantile_aₜ = target_quantiles[:, aₜ] - y = reshape(r, 1, batch_size) .+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ + y = + reshape(r, 1, batch_size) .+ + γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ gs = gradient(params(Q)) do q = reshape(Q(s), N, :, batch_size) diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl index 182ce253f..08f270429 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl @@ -120,7 +120,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple) target_q = Qₜ(s′) target_q = convex_polygon .* reshape(target_q, :, ensemble_num, batch_size) - target_q = dropdims(sum(target_q, dims=2), dims=2) + target_q = dropdims(sum(target_q, dims = 2), dims = 2) if haskey(batch, :next_legal_actions_mask) l′ = send_to_device(D, batch[:next_legal_actions_mask]) @@ -133,7 +133,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple) gs = gradient(params(Q)) do q = Q(s) q = convex_polygon .* reshape(q, :, ensemble_num, batch_size) - q = dropdims(sum(q, dims=2), dims=2)[a] + q = dropdims(sum(q, dims = 2), dims = 2)[a] loss = loss_func(G, q) ignore() do @@ -143,5 +143,4 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple) end update!(Q, gs) -end - +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl index eeda5a271..68a289bb5 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl @@ -43,4 +43,4 @@ function Base.run( end hook(POST_EXPERIMENT_STAGE, policy, env) hook -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl index d95fbf2f0..4e40b9703 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl @@ -19,8 +19,8 @@ See the paper https://arxiv.org/abs/1603.01121 for more details. mutable struct NFSPAgent <: AbstractPolicy rl_agent::Agent sl_agent::Agent - η - rng + η::Any + rng::Any update_freq::Int update_step::Int mode::Bool @@ -39,7 +39,8 @@ function RLBase.update!(π::NFSPAgent, env::AbstractEnv) π(POST_ACT_STAGE, env, player) end -(π::NFSPAgent)(stage::PreEpisodeStage, env::AbstractEnv, ::Any) = update!(π.rl_agent.trajectory, π.rl_agent.policy, env, stage) +(π::NFSPAgent)(stage::PreEpisodeStage, env::AbstractEnv, ::Any) = + update!(π.rl_agent.trajectory, π.rl_agent.policy, env, stage) function (π::NFSPAgent)(stage::PreActStage, env::AbstractEnv, action) rl = π.rl_agent @@ -88,7 +89,7 @@ function (π::NFSPAgent)(::PostEpisodeStage, env::AbstractEnv, player::Any) if haskey(rl.trajectory, :legal_actions_mask) push!(rl.trajectory[:legal_actions_mask], legal_action_space_mask(env, player)) end - + # update the policy π.update_step += 1 if π.update_step % π.update_freq == 0 @@ -106,7 +107,7 @@ function rl_learn!(policy::QBasedPolicy, t::AbstractTrajectory) # just learn the approximator, not update target_approximator learner = policy.learner length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return - + _, batch = sample(learner.rng, t, learner.sampler) if t isa PrioritizedTrajectory diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl index 7ba996f82..cd09a078d 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl @@ -6,7 +6,7 @@ export NFSPAgentManager A special MultiAgentManager in which all agents use NFSP policy to play the game. """ mutable struct NFSPAgentManager <: AbstractPolicy - agents::Dict{Any, NFSPAgent} + agents::Dict{Any,NFSPAgent} end function (π::NFSPAgentManager)(env::AbstractEnv) @@ -18,7 +18,8 @@ function (π::NFSPAgentManager)(env::AbstractEnv) end end -RLBase.prob(π::NFSPAgentManager, env::AbstractEnv, args...) = prob(π.agents[current_player(env)], env, args...) +RLBase.prob(π::NFSPAgentManager, env::AbstractEnv, args...) = + prob(π.agents[current_player(env)], env, args...) function RLBase.update!(π::NFSPAgentManager, env::AbstractEnv) while current_player(env) == chance_player(env) @@ -27,7 +28,10 @@ function RLBase.update!(π::NFSPAgentManager, env::AbstractEnv) update!(π.agents[current_player(env)], env) end -function (π::NFSPAgentManager)(stage::Union{PreEpisodeStage, PostEpisodeStage}, env::AbstractEnv) +function (π::NFSPAgentManager)( + stage::Union{PreEpisodeStage,PostEpisodeStage}, + env::AbstractEnv, +) @sync for (player, agent) in π.agents @async agent(stage, env, player) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl index cbe73fec7..79408aa02 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl @@ -19,19 +19,14 @@ end - `rng = Random.GLOBAL_RNG` """ function BehaviorCloningPolicy(; - approximator::A, - explorer::Any = GreedyExplorer(), - batch_size::Int = 32, - min_reservoir_history::Int = 100, - rng = Random.GLOBAL_RNG + approximator::A, + explorer::Any = GreedyExplorer(), + batch_size::Int = 32, + min_reservoir_history::Int = 100, + rng = Random.GLOBAL_RNG, ) where {A} sampler = BatchSampler{(:state, :action)}(batch_size; rng = rng) - BehaviorCloningPolicy( - approximator, - explorer, - sampler, - min_reservoir_history, - ) + BehaviorCloningPolicy(approximator, explorer, sampler, min_reservoir_history) end function (p::BehaviorCloningPolicy)(env::AbstractEnv) @@ -39,7 +34,8 @@ function (p::BehaviorCloningPolicy)(env::AbstractEnv) s_batch = Flux.unsqueeze(s, ndims(s) + 1) s_batch = send_to_device(device(p.approximator), s_batch) logits = p.approximator(s_batch) |> vec |> send_to_host # drop dimension - typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) : p.explorer(logits, legal_action_space_mask(env)) + typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) : + p.explorer(logits, legal_action_space_mask(env)) end function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :action)}) @@ -65,7 +61,8 @@ function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv) m = p.approximator s_batch = send_to_device(device(m), Flux.unsqueeze(s, ndims(s) + 1)) values = m(s_batch) |> vec |> send_to_host - typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) : prob(p.explorer, values, legal_action_space_mask(env)) + typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) : + prob(p.explorer, values, legal_action_space_mask(env)) end function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv, action) diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl index 8464e08ec..1d8d7d17a 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl @@ -118,7 +118,12 @@ function (p::DDPGPolicy)(env) s = state(env) s = Flux.unsqueeze(s, ndims(s) + 1) actions = p.behavior_actor(send_to_device(D, s)) |> vec |> send_to_host - c = clamp.(actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na), -p.act_limit, p.act_limit) + c = + clamp.( + actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na), + -p.act_limit, + p.act_limit, + ) p.na == 1 && return c[1] c end @@ -154,7 +159,7 @@ function RLBase.update!(p::DDPGPolicy, batch::NamedTuple{SARTS}) a′ = Aₜ(s′) qₜ = Cₜ(vcat(s′, a′)) |> vec y = r .+ γ .* (1 .- t) .* qₜ - a = Flux.unsqueeze(a, ndims(a)+1) + a = Flux.unsqueeze(a, ndims(a) + 1) gs1 = gradient(Flux.params(C)) do q = C(vcat(s, a)) |> vec diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl index fa06a9132..bf9e8f27e 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl @@ -148,7 +148,9 @@ function RLBase.prob( if p.update_step < p.n_random_start @error "todo" else - μ, logσ = p.approximator.actor(send_to_device(device(p.approximator), state)) |> send_to_host + μ, logσ = + p.approximator.actor(send_to_device(device(p.approximator), state)) |> + send_to_host StructArray{Normal}((μ, exp.(logσ))) end end @@ -256,11 +258,11 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory) end s = send_to_device(D, select_last_dim(states_flatten, inds)) # !!! performance critical a = send_to_device(D, select_last_dim(actions_flatten, inds)) - + if eltype(a) === Int a = CartesianIndex.(a, 1:length(a)) end - + r = send_to_device(D, vec(returns)[inds]) log_p = send_to_device(D, vec(action_log_probs)[inds]) adv = send_to_device(D, vec(advantages)[inds]) @@ -275,7 +277,8 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory) else log_p′ₐ = normlogpdf(μ, exp.(logσ), a) end - entropy_loss = mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2 + entropy_loss = + mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2 else # actor is assumed to return discrete logits logit′ = AC.actor(s) diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl index 157a3b1ca..f66157193 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl @@ -104,8 +104,8 @@ function SACPolicy(; Float32(-action_dims), update_step, rng, - 0f0, - 0f0, + 0.0f0, + 0.0f0, ) end @@ -120,7 +120,7 @@ function (p::SACPolicy)(env) s = state(env) s = Flux.unsqueeze(s, ndims(s) + 1) # trainmode: - action = dropdims(p.policy(p.rng, s; is_sampling=true), dims=2) # Single action vec, drop second dim + action = dropdims(p.policy(p.rng, s; is_sampling = true), dims = 2) # Single action vec, drop second dim # testmode: # if testing dont sample an action, but act deterministically by @@ -146,7 +146,7 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS}) γ, τ, α = p.γ, p.τ, p.α - a′, log_π = p.policy(p.rng, s′; is_sampling=true, is_return_log_prob=true) + a′, log_π = p.policy(p.rng, s′; is_sampling = true, is_return_log_prob = true) q′_input = vcat(s′, a′) q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input)) @@ -168,12 +168,12 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS}) # Train Policy p_grad = gradient(Flux.params(p.policy)) do - a, log_π = p.policy(p.rng, s; is_sampling=true, is_return_log_prob=true) + a, log_π = p.policy(p.rng, s; is_sampling = true, is_return_log_prob = true) q_input = vcat(s, a) q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input)) reward = mean(q) entropy = mean(log_π) - ignore() do + ignore() do p.reward_term = reward p.entropy_term = entropy end diff --git a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl index a91c22b62..d0bbedfa3 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl @@ -11,7 +11,7 @@ A `Dict` is used internally to store the mapping from state to action. """ Base.@kwdef struct TabularPolicy{S,A} <: AbstractPolicy table::Dict{S,A} = Dict{Int,Int}() - n_action::Union{Int, Nothing} = nothing + n_action::Union{Int,Nothing} = nothing end (p::TabularPolicy)(env::AbstractEnv) = p(state(env)) diff --git a/test/runtests.jl b/test/runtests.jl index 04be42cbd..d6dd95c9e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,4 @@ using Test using ReinforcementLearning -@testset "ReinforcementLearning" begin -end +@testset "ReinforcementLearning" begin end