diff --git a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl index ff647a3a4..2eeecbd74 100644 --- a/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl +++ b/docs/experiments/experiments/CFR/JuliaRL_DeepCFR_OpenSpiel.jl @@ -61,5 +61,11 @@ function RL.Experiment( batch_size_Π = 2048, initializer = glorot_normal(CUDA.CURAND.default_rng()), ) - Experiment(p, env, StopAfterStep(500, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), "# run DeepcCFR on leduc_poker") -end \ No newline at end of file + Experiment( + p, + env, + StopAfterStep(500, is_show_progress = !haskey(ENV, "CI")), + EmptyHook(), + "# run DeepcCFR on leduc_poker", + ) +end diff --git a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl index edfd7f199..d89cabb16 100644 --- a/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl +++ b/docs/experiments/experiments/CFR/JuliaRL_TabularCFR_OpenSpiel.jl @@ -23,8 +23,14 @@ function RL.Experiment( π = TabularCFRPolicy(; rng = rng) description = "# Play `$game` in OpenSpiel with TabularCFRPolicy" - Experiment(π, env, StopAfterStep(300, is_show_progress=!haskey(ENV, "CI")), EmptyHook(), description) + Experiment( + π, + env, + StopAfterStep(300, is_show_progress = !haskey(ENV, "CI")), + EmptyHook(), + description, + ) end ex = E`JuliaRL_TabularCFR_OpenSpiel(kuhn_poker)` -run(ex) \ No newline at end of file +run(ex) diff --git a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl index f51b4a2a9..59574a7a2 100644 --- a/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl +++ b/docs/experiments/experiments/DQN/Dopamine_DQN_Atari.jl @@ -79,44 +79,41 @@ function atari_env_factory( repeat_action_probability = 0.25, n_replica = nothing, ) - init(seed) = - RewardOverriddenEnv( - StateCachedEnv( - StateTransformedEnv( - AtariEnv(; - name = string(name), - grayscale_obs = true, - noop_max = 30, - frame_skip = 4, - terminal_on_life_loss = false, - repeat_action_probability = repeat_action_probability, - max_num_frames_per_episode = n_frames * max_episode_steps, - color_averaging = false, - full_action_space = false, - seed = seed, - ); - state_mapping=Chain( - ResizeImage(state_size...), - StackFrames(state_size..., n_frames) - ), - state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames)) - ) + init(seed) = RewardOverriddenEnv( + StateCachedEnv( + StateTransformedEnv( + AtariEnv(; + name = string(name), + grayscale_obs = true, + noop_max = 30, + frame_skip = 4, + terminal_on_life_loss = false, + repeat_action_probability = repeat_action_probability, + max_num_frames_per_episode = n_frames * max_episode_steps, + color_averaging = false, + full_action_space = false, + seed = seed, + ); + state_mapping = Chain( + ResizeImage(state_size...), + StackFrames(state_size..., n_frames), + ), + state_space_mapping = _ -> + Space(fill(0 .. 256, state_size..., n_frames)), ), - r -> clamp(r, -1, 1) - ) + ), + r -> clamp(r, -1, 1), + ) if isnothing(n_replica) init(seed) else - envs = [ - init(isnothing(seed) ? nothing : hash(seed + i)) - for i in 1:n_replica - ] + envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica] states = Flux.batch(state.(envs)) rewards = reward.(envs) terminals = is_terminated.(envs) A = Space([action_space(x) for x in envs]) - S = Space(fill(0..255, size(states))) + S = Space(fill(0 .. 255, size(states))) MultiThreadEnv(envs, states, rewards, terminals, A, S, nothing) end end @@ -172,7 +169,7 @@ function RL.Experiment( ::Val{:Atari}, name::AbstractString; save_dir = nothing, - seed = nothing + seed = nothing, ) rng = Random.GLOBAL_RNG Random.seed!(rng, seed) @@ -190,7 +187,7 @@ function RL.Experiment( name, STATE_SIZE, N_FRAMES; - seed = isnothing(seed) ? nothing : hash(seed + 1) + seed = isnothing(seed) ? nothing : hash(seed + 1), ) N_ACTIONS = length(action_space(env)) init = glorot_uniform(rng) @@ -254,17 +251,15 @@ function RL.Experiment( end, DoEveryNEpisode() do t, agent, env with_logger(lg) do - @info "training" episode_length = step_per_episode.steps[end] reward = reward_per_episode.rewards[end] log_step_increment = 0 + @info "training" episode_length = step_per_episode.steps[end] reward = + reward_per_episode.rewards[end] log_step_increment = 0 end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." p = agent.policy p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon - h = ComposedHook( - TotalOriginalRewardPerEpisode(), - StepsPerEpisode(), - ) + h = ComposedHook(TotalOriginalRewardPerEpisode(), StepsPerEpisode()) s = @elapsed run( p, atari_env_factory( @@ -281,16 +276,18 @@ function RL.Experiment( avg_score = mean(h[1].rewards[1:end-1]) avg_length = mean(h[2].steps[1:end-1]) - @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score + @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = + avg_score with_logger(lg) do - @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0 + @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = + 0 end end, ) stop_condition = StopAfterStep( haskey(ENV, "CI") ? 1_000 : 50_000_000, - is_show_progress=!haskey(ENV, "CI") + is_show_progress = !haskey(ENV, "CI"), ) Experiment(agent, env, stop_condition, hook, "# DQN <-> Atari($name)") end diff --git a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl index 6e0305ae5..f0a2b1bdb 100644 --- a/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl +++ b/docs/experiments/experiments/DQN/Dopamine_IQN_Atari.jl @@ -84,44 +84,41 @@ function atari_env_factory( repeat_action_probability = 0.25, n_replica = nothing, ) - init(seed) = - RewardOverriddenEnv( - StateCachedEnv( - StateTransformedEnv( - AtariEnv(; - name = string(name), - grayscale_obs = true, - noop_max = 30, - frame_skip = 4, - terminal_on_life_loss = false, - repeat_action_probability = repeat_action_probability, - max_num_frames_per_episode = n_frames * max_episode_steps, - color_averaging = false, - full_action_space = false, - seed = seed, - ); - state_mapping=Chain( - ResizeImage(state_size...), - StackFrames(state_size..., n_frames) - ), - state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames)) - ) + init(seed) = RewardOverriddenEnv( + StateCachedEnv( + StateTransformedEnv( + AtariEnv(; + name = string(name), + grayscale_obs = true, + noop_max = 30, + frame_skip = 4, + terminal_on_life_loss = false, + repeat_action_probability = repeat_action_probability, + max_num_frames_per_episode = n_frames * max_episode_steps, + color_averaging = false, + full_action_space = false, + seed = seed, + ); + state_mapping = Chain( + ResizeImage(state_size...), + StackFrames(state_size..., n_frames), + ), + state_space_mapping = _ -> + Space(fill(0 .. 256, state_size..., n_frames)), ), - r -> clamp(r, -1, 1) - ) + ), + r -> clamp(r, -1, 1), + ) if isnothing(n_replica) init(seed) else - envs = [ - init(isnothing(seed) ? nothing : hash(seed + i)) - for i in 1:n_replica - ] + envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica] states = Flux.batch(state.(envs)) rewards = reward.(envs) terminals = is_terminated.(envs) A = Space([action_space(x) for x in envs]) - S = Space(fill(0..255, size(states))) + S = Space(fill(0 .. 255, size(states))) MultiThreadEnv(envs, states, rewards, terminals, A, S, nothing) end end @@ -195,7 +192,12 @@ function RL.Experiment( N_FRAMES = 4 STATE_SIZE = (84, 84) - env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 2)) + env = atari_env_factory( + name, + STATE_SIZE, + N_FRAMES; + seed = isnothing(seed) ? nothing : hash(seed + 2), + ) N_ACTIONS = length(action_space(env)) Nₑₘ = 64 @@ -250,7 +252,7 @@ function RL.Experiment( ), ), trajectory = CircularArraySARTTrajectory( - capacity = haskey(ENV, "CI") : 1_000 : 1_000_000, + capacity = haskey(ENV, "CI"):1_000:1_000_000, state = Matrix{Float32} => STATE_SIZE, ), ) @@ -274,7 +276,7 @@ function RL.Experiment( steps_per_episode.steps[end] log_step_increment = 0 end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." p = agent.policy p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon @@ -286,7 +288,7 @@ function RL.Experiment( STATE_SIZE, N_FRAMES, MAX_EPISODE_STEPS_EVAL; - seed = isnothing(seed) ? nothing : hash(seed + t) + seed = isnothing(seed) ? nothing : hash(seed + t), ), StopAfterStep(125_000; is_show_progress = false), h, @@ -295,16 +297,18 @@ function RL.Experiment( avg_score = mean(h[1].rewards[1:end-1]) avg_length = mean(h[2].steps[1:end-1]) - @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score + @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = + avg_score with_logger(lg) do - @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0 + @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = + 0 end end, ) stop_condition = StopAfterStep( haskey(ENV, "CI") ? 10_000 : 50_000_000, - is_show_progress=!haskey(ENV, "CI") + is_show_progress = !haskey(ENV, "CI"), ) Experiment(agent, env, stop_condition, hook, "# IQN <-> Atari($name)") end diff --git a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl index 432e110e4..a49efb658 100644 --- a/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl +++ b/docs/experiments/experiments/DQN/Dopamine_Rainbow_Atari.jl @@ -83,44 +83,41 @@ function atari_env_factory( repeat_action_probability = 0.25, n_replica = nothing, ) - init(seed) = - RewardOverriddenEnv( - StateCachedEnv( - StateTransformedEnv( - AtariEnv(; - name = string(name), - grayscale_obs = true, - noop_max = 30, - frame_skip = 4, - terminal_on_life_loss = false, - repeat_action_probability = repeat_action_probability, - max_num_frames_per_episode = n_frames * max_episode_steps, - color_averaging = false, - full_action_space = false, - seed = seed, - ); - state_mapping=Chain( - ResizeImage(state_size...), - StackFrames(state_size..., n_frames) - ), - state_space_mapping= _ -> Space(fill(0..256, state_size..., n_frames)) - ) + init(seed) = RewardOverriddenEnv( + StateCachedEnv( + StateTransformedEnv( + AtariEnv(; + name = string(name), + grayscale_obs = true, + noop_max = 30, + frame_skip = 4, + terminal_on_life_loss = false, + repeat_action_probability = repeat_action_probability, + max_num_frames_per_episode = n_frames * max_episode_steps, + color_averaging = false, + full_action_space = false, + seed = seed, + ); + state_mapping = Chain( + ResizeImage(state_size...), + StackFrames(state_size..., n_frames), + ), + state_space_mapping = _ -> + Space(fill(0 .. 256, state_size..., n_frames)), ), - r -> clamp(r, -1, 1) - ) + ), + r -> clamp(r, -1, 1), + ) if isnothing(n_replica) init(seed) else - envs = [ - init(isnothing(seed) ? nothing : hash(seed + i)) - for i in 1:n_replica - ] + envs = [init(isnothing(seed) ? nothing : hash(seed + i)) for i in 1:n_replica] states = Flux.batch(state.(envs)) rewards = reward.(envs) terminals = is_terminated.(envs) A = Space([action_space(x) for x in envs]) - S = Space(fill(0..255, size(states))) + S = Space(fill(0 .. 255, size(states))) MultiThreadEnv(envs, states, rewards, terminals, A, S, nothing) end end @@ -191,7 +188,12 @@ function RL.Experiment( N_FRAMES = 4 STATE_SIZE = (84, 84) - env = atari_env_factory(name, STATE_SIZE, N_FRAMES; seed = isnothing(seed) ? nothing : hash(seed + 1)) + env = atari_env_factory( + name, + STATE_SIZE, + N_FRAMES; + seed = isnothing(seed) ? nothing : hash(seed + 1), + ) N_ACTIONS = length(action_space(env)) N_ATOMS = 51 init = glorot_uniform(rng) @@ -238,7 +240,7 @@ function RL.Experiment( ), ), trajectory = CircularArrayPSARTTrajectory( - capacity = haskey(ENV, "CI") : 1_000 : 1_000_000, + capacity = haskey(ENV, "CI"):1_000:1_000_000, state = Matrix{Float32} => STATE_SIZE, ), ) @@ -262,7 +264,7 @@ function RL.Experiment( steps_per_episode.steps[end] log_step_increment = 0 end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." p = agent.policy p = @set p.explorer = EpsilonGreedyExplorer(0.001; rng = rng) # set evaluation epsilon @@ -282,16 +284,18 @@ function RL.Experiment( avg_length = mean(h[2].steps[1:end-1]) avg_score = mean(h[1].rewards[1:end-1]) - @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = avg_score + @info "finished evaluating agent in $s seconds" avg_length = avg_length avg_score = + avg_score with_logger(lg) do - @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = 0 + @info "evaluating" avg_length = avg_length avg_score = avg_score log_step_increment = + 0 end end, ) stop_condition = StopAfterStep( haskey(ENV, "CI") ? 10_000 : 50_000_000, - is_show_progress=!haskey(ENV, "CI") + is_show_progress = !haskey(ENV, "CI"), ) Experiment(agent, env, stop_condition, hook, "# Rainbow <-> Atari($name)") diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl index 9f32c48d6..d5ba9c9c7 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_CartPole.jl @@ -51,7 +51,7 @@ function RL.Experiment( state = Vector{Float32} => (ns,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(policy, env, stop_condition, hook, "# BasicDQN <-> CartPole") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl index ae8f02cb5..bc79c94be 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_MountainCar.jl @@ -51,7 +51,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(70_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(70_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") diff --git a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl index 4b7f2a5cd..39748e307 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_BasicDQN_SingleRoomUndirected.jl @@ -18,47 +18,47 @@ function RL.Experiment( ::Val{:BasicDQN}, ::Val{:SingleRoomUndirected}, ::Nothing; - seed=123, + seed = 123, ) rng = StableRNG(seed) - env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng=rng) + env = GridWorlds.SingleRoomUndirectedModule.SingleRoomUndirected(rng = rng) env = GridWorlds.RLBaseEnv(env) - env = RLEnvs.StateTransformedEnv(env;state_mapping=x -> vec(Float32.(x))) + env = RLEnvs.StateTransformedEnv(env; state_mapping = x -> vec(Float32.(x))) env = RewardOverriddenEnv(env, x -> x - convert(typeof(x), 0.01)) env = MaxTimeoutEnv(env, 240) ns, na = length(state(env)), length(action_space(env)) agent = Agent( - policy=QBasedPolicy( - learner=BasicDQNLearner( - approximator=NeuralNetworkApproximator( - model=Chain( - Dense(ns, 128, relu; init=glorot_uniform(rng)), - Dense(128, 128, relu; init=glorot_uniform(rng)), - Dense(128, na; init=glorot_uniform(rng)), + policy = QBasedPolicy( + learner = BasicDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( + Dense(ns, 128, relu; init = glorot_uniform(rng)), + Dense(128, 128, relu; init = glorot_uniform(rng)), + Dense(128, na; init = glorot_uniform(rng)), ) |> cpu, - optimizer=ADAM(), + optimizer = ADAM(), ), - batch_size=32, - min_replay_history=100, - loss_func=huber_loss, - rng=rng, + batch_size = 32, + min_replay_history = 100, + loss_func = huber_loss, + rng = rng, ), - explorer=EpsilonGreedyExplorer( - kind=:exp, - ϵ_stable=0.01, - decay_steps=500, - rng=rng, + explorer = EpsilonGreedyExplorer( + kind = :exp, + ϵ_stable = 0.01, + decay_steps = 500, + rng = rng, ), ), - trajectory=CircularArraySARTTrajectory( - capacity=1000, - state=Vector{Float32} => (ns,), + trajectory = CircularArraySARTTrajectory( + capacity = 1000, + state = Vector{Float32} => (ns,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl index 7e922e13e..7e2e218f6 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_CartPole.jl @@ -14,13 +14,13 @@ using Flux.Losses function build_dueling_network(network::Chain) lm = length(network) - if !(network[lm] isa Dense) || !(network[lm-1] isa Dense) + if !(network[lm] isa Dense) || !(network[lm-1] isa Dense) error("The Qnetwork provided is incompatible with dueling.") end - base = Chain([deepcopy(network[i]) for i=1:lm-2]...) + base = Chain([deepcopy(network[i]) for i in 1:lm-2]...) last_layer_dims = size(network[lm].weight, 2) val = Chain(deepcopy(network[lm-1]), Dense(last_layer_dims, 1)) - adv = Chain([deepcopy(network[i]) for i=lm-1:lm]...) + adv = Chain([deepcopy(network[i]) for i in lm-1:lm]...) return DuelingNetwork(base, val, adv) end @@ -37,8 +37,8 @@ function RL.Experiment( base_model = Chain( Dense(ns, 128, relu; init = glorot_uniform(rng)), Dense(128, 128, relu; init = glorot_uniform(rng)), - Dense(128, na; init = glorot_uniform(rng)) - ) + Dense(128, na; init = glorot_uniform(rng)), + ) agent = Agent( policy = QBasedPolicy( @@ -72,7 +72,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl index d8b1eb633..f74bcaea1 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_DQN_MountainCar.jl @@ -64,7 +64,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(40_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(40_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl index f3ab3c98f..cba0fcea5 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_IQN_CartPole.jl @@ -71,7 +71,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl index 7fb238d28..c45bb0b03 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_QRDQN_Cartpole.jl @@ -17,58 +17,58 @@ function RL.Experiment( ::Val{:QRDQN}, ::Val{:CartPole}, ::Nothing; - seed=123, + seed = 123, ) N = 10 rng = StableRNG(seed) - env = CartPoleEnv(; T=Float32, rng=rng) + env = CartPoleEnv(; T = Float32, rng = rng) ns, na = length(state(env)), length(action_space(env)) init = glorot_uniform(rng) agent = Agent( - policy=QBasedPolicy( - learner=QRDQNLearner( - approximator=NeuralNetworkApproximator( - model=Chain( + policy = QBasedPolicy( + learner = QRDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( Dense(ns, 128, relu; init = init), Dense(128, 128, relu; init = init), Dense(128, N * na; init = init), ) |> cpu, - optimizer=ADAM(), + optimizer = ADAM(), ), - target_approximator=NeuralNetworkApproximator( - model=Chain( + target_approximator = NeuralNetworkApproximator( + model = Chain( Dense(ns, 128, relu; init = init), Dense(128, 128, relu; init = init), Dense(128, N * na; init = init), ) |> cpu, ), - stack_size=nothing, - batch_size=32, - update_horizon=1, - min_replay_history=100, - update_freq=1, - target_update_freq=100, - n_quantile=N, - rng=rng, + stack_size = nothing, + batch_size = 32, + update_horizon = 1, + min_replay_history = 100, + update_freq = 1, + target_update_freq = 100, + n_quantile = N, + rng = rng, ), - explorer=EpsilonGreedyExplorer( - kind=:exp, - ϵ_stable=0.01, - decay_steps=500, - rng=rng, + explorer = EpsilonGreedyExplorer( + kind = :exp, + ϵ_stable = 0.01, + decay_steps = 500, + rng = rng, ), ), - trajectory=CircularArraySARTTrajectory( - capacity=1000, - state=Vector{Float32} => (ns,), + trajectory = CircularArraySARTTrajectory( + capacity = 1000, + state = Vector{Float32} => (ns,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl index fdf473a83..7f74dd096 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_REMDQN_CartPole.jl @@ -52,7 +52,7 @@ function RL.Experiment( update_freq = 1, target_update_freq = 100, ensemble_num = ensemble_num, - ensemble_method = :rand, + ensemble_method = :rand, rng = rng, ), explorer = EpsilonGreedyExplorer( @@ -68,7 +68,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl index f367d1cf5..d8f5d2437 100644 --- a/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl +++ b/docs/experiments/experiments/DQN/JuliaRL_Rainbow_CartPole.jl @@ -71,7 +71,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "") end diff --git a/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl b/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl index c7e0bb2fa..83549aab0 100644 --- a/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl +++ b/docs/experiments/experiments/ED/JuliaRL_ED_OpenSpiel.jl @@ -19,41 +19,37 @@ end function (hook::KuhnOpenEDHook)(::PreEpisodeStage, policy, env) ## get nash_conv of the current policy. push!(hook.results, RLZoo.nash_conv(policy, env)) - + ## update agents' learning rate. for (_, agent) in policy.agents agent.learner.optimizer[2].eta = 1.0 / sqrt(length(hook.results)) end end -function RL.Experiment( - ::Val{:JuliaRL}, - ::Val{:ED}, - ::Val{:OpenSpiel}, - game; - seed = 123, -) +function RL.Experiment(::Val{:JuliaRL}, ::Val{:ED}, ::Val{:OpenSpiel}, game; seed = 123) rng = StableRNG(seed) - + env = OpenSpielEnv(game) wrapped_env = ActionTransformedEnv( env, - action_mapping = a -> RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1), - action_space_mapping = as -> RLBase.current_player(env) == chance_player(env) ? - as : Base.OneTo(num_distinct_actions(env.game)), + action_mapping = a -> + RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1), + action_space_mapping = as -> + RLBase.current_player(env) == chance_player(env) ? as : + Base.OneTo(num_distinct_actions(env.game)), ) wrapped_env = DefaultStateStyleEnv{InformationSet{Array}()}(wrapped_env) player = 0 # or 1 ns, na = length(state(wrapped_env, player)), length(action_space(wrapped_env, player)) create_network() = Chain( - Dense(ns, 64, relu;init = glorot_uniform(rng)), - Dense(64, na;init = glorot_uniform(rng)) + Dense(ns, 64, relu; init = glorot_uniform(rng)), + Dense(64, na; init = glorot_uniform(rng)), ) create_learner() = NeuralNetworkApproximator( model = create_network(), - optimizer = Flux.Optimise.Optimiser(WeightDecay(0.001), Descent()) + optimizer = Flux.Optimise.Optimiser(WeightDecay(0.001), Descent()), ) EDmanager = EDManager( @@ -63,20 +59,26 @@ function RL.Experiment( create_learner(), # neural network learner WeightedSoftmaxExplorer(), # explorer ) for player in players(env) if player != chance_player(env) - ) + ), ) - stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(500, is_show_progress = !haskey(ENV, "CI")) hook = KuhnOpenEDHook([]) - Experiment(EDmanager, wrapped_env, stop_condition, hook, "# play OpenSpiel $game with ED algorithm") + Experiment( + EDmanager, + wrapped_env, + stop_condition, + hook, + "# play OpenSpiel $game with ED algorithm", + ) end using Plots ex = E`JuliaRL_ED_OpenSpiel(kuhn_poker)` results = run(ex) -plot(ex.hook.results, xlabel="episode", ylabel="nash_conv") +plot(ex.hook.results, xlabel = "episode", ylabel = "nash_conv") savefig("assets/JuliaRL_ED_OpenSpiel(kuhn_poker).png")#hide -# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png) \ No newline at end of file +# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png) diff --git a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl index 7c69f4f63..d234d78e3 100644 --- a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl +++ b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl @@ -35,14 +35,15 @@ function RL.Experiment( seed = 123, ) rng = StableRNG(seed) - + ## Encode the KuhnPokerEnv's states for training. env = KuhnPokerEnv() wrapped_env = StateTransformedEnv( env; state_mapping = s -> [findfirst(==(s), state_space(env))], - state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)] - ) + state_space_mapping = ss -> + [[findfirst(==(s), state_space(env))] for s in state_space(env)], + ) player = 1 # or 2 ns, na = length(state(wrapped_env, player)), length(action_space(wrapped_env, player)) @@ -53,14 +54,14 @@ function RL.Experiment( approximator = NeuralNetworkApproximator( model = Chain( Dense(ns, 64, relu; init = glorot_normal(rng)), - Dense(64, na; init = glorot_normal(rng)) + Dense(64, na; init = glorot_normal(rng)), ) |> cpu, optimizer = Descent(0.01), ), target_approximator = NeuralNetworkApproximator( model = Chain( Dense(ns, 64, relu; init = glorot_normal(rng)), - Dense(64, na; init = glorot_normal(rng)) + Dense(64, na; init = glorot_normal(rng)), ) |> cpu, ), γ = 1.0f0, @@ -81,7 +82,7 @@ function RL.Experiment( ), trajectory = CircularArraySARTTrajectory( capacity = 200_000, - state = Vector{Int} => (ns, ), + state = Vector{Int} => (ns,), ), ) @@ -89,9 +90,9 @@ function RL.Experiment( policy = BehaviorCloningPolicy(; approximator = NeuralNetworkApproximator( model = Chain( - Dense(ns, 64, relu; init = glorot_normal(rng)), - Dense(64, na; init = glorot_normal(rng)) - ) |> cpu, + Dense(ns, 64, relu; init = glorot_normal(rng)), + Dense(64, na; init = glorot_normal(rng)), + ) |> cpu, optimizer = Descent(0.01), ), explorer = WeightedSoftmaxExplorer(), @@ -111,19 +112,22 @@ function RL.Experiment( η = 0.1 # anticipatory parameter nfsp = NFSPAgentManager( Dict( - (player, NFSPAgent( - deepcopy(rl_agent), - deepcopy(sl_agent), - η, - rng, - 128, # update_freq - 0, # initial update_step - true, # initial NFSPAgent's training mode - )) for player in players(wrapped_env) if player != chance_player(wrapped_env) - ) + ( + player, + NFSPAgent( + deepcopy(rl_agent), + deepcopy(sl_agent), + η, + rng, + 128, # update_freq + 0, # initial update_step + true, # initial NFSPAgent's training mode + ), + ) for player in players(wrapped_env) if player != chance_player(wrapped_env) + ), ) - stop_condition = StopAfterEpisode(1_200_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(1_200_000, is_show_progress = !haskey(ENV, "CI")) hook = KuhnNFSPHook(10_000, 0, [], []) Experiment(nfsp, wrapped_env, stop_condition, hook, "# run NFSP on KuhnPokerEnv") @@ -133,8 +137,14 @@ end using Plots ex = E`JuliaRL_NFSP_KuhnPoker` run(ex) -plot(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="nash_conv") +plot( + ex.hook.episode, + ex.hook.results, + xaxis = :log, + xlabel = "episode", + ylabel = "nash_conv", +) savefig("assets/JuliaRL_NFSP_KuhnPoker.png")#hide -# ![](assets/JuliaRL_NFSP_KuhnPoker.png) \ No newline at end of file +# ![](assets/JuliaRL_NFSP_KuhnPoker.png) diff --git a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl index 1e8509f57..d0a8a099c 100644 --- a/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl +++ b/docs/experiments/experiments/NFSP/JuliaRL_NFSP_OpenSpiel.jl @@ -29,21 +29,17 @@ function (hook::KuhnOpenNFSPHook)(::PostEpisodeStage, policy, env) end end -function RL.Experiment( - ::Val{:JuliaRL}, - ::Val{:NFSP}, - ::Val{:OpenSpiel}, - game; - seed = 123, -) +function RL.Experiment(::Val{:JuliaRL}, ::Val{:NFSP}, ::Val{:OpenSpiel}, game; seed = 123) rng = StableRNG(seed) - + env = OpenSpielEnv(game) wrapped_env = ActionTransformedEnv( env, - action_mapping = a -> RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1), - action_space_mapping = as -> RLBase.current_player(env) == chance_player(env) ? - as : Base.OneTo(num_distinct_actions(env.game)), + action_mapping = a -> + RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1), + action_space_mapping = as -> + RLBase.current_player(env) == chance_player(env) ? as : + Base.OneTo(num_distinct_actions(env.game)), ) wrapped_env = DefaultStateStyleEnv{InformationSet{Array}()}(wrapped_env) player = 0 # or 1 @@ -56,14 +52,14 @@ function RL.Experiment( approximator = NeuralNetworkApproximator( model = Chain( Dense(ns, 128, relu; init = glorot_normal(rng)), - Dense(128, na; init = glorot_normal(rng)) + Dense(128, na; init = glorot_normal(rng)), ) |> cpu, optimizer = Descent(0.01), ), target_approximator = NeuralNetworkApproximator( model = Chain( Dense(ns, 128, relu; init = glorot_normal(rng)), - Dense(128, na; init = glorot_normal(rng)) + Dense(128, na; init = glorot_normal(rng)), ) |> cpu, ), γ = 1.0f0, @@ -84,7 +80,7 @@ function RL.Experiment( ), trajectory = CircularArraySARTTrajectory( capacity = 200_000, - state = Vector{Float64} => (ns, ), + state = Vector{Float64} => (ns,), ), ) @@ -92,9 +88,9 @@ function RL.Experiment( policy = BehaviorCloningPolicy(; approximator = NeuralNetworkApproximator( model = Chain( - Dense(ns, 128, relu; init = glorot_normal(rng)), - Dense(128, na; init = glorot_normal(rng)) - ) |> cpu, + Dense(ns, 128, relu; init = glorot_normal(rng)), + Dense(128, na; init = glorot_normal(rng)), + ) |> cpu, optimizer = Descent(0.01), ), explorer = WeightedSoftmaxExplorer(), @@ -114,29 +110,44 @@ function RL.Experiment( η = 0.1 # anticipatory parameter nfsp = NFSPAgentManager( Dict( - (player, NFSPAgent( - deepcopy(rl_agent), - deepcopy(sl_agent), - η, - rng, - 128, # update_freq - 0, # initial update_step - true, # initial NFSPAgent's training mode - )) for player in players(wrapped_env) if player != chance_player(wrapped_env) - ) + ( + player, + NFSPAgent( + deepcopy(rl_agent), + deepcopy(sl_agent), + η, + rng, + 128, # update_freq + 0, # initial update_step + true, # initial NFSPAgent's training mode + ), + ) for player in players(wrapped_env) if player != chance_player(wrapped_env) + ), ) - stop_condition = StopAfterEpisode(1_200_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(1_200_000, is_show_progress = !haskey(ENV, "CI")) hook = KuhnOpenNFSPHook(10_000, 0, [], []) - Experiment(nfsp, wrapped_env, stop_condition, hook, "# Play kuhn_poker in OpenSpiel with NFSP") + Experiment( + nfsp, + wrapped_env, + stop_condition, + hook, + "# Play kuhn_poker in OpenSpiel with NFSP", + ) end using Plots ex = E`JuliaRL_NFSP_OpenSpiel(kuhn_poker)` run(ex) -plot(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="nash_conv") +plot( + ex.hook.episode, + ex.hook.results, + xaxis = :log, + xlabel = "episode", + ylabel = "nash_conv", +) savefig("assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png")#hide -# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png) \ No newline at end of file +# ![](assets/JuliaRL_NFSP_OpenSpiel(kuhn_poker).png) diff --git a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl index f1e42bb51..808719818 100644 --- a/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl +++ b/docs/experiments/experiments/Offline/JuliaRL_BC_CartPole.jl @@ -61,7 +61,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = RecordStateAction() run(agent, env, stop_condition, hook) @@ -84,7 +84,13 @@ function RL.Experiment( end hook = TotalRewardPerEpisode() - Experiment(bc, env, StopAfterEpisode(100, is_show_progress=!haskey(ENV, "CI")), hook, "BehaviorCloning <-> CartPole") + Experiment( + bc, + env, + StopAfterEpisode(100, is_show_progress = !haskey(ENV, "CI")), + hook, + "BehaviorCloning <-> CartPole", + ) end #+ tangle=false diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl index eeed90609..2ea78b92d 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2CGAE_CartPole.jl @@ -63,7 +63,7 @@ function RL.Experiment( terminal = Vector{Bool} => (N_ENV,), ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# A2CGAE with CartPole") end @@ -78,7 +78,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_A2CGAE_CartPole.png") #hide # ![](assets/JuliaRL_A2CGAE_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl index cb03dd025..e56733a92 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_A2C_CartPole.jl @@ -59,7 +59,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=true) + stop_condition = StopAfterStep(50_000, is_show_progress = true) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# A2C with CartPole") end @@ -73,7 +73,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_A2C_CartPole.png") #hide # ![](assets/JuliaRL_A2C_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl index 927cdc2d2..0120abecd 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_DDPG_Pendulum.jl @@ -69,7 +69,7 @@ function RL.Experiment( na = 1, batch_size = 64, start_steps = 1000, - start_policy = RandomPolicy(-1.0..1.0; rng = rng), + start_policy = RandomPolicy(-1.0 .. 1.0; rng = rng), update_after = 1000, update_freq = 1, act_limit = 1.0, @@ -79,11 +79,11 @@ function RL.Experiment( trajectory = CircularArraySARTTrajectory( capacity = 10000, state = Vector{Float32} => (ns,), - action = Float32 => (na, ), + action = Float32 => (na,), ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "# Play Pendulum with DDPG") end diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl index 3559b1b03..c526fcfe3 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MAC_CartPole.jl @@ -64,7 +64,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# MAC with CartPole") end @@ -78,7 +78,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_MAC_CartPole.png") #hide # ![](assets/JuliaRL_MAC_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl index 9d13ec830..542d22100 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_KuhnPoker.jl @@ -32,7 +32,7 @@ function RL.Experiment( ::Val{:MADDPG}, ::Val{:KuhnPoker}, ::Nothing; - seed=123, + seed = 123, ) rng = StableRNG(seed) env = KuhnPokerEnv() @@ -40,10 +40,12 @@ function RL.Experiment( StateTransformedEnv( env; state_mapping = s -> [findfirst(==(s), state_space(env))], - state_space_mapping = ss -> [[findfirst(==(s), state_space(env))] for s in state_space(env)] - ), + state_space_mapping = ss -> + [[findfirst(==(s), state_space(env))] for s in state_space(env)], + ), ## drop the dummy action of the other agent. - action_mapping = x -> length(x) == 1 ? x : Int(ceil(x[current_player(env)]) + 1), + action_mapping = x -> + length(x) == 1 ? x : Int(ceil(x[current_player(env)]) + 1), ) ns, na = 1, 1 # dimension of the state and action. n_players = 2 # number of players @@ -51,18 +53,18 @@ function RL.Experiment( init = glorot_uniform(rng) create_actor() = Chain( - Dense(ns, 64, relu; init = init), - Dense(64, 64, relu; init = init), - Dense(64, na, tanh; init = init), - ) + Dense(ns, 64, relu; init = init), + Dense(64, 64, relu; init = init), + Dense(64, na, tanh; init = init), + ) create_critic() = Chain( Dense(n_players * ns + n_players * na, 64, relu; init = init), Dense(64, 64, relu; init = init), Dense(64, 1; init = init), - ) + ) + - policy = DDPGPolicy( behavior_actor = NeuralNetworkApproximator( model = create_actor(), @@ -84,31 +86,36 @@ function RL.Experiment( ρ = 0.99f0, na = na, start_steps = 1000, - start_policy = RandomPolicy(-0.99..0.99; rng = rng), + start_policy = RandomPolicy(-0.99 .. 0.99; rng = rng), update_after = 1000, act_limit = 0.99, - act_noise = 0., + act_noise = 0.0, rng = rng, ) trajectory = CircularArraySARTTrajectory( capacity = 100_000, # replay buffer capacity - state = Vector{Int} => (ns, ), - action = Float32 => (na, ), + state = Vector{Int} => (ns,), + action = Float32 => (na,), ) agents = MADDPGManager( - Dict((player, Agent( - policy = NamedPolicy(player, deepcopy(policy)), - trajectory = deepcopy(trajectory), - )) for player in players(env) if player != chance_player(env)), + Dict( + ( + player, + Agent( + policy = NamedPolicy(player, deepcopy(policy)), + trajectory = deepcopy(trajectory), + ), + ) for player in players(env) if player != chance_player(env) + ), SARTS, # trace's type 512, # batch_size 100, # update_freq 0, # initial update_step - rng + rng, ) - stop_condition = StopAfterEpisode(100_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(100_000, is_show_progress = !haskey(ENV, "CI")) hook = KuhnMADDPGHook(1000, 0, [], []) Experiment(agents, wrapped_env, stop_condition, hook, "# play KuhnPoker with MADDPG") end @@ -117,7 +124,13 @@ end using Plots ex = E`JuliaRL_MADDPG_KuhnPoker` run(ex) -scatter(ex.hook.episode, ex.hook.results, xaxis=:log, xlabel="episode", ylabel="reward of player 1") +scatter( + ex.hook.episode, + ex.hook.results, + xaxis = :log, + xlabel = "episode", + ylabel = "reward of player 1", +) savefig("assets/JuliaRL_MADDPG_KuhnPoker.png") #hide diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl index 6f00a1dc2..116e4f6d1 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_MADDPG_SpeakerListener.jl @@ -43,53 +43,51 @@ function RL.Experiment( ::Val{:MADDPG}, ::Val{:SpeakerListener}, ::Nothing; - seed=123, + seed = 123, ) rng = StableRNG(seed) env = SpeakerListenerEnv(max_steps = 25) init = glorot_uniform(rng) - critic_dim = sum(length(state(env, p)) + length(action_space(env, p)) for p in (:Speaker, :Listener)) + critic_dim = sum( + length(state(env, p)) + length(action_space(env, p)) for p in (:Speaker, :Listener) + ) create_actor(player) = Chain( Dense(length(state(env, player)), 64, relu; init = init), Dense(64, 64, relu; init = init), - Dense(64, length(action_space(env, player)); init = init) - ) + Dense(64, length(action_space(env, player)); init = init), + ) create_critic(critic_dim) = Chain( Dense(critic_dim, 64, relu; init = init), Dense(64, 64, relu; init = init), Dense(64, 1; init = init), - ) + ) create_policy(player) = DDPGPolicy( - behavior_actor = NeuralNetworkApproximator( - model = create_actor(player), - optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)), - ), - behavior_critic = NeuralNetworkApproximator( - model = create_critic(critic_dim), - optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)), - ), - target_actor = NeuralNetworkApproximator( - model = create_actor(player), - ), - target_critic = NeuralNetworkApproximator( - model = create_critic(critic_dim), - ), - γ = 0.95f0, - ρ = 0.99f0, - na = length(action_space(env, player)), - start_steps = 0, - start_policy = nothing, - update_after = 512 * env.max_steps, # batch_size * env.max_steps - act_limit = 1.0, - act_noise = 0., - ) + behavior_actor = NeuralNetworkApproximator( + model = create_actor(player), + optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)), + ), + behavior_critic = NeuralNetworkApproximator( + model = create_critic(critic_dim), + optimizer = Flux.Optimise.Optimiser(ClipNorm(0.5), ADAM(1e-2)), + ), + target_actor = NeuralNetworkApproximator(model = create_actor(player)), + target_critic = NeuralNetworkApproximator(model = create_critic(critic_dim)), + γ = 0.95f0, + ρ = 0.99f0, + na = length(action_space(env, player)), + start_steps = 0, + start_policy = nothing, + update_after = 512 * env.max_steps, # batch_size * env.max_steps + act_limit = 1.0, + act_noise = 0.0, + ) create_trajectory(player) = CircularArraySARTTrajectory( - capacity = 1_000_000, # replay buffer capacity - state = Vector{Float64} => (length(state(env, player)), ), - action = Vector{Float64} => (length(action_space(env, player)), ), - ) + capacity = 1_000_000, # replay buffer capacity + state = Vector{Float64} => (length(state(env, player)),), + action = Vector{Float64} => (length(action_space(env, player)),), + ) agents = MADDPGManager( Dict( @@ -102,10 +100,10 @@ function RL.Experiment( 512, # batch_size 100, # update_freq 0, # initial update_step - rng + rng, ) - stop_condition = StopAfterEpisode(8_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(8_000, is_show_progress = !haskey(ENV, "CI")) hook = MeanRewardHook(0, 800, 100, [], []) Experiment(agents, env, stop_condition, hook, "# play SpeakerListener with MADDPG") end @@ -114,7 +112,12 @@ end using Plots ex = E`JuliaRL_MADDPG_SpeakerListener` run(ex) -plot(ex.hook.episodes, ex.hook.mean_rewards, xlabel="episode", ylabel="mean episode reward") +plot( + ex.hook.episodes, + ex.hook.mean_rewards, + xlabel = "episode", + ylabel = "mean episode reward", +) savefig("assets/JuliaRL_MADDPG_SpeakerListenerEnv.png") #hide diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl index 45cdd2e13..cbc0e3340 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_CartPole.jl @@ -62,7 +62,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# PPO with CartPole") end @@ -76,7 +76,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_PPO_CartPole.png") #hide # ![](assets/JuliaRL_PPO_CartPole.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl index 80625e104..1fd85f10c 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_PPO_Pendulum.jl @@ -32,7 +32,8 @@ function RL.Experiment( UPDATE_FREQ = 2048 env = MultiThreadEnv([ PendulumEnv(T = Float32, rng = StableRNG(hash(seed + i))) |> - env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high)) for i in 1:N_ENV + env -> ActionTransformedEnv(env, action_mapping = x -> clamp(x * 2, low, high)) + for i in 1:N_ENV ]) init = glorot_uniform(rng) @@ -78,7 +79,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalBatchRewardPerEpisode(N_ENV) Experiment(agent, env, stop_condition, hook, "# Play Pendulum with PPO") end @@ -92,7 +93,7 @@ run(ex) n = minimum(map(length, ex.hook.rewards)) m = mean([@view(x[1:n]) for x in ex.hook.rewards]) s = std([@view(x[1:n]) for x in ex.hook.rewards]) -plot(m,ribbon=s) +plot(m, ribbon = s) savefig("assets/JuliaRL_PPO_Pendulum.png") #hide # ![](assets/JuliaRL_PPO_Pendulum.png) diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl index c15e3c07d..ffd1b93eb 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_SAC_Pendulum.jl @@ -38,12 +38,11 @@ function RL.Experiment( create_policy_net() = NeuralNetworkApproximator( model = GaussianNetwork( - pre = Chain( - Dense(ns, 30, relu), - Dense(30, 30, relu), - ), + pre = Chain(Dense(ns, 30, relu), Dense(30, 30, relu)), μ = Chain(Dense(30, na, init = init)), - logσ = Chain(Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init)), + logσ = Chain( + Dense(30, na, x -> clamp.(x, typeof(x)(-10), typeof(x)(2)), init = init), + ), ), optimizer = ADAM(0.003), ) @@ -69,7 +68,7 @@ function RL.Experiment( α = 0.2f0, batch_size = 64, start_steps = 1000, - start_policy = RandomPolicy(Space([-1.0..1.0 for _ in 1:na]); rng = rng), + start_policy = RandomPolicy(Space([-1.0 .. 1.0 for _ in 1:na]); rng = rng), update_after = 1000, update_freq = 1, automatic_entropy_tuning = true, @@ -84,7 +83,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "# Play Pendulum with SAC") end diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl index 49bc26c64..5d24117b4 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_TD3_Pendulum.jl @@ -69,7 +69,7 @@ function RL.Experiment( ρ = 0.99f0, batch_size = 64, start_steps = 1000, - start_policy = RandomPolicy(-1.0..1.0; rng = rng), + start_policy = RandomPolicy(-1.0 .. 1.0; rng = rng), update_after = 1000, update_freq = 1, policy_freq = 2, @@ -86,7 +86,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(10_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(10_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "# Play Pendulum with TD3") end diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl index 506bba5d1..005c29f33 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_VMPO_CartPole.jl @@ -55,7 +55,7 @@ function RL.Experiment( ), ) - stop_condition = StopAfterStep(50_000, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterStep(50_000, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() Experiment(agent, env, stop_condition, hook, "# VMPO with CartPole") diff --git a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl index 87130f3dc..8cfbe125d 100644 --- a/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl +++ b/docs/experiments/experiments/Policy Gradient/JuliaRL_VPG_CartPole.jl @@ -49,7 +49,7 @@ function RL.Experiment( ), trajectory = ElasticSARTTrajectory(state = Vector{Float32} => (ns,)), ) - stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI")) + stop_condition = StopAfterEpisode(500, is_show_progress = !haskey(ENV, "CI")) hook = TotalRewardPerEpisode() description = "# Play CartPole with VPG" diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl index 7bf2add71..56eb5d90e 100644 --- a/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl +++ b/docs/experiments/experiments/Policy Gradient/rlpyt_A2C_Atari.jl @@ -83,7 +83,7 @@ function RL.Experiment( hook = ComposedHook( total_batch_reward_per_episode, batch_steps_per_episode, - DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env + DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env learner = agent.policy.policy.learner with_logger(lg) do @info "training" loss = learner.loss actor_loss = learner.actor_loss critic_loss = @@ -94,20 +94,22 @@ function RL.Experiment( DoEveryNStep() do t, agent, env with_logger(lg) do rewards = [ - total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i]) + total_batch_reward_per_episode.rewards[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(rewards) > 0 @info "training" rewards = mean(rewards) log_step_increment = 0 end steps = [ - batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i]) + batch_steps_per_episode.steps[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(steps) > 0 @info "training" steps = mean(steps) log_step_increment = 0 end end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." h = TotalBatchOriginalRewardPerEpisode(N_ENV) s = @elapsed run( diff --git a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl index e9eb574a3..4b33dc27f 100644 --- a/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl +++ b/docs/experiments/experiments/Policy Gradient/rlpyt_PPO_Atari.jl @@ -85,7 +85,7 @@ function RL.Experiment( hook = ComposedHook( total_batch_reward_per_episode, batch_steps_per_episode, - DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env + DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env p = agent.policy with_logger(lg) do @info "training" loss = mean(p.loss) actor_loss = mean(p.actor_loss) critic_loss = @@ -93,7 +93,7 @@ function RL.Experiment( mean(p.norm) log_step_increment = UPDATE_FREQ end end, - DoEveryNStep(;n=UPDATE_FREQ) do t, agent, env + DoEveryNStep(; n = UPDATE_FREQ) do t, agent, env decay = (N_TRAINING_STEPS - t) / N_TRAINING_STEPS agent.policy.approximator.optimizer.eta = INIT_LEARNING_RATE * decay agent.policy.clip_range = INIT_CLIP_RANGE * Float32(decay) @@ -101,20 +101,22 @@ function RL.Experiment( DoEveryNStep() do t, agent, env with_logger(lg) do rewards = [ - total_batch_reward_per_episode.rewards[i][end] for i in 1:length(env) if is_terminated(env[i]) + total_batch_reward_per_episode.rewards[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(rewards) > 0 @info "training" rewards = mean(rewards) log_step_increment = 0 end steps = [ - batch_steps_per_episode.steps[i][end] for i in 1:length(env) if is_terminated(env[i]) + batch_steps_per_episode.steps[i][end] for + i in 1:length(env) if is_terminated(env[i]) ] if length(steps) > 0 @info "training" steps = mean(steps) log_step_increment = 0 end end end, - DoEveryNStep(;n=EVALUATION_FREQ) do t, agent, env + DoEveryNStep(; n = EVALUATION_FREQ) do t, agent, env @info "evaluating agent at $t step..." ## switch to GreedyExplorer? h = TotalBatchRewardPerEpisode(N_ENV) diff --git a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl index f31098c00..a6bfc1d4b 100644 --- a/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl +++ b/docs/experiments/experiments/Search/JuliaRL_Minimax_OpenSpiel.jl @@ -18,7 +18,13 @@ function RL.Experiment(::Val{:JuliaRL}, ::Val{:Minimax}, ::Val{:OpenSpiel}, game ) hooks = MultiAgentHook(0 => TotalRewardPerEpisode(), 1 => TotalRewardPerEpisode()) description = "# Play `$game` in OpenSpiel with Minimax" - Experiment(agents, env, StopAfterEpisode(1, is_show_progress=!haskey(ENV, "CI")), hooks, description) + Experiment( + agents, + env, + StopAfterEpisode(1, is_show_progress = !haskey(ENV, "CI")), + hooks, + description, + ) end using Plots diff --git a/docs/homepage/utils.jl b/docs/homepage/utils.jl index 816810d71..64787a91a 100644 --- a/docs/homepage/utils.jl +++ b/docs/homepage/utils.jl @@ -5,7 +5,7 @@ html(s) = "\n~~~$s~~~\n" function hfun_adddescription() d = locvar(:description) - isnothing(d) ? "" : F.fd2html(d, internal=true) + isnothing(d) ? "" : F.fd2html(d, internal = true) end function hfun_frontmatter() @@ -28,7 +28,7 @@ function hfun_byline() if isnothing(fm) "" else - "" + "" end end @@ -62,7 +62,7 @@ function hfun_appendix() if isfile(bib_in_cur_folder) bib_resolved = F.parse_rpath("/" * bib_in_cur_folder) else - bib_resolved = F.parse_rpath(bib; canonical=false, code=true) + bib_resolved = F.parse_rpath(bib; canonical = false, code = true) end bib = "" end @@ -74,7 +74,7 @@ function hfun_appendix() """ end -function lx_dcite(lxc,_) +function lx_dcite(lxc, _) content = F.content(lxc.braces[1]) "" |> html end @@ -92,7 +92,7 @@ end """ Possible layouts: """ -function lx_dfig(lxc,lxd) +function lx_dfig(lxc, lxd) content = F.content(lxc.braces[1]) info = split(content, ';') layout = info[1] @@ -111,7 +111,7 @@ function lx_dfig(lxc,lxd) end # (case 3) assume it is generated by code - src = F.parse_rpath(src; canonical=false, code=true) + src = F.parse_rpath(src; canonical = false, code = true) # !!! directly take from `lx_fig` in Franklin.jl fdir, fext = splitext(src) @@ -122,11 +122,10 @@ function lx_dfig(lxc,lxd) # then in both cases there can be a relative path set but the user may mean # that it's in the subfolder /output/ (if generated by code) so should look # both in the relpath and if not found and if /output/ not already last dir - candext = ifelse(isempty(fext), - (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,)) - for ext ∈ candext + candext = ifelse(isempty(fext), (".png", ".jpeg", ".jpg", ".svg", ".gif"), (fext,)) + for ext in candext candpath = fdir * ext - syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) + syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) isfile(syspath) && return dfigure(layout, candpath, caption) end # now try in the output dir just in case (provided we weren't already @@ -134,20 +133,20 @@ function lx_dfig(lxc,lxd) p1, p2 = splitdir(fdir) @debug "TEST" p1 p2 if splitdir(p1)[2] != "output" - for ext ∈ candext + for ext in candext candpath = joinpath(splitdir(p1)[1], "output", p2 * ext) - syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) + syspath = joinpath(F.PATHS[:site], split(candpath, '/')...) isfile(syspath) && return dfigure(layout, candpath, caption) end end end -function lx_aside(lxc,lxd) +function lx_aside(lxc, lxd) content = F.reprocess(F.content(lxc.braces[1]), lxd) "" |> html end -function lx_footnote(lxc,lxd) +function lx_footnote(lxc, lxd) content = F.reprocess(F.content(lxc.braces[1]), lxd) # workaround if startswith(content, "

") @@ -156,7 +155,7 @@ function lx_footnote(lxc,lxd) "$content" |> html end -function lx_appendix(lxc,lxd) +function lx_appendix(lxc, lxd) content = F.reprocess(F.content(lxc.braces[1]), lxd) "$content" |> html -end \ No newline at end of file +end diff --git a/docs/make.jl b/docs/make.jl index 8f947388d..4d7fd750d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -15,11 +15,7 @@ end experiments, postprocess_cb, experiments_assets = makedemos("experiments") -assets = [ - "assets/favicon.ico", - "assets/custom.css", - experiments_assets -] +assets = ["assets/favicon.ico", "assets/custom.css", experiments_assets] makedocs( modules = [ @@ -56,7 +52,7 @@ makedocs( "RLZoo" => "rlzoo.md", "RLDatasets" => "rldatasets.md", ], - ] + ], ) postprocess_cb() diff --git a/src/DistributedReinforcementLearning/src/actor_model.jl b/src/DistributedReinforcementLearning/src/actor_model.jl index 3a2414c5c..af6b05214 100644 --- a/src/DistributedReinforcementLearning/src/actor_model.jl +++ b/src/DistributedReinforcementLearning/src/actor_model.jl @@ -1,19 +1,12 @@ -export AbstractMessage, - StartMsg, - StopMsg, - PingMsg, - PongMsg, - ProxyMsg, - actor, - self +export AbstractMessage, StartMsg, StopMsg, PingMsg, PongMsg, ProxyMsg, actor, self abstract type AbstractMessage end -struct StartMsg{A, K} <: AbstractMessage +struct StartMsg{A,K} <: AbstractMessage args::A kwargs::K - StartMsg(args...;kwargs...) = new{typeof(args), typeof(kwargs)}(args, kwargs) + StartMsg(args...; kwargs...) = new{typeof(args),typeof(kwargs)}(args, kwargs) end struct StopMsg <: AbstractMessage end @@ -45,9 +38,9 @@ const DEFAULT_MAILBOX_SIZE = 32 Create a task to handle messages one-by-one by calling `f(msg)`. A mailbox (`RemoteChannel`) is returned. """ -function actor(f;sz=DEFAULT_MAILBOX_SIZE) +function actor(f; sz = DEFAULT_MAILBOX_SIZE) RemoteChannel() do - Channel(sz;spawn=true) do ch + Channel(sz; spawn = true) do ch task_local_storage("MAILBOX", RemoteChannel(() -> ch)) while true msg = take!(ch) diff --git a/src/DistributedReinforcementLearning/src/core.jl b/src/DistributedReinforcementLearning/src/core.jl index 44ed67937..34a04571a 100644 --- a/src/DistributedReinforcementLearning/src/core.jl +++ b/src/DistributedReinforcementLearning/src/core.jl @@ -51,7 +51,7 @@ Base.@kwdef struct Trainer{P,S} sealer::S = deepcopy end -Trainer(p) = Trainer(;policy=p) +Trainer(p) = Trainer(; policy = p) function (trainer::Trainer)(msg::BatchDataMsg) update!(trainer.policy, msg.data) @@ -94,7 +94,7 @@ mutable struct Worker end function (w::Worker)(msg::StartMsg) - w.experiment = w.init(msg.args...;msg.kwargs...) + w.experiment = w.init(msg.args...; msg.kwargs...) w.task = Threads.@spawn run(w.experiment) end @@ -128,7 +128,7 @@ end function (wp::WorkerProxy)(::FetchParamMsg) if !wp.is_fetch_msg_sent[] put!(wp.target, FetchParamMsg(self())) - wp.is_fetch_msg_sent[] = true + wp.is_fetch_msg_sent[] = true end end @@ -172,9 +172,12 @@ function (orc::Orchestrator)(msg::InsertTrajectoryMsg) put!(orc.trajectory_proxy, BatchSampleMsg(orc.trainer)) L.n_sample += 1 if L.n_sample == (L.n_load + 1) * L.sample_load_ratio - put!(orc.trajectory_proxy, ProxyMsg(to=orc.trainer, msg=FetchParamMsg(orc.worker))) + put!( + orc.trajectory_proxy, + ProxyMsg(to = orc.trainer, msg = FetchParamMsg(orc.worker)), + ) L.n_load += 1 end end end -end \ No newline at end of file +end diff --git a/src/DistributedReinforcementLearning/src/extensions.jl b/src/DistributedReinforcementLearning/src/extensions.jl index efeda428b..1c6f32e9d 100644 --- a/src/DistributedReinforcementLearning/src/extensions.jl +++ b/src/DistributedReinforcementLearning/src/extensions.jl @@ -76,4 +76,4 @@ function (hook::FetchParamsHook)(::PostActStage, agent, env) end end end -end \ No newline at end of file +end diff --git a/src/DistributedReinforcementLearning/test/actor.jl b/src/DistributedReinforcementLearning/test/actor.jl index 34881c5b1..e2cc9ae87 100644 --- a/src/DistributedReinforcementLearning/test/actor.jl +++ b/src/DistributedReinforcementLearning/test/actor.jl @@ -1,53 +1,53 @@ @testset "basic tests" begin -Base.@kwdef mutable struct TestActor - state::Union{Nothing, Int} = nothing -end + Base.@kwdef mutable struct TestActor + state::Union{Nothing,Int} = nothing + end -struct CurrentStateMsg <: AbstractMessage - state -end + struct CurrentStateMsg <: AbstractMessage + state::Any + end -Base.@kwdef struct ReadStateMsg <: AbstractMessage - from = self() -end + Base.@kwdef struct ReadStateMsg <: AbstractMessage + from = self() + end -struct IncMsg <: AbstractMessage end -struct DecMsg <: AbstractMessage end + struct IncMsg <: AbstractMessage end + struct DecMsg <: AbstractMessage end -(x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1] -(x::TestActor)(msg::StopMsg) = x.state = nothing -(x::TestActor)(::IncMsg) = x.state += 1 -(x::TestActor)(::DecMsg) = x.state -= 1 -(x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state)) + (x::TestActor)(msg::StartMsg{Tuple{Int}}) = x.state = msg.args[1] + (x::TestActor)(msg::StopMsg) = x.state = nothing + (x::TestActor)(::IncMsg) = x.state += 1 + (x::TestActor)(::DecMsg) = x.state -= 1 + (x::TestActor)(msg::ReadStateMsg) = put!(msg.from, CurrentStateMsg(x.state)) -x = actor(TestActor()) -put!(x, StartMsg(0)) + x = actor(TestActor()) + put!(x, StartMsg(0)) -put!(x, ReadStateMsg()) -@test take!(self()).state == 0 + put!(x, ReadStateMsg()) + @test take!(self()).state == 0 -@sync begin - for _ in 1:100 - Threads.@spawn put!(x, IncMsg()) - Threads.@spawn put!(x, DecMsg()) - end - for _ in 1:10 - for _ in 1:10 + @sync begin + for _ in 1:100 Threads.@spawn put!(x, IncMsg()) + Threads.@spawn put!(x, DecMsg()) end for _ in 1:10 - Threads.@spawn put!(x, DecMsg()) + for _ in 1:10 + Threads.@spawn put!(x, IncMsg()) + end + for _ in 1:10 + Threads.@spawn put!(x, DecMsg()) + end end end -end -put!(x, ReadStateMsg()) -@test take!(self()).state == 0 + put!(x, ReadStateMsg()) + @test take!(self()).state == 0 -y = actor(TestActor()) -put!(x, ProxyMsg(;to=y,msg=StartMsg(0))) -put!(x, ProxyMsg(;to=y,msg=ReadStateMsg())) -@test take!(self()).state == 0 + y = actor(TestActor()) + put!(x, ProxyMsg(; to = y, msg = StartMsg(0))) + put!(x, ProxyMsg(; to = y, msg = ReadStateMsg())) + @test take!(self()).state == 0 -end \ No newline at end of file +end diff --git a/src/DistributedReinforcementLearning/test/core.jl b/src/DistributedReinforcementLearning/test/core.jl index 25d5c2c49..d4a196f37 100644 --- a/src/DistributedReinforcementLearning/test/core.jl +++ b/src/DistributedReinforcementLearning/test/core.jl @@ -1,181 +1,202 @@ @testset "core.jl" begin -@testset "Trainer" begin - _trainer = Trainer(; - policy=BasicDQNLearner( - approximator = NeuralNetworkApproximator( - model = Chain( - Dense(4, 128, relu; initW = glorot_uniform), - Dense(128, 128, relu; initW = glorot_uniform), - Dense(128, 2; initW = glorot_uniform), - ) |> cpu, - optimizer = ADAM(), + @testset "Trainer" begin + _trainer = Trainer(; + policy = BasicDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( + Dense(4, 128, relu; initW = glorot_uniform), + Dense(128, 128, relu; initW = glorot_uniform), + Dense(128, 2; initW = glorot_uniform), + ) |> cpu, + optimizer = ADAM(), + ), + loss_func = huber_loss, ), - loss_func = huber_loss, ) - ) - trainer = actor(_trainer) + trainer = actor(_trainer) - put!(trainer, FetchParamMsg()) - ps = take!(self()) - original_sum = sum(sum, ps.data) + put!(trainer, FetchParamMsg()) + ps = take!(self()) + original_sum = sum(sum, ps.data) - for x in ps.data - fill!(x, 0.) - end + for x in ps.data + fill!(x, 0.0) + end - put!(trainer, FetchParamMsg()) - ps = take!(self()) - new_sum = sum(sum, ps.data) + put!(trainer, FetchParamMsg()) + ps = take!(self()) + new_sum = sum(sum, ps.data) - # make sure no state sharing between messages - @test original_sum == new_sum + # make sure no state sharing between messages + @test original_sum == new_sum - batch_data = ( - state = rand(4, 32), - action = rand(1:2, 32), - reward = rand(32), - terminal = rand(Bool, 32), - next_state = rand(4,32), - next_action = rand(1:2, 32) - ) + batch_data = ( + state = rand(4, 32), + action = rand(1:2, 32), + reward = rand(32), + terminal = rand(Bool, 32), + next_state = rand(4, 32), + next_action = rand(1:2, 32), + ) - put!(trainer, BatchDataMsg(batch_data)) + put!(trainer, BatchDataMsg(batch_data)) - put!(trainer, FetchParamMsg()) - ps = take!(self()) - updated_sum = sum(sum, ps.data) - @test original_sum != updated_sum -end + put!(trainer, FetchParamMsg()) + ps = take!(self()) + updated_sum = sum(sum, ps.data) + @test original_sum != updated_sum + end -@testset "TrajectoryManager" begin - _trajectory_proxy = TrajectoryManager( - trajectory = CircularSARTSATrajectory(;capacity=5, state_type=Any, ), - sampler = UniformBatchSampler(3), - inserter = NStepInserter(), - ) + @testset "TrajectoryManager" begin + _trajectory_proxy = TrajectoryManager( + trajectory = CircularSARTSATrajectory(; capacity = 5, state_type = Any), + sampler = UniformBatchSampler(3), + inserter = NStepInserter(), + ) - trajectory_proxy = actor(_trajectory_proxy) + trajectory_proxy = actor(_trajectory_proxy) - # 1. init traj for testing - traj = CircularCompactSARTSATrajectory( - capacity = 2, - state_type = Float32, - state_size = (4,), - ) - push!(traj;state=rand(Float32, 4), action=rand(1:2)) - push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2)) - push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2)) + # 1. init traj for testing + traj = CircularCompactSARTSATrajectory( + capacity = 2, + state_type = Float32, + state_size = (4,), + ) + push!(traj; state = rand(Float32, 4), action = rand(1:2)) + push!( + traj; + reward = rand(), + terminal = rand(Bool), + state = rand(Float32, 4), + action = rand(1:2), + ) + push!( + traj; + reward = rand(), + terminal = rand(Bool), + state = rand(Float32, 4), + action = rand(1:2), + ) - # 2. insert - put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here + # 2. insert + put!(trajectory_proxy, InsertTrajectoryMsg(deepcopy(traj))) #!!! we used deepcopy here - # 3. make sure the above message is already been handled - put!(trajectory_proxy, PingMsg()) - take!(self()) + # 3. make sure the above message is already been handled + put!(trajectory_proxy, PingMsg()) + take!(self()) - # 4. test that updating traj will not affect data in trajectory_proxy - s_tp = _trajectory_proxy.trajectory[:state] - s_traj = traj[:state] + # 4. test that updating traj will not affect data in trajectory_proxy + s_tp = _trajectory_proxy.trajectory[:state] + s_traj = traj[:state] - @test s_tp[1] == s_traj[:, 1] + @test s_tp[1] == s_traj[:, 1] - push!(traj;reward=rand(), terminal=rand(Bool),state=rand(Float32, 4), action=rand(1:2)) + push!( + traj; + reward = rand(), + terminal = rand(Bool), + state = rand(Float32, 4), + action = rand(1:2), + ) - @test s_tp[1] != s_traj[:, 1] + @test s_tp[1] != s_traj[:, 1] - s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler) - fill!(s[:state], 0.) - @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy -end + s = sample(_trajectory_proxy.trajectory, _trajectory_proxy.sampler) + fill!(s[:state], 0.0) + @test any(x -> sum(x) == 0, s_tp) == false # make sure sample create an independent copy + end -@testset "Worker" begin - _worker = Worker() do worker_proxy - Experiment( - Agent( - policy = StaticPolicy( + @testset "Worker" begin + _worker = Worker() do worker_proxy + Experiment( + Agent( + policy = StaticPolicy( QBasedPolicy( - learner = BasicDQNLearner( - approximator = NeuralNetworkApproximator( - model = Chain( - Dense(4, 128, relu; initW = glorot_uniform), - Dense(128, 128, relu; initW = glorot_uniform), - Dense(128, 2; initW = glorot_uniform), - ) |> cpu, - optimizer = ADAM(), + learner = BasicDQNLearner( + approximator = NeuralNetworkApproximator( + model = Chain( + Dense(4, 128, relu; initW = glorot_uniform), + Dense(128, 128, relu; initW = glorot_uniform), + Dense(128, 2; initW = glorot_uniform), + ) |> cpu, + optimizer = ADAM(), + ), + loss_func = huber_loss, + ), + explorer = EpsilonGreedyExplorer( + kind = :exp, + ϵ_stable = 0.01, + decay_steps = 500, ), - loss_func = huber_loss, - ), - explorer = EpsilonGreedyExplorer( - kind = :exp, - ϵ_stable = 0.01, - decay_steps = 500, ), ), + trajectory = CircularCompactSARTSATrajectory( + capacity = 10, + state_type = Float32, + state_size = (4,), + ), ), - trajectory = CircularCompactSARTSATrajectory( - capacity = 10, - state_type = Float32, - state_size = (4,), + CartPoleEnv(; T = Float32), + ComposedStopCondition(StopAfterStep(1_000), StopSignal()), + ComposedHook( + UploadTrajectoryEveryNStep( + mailbox = worker_proxy, + n = 10, + sealer = x -> InsertTrajectoryMsg(deepcopy(x)), + ), + LoadParamsHook(), + TotalRewardPerEpisode(), ), - ), - CartPoleEnv(; T = Float32), - ComposedStopCondition( - StopAfterStep(1_000), - StopSignal(), - ), - ComposedHook( - UploadTrajectoryEveryNStep(mailbox=worker_proxy, n=10, sealer=x -> InsertTrajectoryMsg(deepcopy(x))), - LoadParamsHook(), - TotalRewardPerEpisode(), - ), - "experimenting..." - ) - end + "experimenting...", + ) + end - worker = actor(_worker) - tmp_mailbox = Channel(100) - put!(worker, StartMsg(tmp_mailbox)) -end - -@testset "WorkerProxy" begin - target = RemoteChannel(() -> Channel(10)) - workers = [RemoteChannel(()->Channel(10)) for _ in 1:10] - _wp = WorkerProxy(workers) - wp = actor(_wp) - - put!(wp, StartMsg(target)) - for w in workers - # @test take!(w).args[1] === wp - @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) === Distributed.channel_from_id(remoteref_id(wp)) + worker = actor(_worker) + tmp_mailbox = Channel(100) + put!(worker, StartMsg(tmp_mailbox)) end - msg = InsertTrajectoryMsg(1) - put!(wp, msg) - @test take!(target) === msg - - for w in workers - put!(wp, FetchParamMsg(w)) + @testset "WorkerProxy" begin + target = RemoteChannel(() -> Channel(10)) + workers = [RemoteChannel(() -> Channel(10)) for _ in 1:10] + _wp = WorkerProxy(workers) + wp = actor(_wp) + + put!(wp, StartMsg(target)) + for w in workers + # @test take!(w).args[1] === wp + @test Distributed.channel_from_id(remoteref_id(take!(w).args[1])) === + Distributed.channel_from_id(remoteref_id(wp)) + end + + msg = InsertTrajectoryMsg(1) + put!(wp, msg) + @test take!(target) === msg + + for w in workers + put!(wp, FetchParamMsg(w)) + end + # @test take!(target).from === wp + @test Distributed.channel_from_id(remoteref_id(take!(target).from)) === + Distributed.channel_from_id(remoteref_id(wp)) + + # make sure target only received one FetchParamMsg + msg = PingMsg() + put!(target, msg) + @test take!(target) === msg + + msg = LoadParamMsg([]) + put!(wp, msg) + for w in workers + @test take!(w) === msg + end end - # @test take!(target).from === wp - @test Distributed.channel_from_id(remoteref_id(take!(target).from)) === Distributed.channel_from_id(remoteref_id(wp)) - - # make sure target only received one FetchParamMsg - msg = PingMsg() - put!(target, msg) - @test take!(target) === msg - - msg = LoadParamMsg([]) - put!(wp, msg) - for w in workers - @test take!(w) === msg + + @testset "Orchestrator" begin + # TODO + # Add an integration test end -end -@testset "Orchestrator" begin - # TODO - # Add an integration test end - -end \ No newline at end of file diff --git a/src/DistributedReinforcementLearning/test/runtests.jl b/src/DistributedReinforcementLearning/test/runtests.jl index f9f572d55..a73cd0521 100644 --- a/src/DistributedReinforcementLearning/test/runtests.jl +++ b/src/DistributedReinforcementLearning/test/runtests.jl @@ -9,7 +9,7 @@ using Flux @testset "DistributedReinforcementLearning.jl" begin -include("actor.jl") -include("core.jl") + include("actor.jl") + include("core.jl") end diff --git a/src/ReinforcementLearningBase/src/CommonRLInterface.jl b/src/ReinforcementLearningBase/src/CommonRLInterface.jl index af86a2d83..28c73ec40 100644 --- a/src/ReinforcementLearningBase/src/CommonRLInterface.jl +++ b/src/ReinforcementLearningBase/src/CommonRLInterface.jl @@ -41,7 +41,8 @@ end # !!! may need to be extended by user CRL.@provide CRL.observe(env::CommonRLEnv) = state(env.env) -CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = !isnothing(find_state_style(env.env, InternalState)) +CRL.provided(::typeof(CRL.state), env::CommonRLEnv) = + !isnothing(find_state_style(env.env, InternalState)) CRL.state(env::CommonRLEnv) = state(env.env, find_state_style(env.env, InternalState)) CRL.@provide CRL.clone(env::CommonRLEnv) = CommonRLEnv(copy(env.env)) @@ -94,4 +95,4 @@ ActionStyle(env::RLBaseEnv) = CRL.provided(CRL.valid_actions, env.env) ? FullActionSet() : MinimalActionSet() current_player(env::RLBaseEnv) = CRL.player(env.env) -players(env::RLBaseEnv) = CRL.players(env.env) \ No newline at end of file +players(env::RLBaseEnv) = CRL.players(env.env) diff --git a/src/ReinforcementLearningBase/src/interface.jl b/src/ReinforcementLearningBase/src/interface.jl index 7c7b16888..3f64d6ad7 100644 --- a/src/ReinforcementLearningBase/src/interface.jl +++ b/src/ReinforcementLearningBase/src/interface.jl @@ -410,12 +410,13 @@ Make an independent copy of `env`, !!! warning Only check the state of all players in the env. """ -function Base.:(==)(env1::T, env2::T) where T<:AbstractEnv +function Base.:(==)(env1::T, env2::T) where {T<:AbstractEnv} len = length(players(env1)) - len == length(players(env2)) && - all(state(env1, player) == state(env2, player) for player in players(env1)) + len == length(players(env2)) && + all(state(env1, player) == state(env2, player) for player in players(env1)) end -Base.hash(env::AbstractEnv, h::UInt) = hash([state(env, player) for player in players(env)], h) +Base.hash(env::AbstractEnv, h::UInt) = + hash([state(env, player) for player in players(env)], h) @api nameof(env::AbstractEnv) = nameof(typeof(env)) diff --git a/src/ReinforcementLearningBase/test/CommonRLInterface.jl b/src/ReinforcementLearningBase/test/CommonRLInterface.jl index fc38a102b..7b32dbe08 100644 --- a/src/ReinforcementLearningBase/test/CommonRLInterface.jl +++ b/src/ReinforcementLearningBase/test/CommonRLInterface.jl @@ -1,34 +1,34 @@ @testset "CommonRLInterface" begin -@testset "MDPEnv" begin - struct RLTestMDP <: MDP{Int, Int} end + @testset "MDPEnv" begin + struct RLTestMDP <: MDP{Int,Int} end - POMDPs.actions(m::RLTestMDP) = [-1, 1] - POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) - POMDPs.initialstate(m::RLTestMDP) = Deterministic(1) - POMDPs.isterminal(m::RLTestMDP, s) = s == 3 - POMDPs.reward(m::RLTestMDP, s, a, sp) = sp - POMDPs.states(m::RLTestMDP) = 1:3 + POMDPs.actions(m::RLTestMDP) = [-1, 1] + POMDPs.transition(m::RLTestMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) + POMDPs.initialstate(m::RLTestMDP) = Deterministic(1) + POMDPs.isterminal(m::RLTestMDP, s) = s == 3 + POMDPs.reward(m::RLTestMDP, s, a, sp) = sp + POMDPs.states(m::RLTestMDP) = 1:3 - env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP())) - RLBase.test_runnable!(env) -end + env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestMDP())) + RLBase.test_runnable!(env) + end -@testset "POMDPEnv" begin + @testset "POMDPEnv" begin - struct RLTestPOMDP <: POMDP{Int, Int, Int} end + struct RLTestPOMDP <: POMDP{Int,Int,Int} end - POMDPs.actions(m::RLTestPOMDP) = [-1, 1] - POMDPs.states(m::RLTestPOMDP) = 1:3 - POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) - POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1) - POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1) - POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1) - POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3 - POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp - POMDPs.observations(m::RLTestPOMDP) = 2:4 + POMDPs.actions(m::RLTestPOMDP) = [-1, 1] + POMDPs.states(m::RLTestPOMDP) = 1:3 + POMDPs.transition(m::RLTestPOMDP, s, a) = Deterministic(clamp(s + a, 1, 3)) + POMDPs.observation(m::RLTestPOMDP, s, a, sp) = Deterministic(sp + 1) + POMDPs.initialstate(m::RLTestPOMDP) = Deterministic(1) + POMDPs.initialobs(m::RLTestPOMDP, s) = Deterministic(s + 1) + POMDPs.isterminal(m::RLTestPOMDP, s) = s == 3 + POMDPs.reward(m::RLTestPOMDP, s, a, sp) = sp + POMDPs.observations(m::RLTestPOMDP) = 2:4 - env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP())) + env = convert(RLBase.AbstractEnv, convert(CRL.AbstractEnv, RLTestPOMDP())) - RLBase.test_runnable!(env) + RLBase.test_runnable!(env) + end end -end \ No newline at end of file diff --git a/src/ReinforcementLearningBase/test/runtests.jl b/src/ReinforcementLearningBase/test/runtests.jl index 6b44a29b8..d4f743f68 100644 --- a/src/ReinforcementLearningBase/test/runtests.jl +++ b/src/ReinforcementLearningBase/test/runtests.jl @@ -8,5 +8,5 @@ using POMDPs using POMDPModelTools: Deterministic @testset "ReinforcementLearningBase" begin -include("CommonRLInterface.jl") -end \ No newline at end of file + include("CommonRLInterface.jl") +end diff --git a/src/ReinforcementLearningCore/src/core/experiment.jl b/src/ReinforcementLearningCore/src/core/experiment.jl index 8240f36e3..f753f6d62 100644 --- a/src/ReinforcementLearningCore/src/core/experiment.jl +++ b/src/ReinforcementLearningCore/src/core/experiment.jl @@ -24,7 +24,7 @@ end function Base.show(io::IO, x::Experiment) display(Markdown.parse(x.description)) - AbstractTrees.print_tree(io, StructTree(x), maxdepth=get(io, :max_depth, 10)) + AbstractTrees.print_tree(io, StructTree(x), maxdepth = get(io, :max_depth, 10)) end macro experiment_cmd(s) @@ -51,7 +51,7 @@ function Experiment(s::String) ) end -function Base.run(x::Experiment; describe::Bool=true) +function Base.run(x::Experiment; describe::Bool = true) describe && display(Markdown.parse(x.description)) run(x.policy, x.env, x.stop_condition, x.hook) x diff --git a/src/ReinforcementLearningCore/src/core/hooks.jl b/src/ReinforcementLearningCore/src/core/hooks.jl index 11c65e686..a8a58be3e 100644 --- a/src/ReinforcementLearningCore/src/core/hooks.jl +++ b/src/ReinforcementLearningCore/src/core/hooks.jl @@ -13,7 +13,7 @@ export AbstractHook, UploadTrajectoryEveryNStep, MultiAgentHook -using UnicodePlots:lineplot, lineplot! +using UnicodePlots: lineplot, lineplot! using Statistics """ @@ -155,7 +155,14 @@ end function (hook::TotalRewardPerEpisode)(::PostExperimentStage, agent, env) if hook.is_display_on_exit - println(lineplot(hook.rewards, title="Total reward per episode", xlabel="Episode", ylabel="Score")) + println( + lineplot( + hook.rewards, + title = "Total reward per episode", + xlabel = "Episode", + ylabel = "Score", + ), + ) end end @@ -178,8 +185,12 @@ which return a `Vector` of rewards (a typical case with `MultiThreadEnv`). If `is_display_on_exit` is set to `true`, a ribbon plot will be shown to reflect the mean and std of rewards. """ -function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit=true) - TotalBatchRewardPerEpisode([Float64[] for _ in 1:batch_size], zeros(batch_size), is_display_on_exit) +function TotalBatchRewardPerEpisode(batch_size::Int; is_display_on_exit = true) + TotalBatchRewardPerEpisode( + [Float64[] for _ in 1:batch_size], + zeros(batch_size), + is_display_on_exit, + ) end function (hook::TotalBatchRewardPerEpisode)(::PostActStage, agent, env) @@ -198,7 +209,12 @@ function (hook::TotalBatchRewardPerEpisode)(::PostExperimentStage, agent, env) n = minimum(map(length, hook.rewards)) m = mean([@view(x[1:n]) for x in hook.rewards]) s = std([@view(x[1:n]) for x in hook.rewards]) - p = lineplot(m, title="Avg total reward per episode", xlabel="Episode", ylabel="Score") + p = lineplot( + m, + title = "Avg total reward per episode", + xlabel = "Episode", + ylabel = "Score", + ) lineplot!(p, m .- s) lineplot!(p, m .+ s) println(p) @@ -288,8 +304,7 @@ end Execute `f(t, agent, env)` every `n` episode. `t` is a counter of episodes. """ -mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: - AbstractHook +mutable struct DoEveryNEpisode{S<:Union{PreEpisodeStage,PostEpisodeStage},F} <: AbstractHook f::F n::Int t::Int diff --git a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl index d641c615b..507907b8a 100644 --- a/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl +++ b/src/ReinforcementLearningCore/src/extensions/ArrayInterface.jl @@ -1,7 +1,10 @@ using ArrayInterface -function ArrayInterface.restructure(x::AbstractArray{T1, 0}, y::AbstractArray{T2, 0}) where {T1, T2} +function ArrayInterface.restructure( + x::AbstractArray{T1,0}, + y::AbstractArray{T2,0}, +) where {T1,T2} out = similar(x, eltype(y)) out .= y out -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl b/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl index dd4445eb7..d76810301 100644 --- a/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl +++ b/src/ReinforcementLearningCore/src/extensions/ReinforcementLearningBase.jl @@ -1,6 +1,6 @@ using AbstractTrees Base.show(io::IO, p::AbstractPolicy) = - AbstractTrees.print_tree(io, StructTree(p), maxdepth=get(io, :max_depth, 10)) + AbstractTrees.print_tree(io, StructTree(p), maxdepth = get(io, :max_depth, 10)) is_expand(::AbstractEnv) = false diff --git a/src/ReinforcementLearningCore/src/policies/agents/agent.jl b/src/ReinforcementLearningCore/src/policies/agents/agent.jl index fb7862052..59f415063 100644 --- a/src/ReinforcementLearningCore/src/policies/agents/agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agents/agent.jl @@ -139,7 +139,9 @@ function RLBase.update!( # TODO: how to inject a local rng here to avoid polluting the global rng s = policy isa NamedPolicy ? state(env, nameof(policy)) : state(env) - a = policy isa NamedPolicy ? rand(action_space(env, nameof(policy))) : rand(action_space(env)) + a = + policy isa NamedPolicy ? rand(action_space(env, nameof(policy))) : + rand(action_space(env)) push!(trajectory[:state], s) push!(trajectory[:action], a) if haskey(trajectory, :legal_actions_mask) diff --git a/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl b/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl index 7bc0ab255..6a496dd34 100644 --- a/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl +++ b/src/ReinforcementLearningCore/src/policies/agents/multi_agent.jl @@ -23,7 +23,8 @@ of `SIMULTANEOUS` style, please wrap it with [`SequentialEnv`](@ref) first. MultiAgentManager(policies...) = MultiAgentManager(Dict{Any,Any}(nameof(p) => p for p in policies)) -RLBase.prob(A::MultiAgentManager, env::AbstractEnv, args...) = prob(A[current_player(env)].policy, env, args...) +RLBase.prob(A::MultiAgentManager, env::AbstractEnv, args...) = + prob(A[current_player(env)].policy, env, args...) (A::MultiAgentManager)(env::AbstractEnv) = A(env, DynamicStyle(env)) diff --git a/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl b/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl index 215ba5639..a692037df 100644 --- a/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl +++ b/src/ReinforcementLearningCore/src/policies/agents/named_policy.jl @@ -42,6 +42,7 @@ function RLBase.update!( end -(p::NamedPolicy)(env::AbstractEnv) = DynamicStyle(env) == SEQUENTIAL ? p.policy(env) : p.policy(env, p.name) +(p::NamedPolicy)(env::AbstractEnv) = + DynamicStyle(env) == SEQUENTIAL ? p.policy(env) : p.policy(env, p.name) (p::NamedPolicy)(s::AbstractStage, env::AbstractEnv) = p.policy(s, env) (p::NamedPolicy)(s::PreActStage, env::AbstractEnv, action) = p.policy(s, env, action) diff --git a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl index 7818e24ce..0050151c9 100644 --- a/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl +++ b/src/ReinforcementLearningCore/src/policies/agents/trajectories/trajectory_extension.jl @@ -85,7 +85,11 @@ function fetch!(s::BatchSampler, t::AbstractTrajectory, inds::Vector{Int}) end end -function fetch!(s::BatchSampler{traces}, t::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory}, inds::Vector{Int}) where {traces} +function fetch!( + s::BatchSampler{traces}, + t::Union{CircularArraySARTTrajectory,CircularArraySLARTTrajectory}, + inds::Vector{Int}, +) where {traces} if traces == SARTS batch = NamedTuple{SARTS}(( (consecutive_view(t[x], inds) for x in SART)..., @@ -100,7 +104,7 @@ function fetch!(s::BatchSampler{traces}, t::Union{CircularArraySARTTrajectory, C else @error "unsupported traces $traces" end - + if isnothing(s.cache) s.cache = map(batch) do x convert(Array, x) @@ -151,7 +155,7 @@ end function fetch!( sampler::NStepBatchSampler{traces}, - traj::Union{CircularArraySARTTrajectory, CircularArraySLARTTrajectory}, + traj::Union{CircularArraySARTTrajectory,CircularArraySLARTTrajectory}, inds::Vector{Int}, ) where {traces} γ, n, bz, sz = sampler.γ, sampler.n, sampler.batch_size, sampler.stack_size diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl index f176fbdab..95e2b689a 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/abstract_learner.jl @@ -17,7 +17,7 @@ function (learner::AbstractLearner)(env) end function RLBase.priority(p::AbstractLearner, experience) end Base.show(io::IO, p::AbstractLearner) = - AbstractTrees.print_tree(io, StructTree(p), maxdepth=get(io, :max_depth, 10)) + AbstractTrees.print_tree(io, StructTree(p), maxdepth = get(io, :max_depth, 10)) function RLBase.update!( L::AbstractLearner, diff --git a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl index 926d0d40c..81ab4e41d 100644 --- a/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl +++ b/src/ReinforcementLearningCore/src/policies/q_based_policies/learners/approximators/neural_network_approximator.jl @@ -1,4 +1,5 @@ -export NeuralNetworkApproximator, ActorCritic, GaussianNetwork, DuelingNetwork, PerturbationNetwork +export NeuralNetworkApproximator, + ActorCritic, GaussianNetwork, DuelingNetwork, PerturbationNetwork export VAE, decode, vae_loss using Flux @@ -79,7 +80,7 @@ Base.@kwdef struct GaussianNetwork{P,U,S} pre::P = identity μ::U logσ::S - min_σ::Float32 = 0f0 + min_σ::Float32 = 0.0f0 max_σ::Float32 = Inf32 end @@ -92,15 +93,24 @@ This function is compatible with a multidimensional action space. When outputtin - `is_sampling::Bool=false`, whether to sample from the obtained normal distribution. - `is_return_log_prob::Bool=false`, whether to calculate the conditional probability of getting actions in the given state. """ -function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=false, is_return_log_prob::Bool=false) +function (model::GaussianNetwork)( + rng::AbstractRNG, + state; + is_sampling::Bool = false, + is_return_log_prob::Bool = false, +) x = model.pre(state) - μ, raw_logσ = model.μ(x), model.logσ(x) + μ, raw_logσ = model.μ(x), model.logσ(x) logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ)) if is_sampling σ = exp.(logσ) z = μ .+ σ .* send_to_device(device(model), randn(rng, Float32, size(μ))) if is_return_log_prob - logp_π = sum(normlogpdf(μ, σ, z) .- (2.0f0 .* (log(2.0f0) .- z .- softplus.(-2.0f0 .* z))), dims = 1) + logp_π = sum( + normlogpdf(μ, σ, z) .- + (2.0f0 .* (log(2.0f0) .- z .- softplus.(-2.0f0 .* z))), + dims = 1, + ) return tanh.(z), logp_π else return tanh.(z) @@ -110,16 +120,29 @@ function (model::GaussianNetwork)(rng::AbstractRNG, state; is_sampling::Bool=fal end end -function (model::GaussianNetwork)(state; is_sampling::Bool=false, is_return_log_prob::Bool=false) - model(Random.GLOBAL_RNG, state; is_sampling=is_sampling, is_return_log_prob=is_return_log_prob) +function (model::GaussianNetwork)( + state; + is_sampling::Bool = false, + is_return_log_prob::Bool = false, +) + model( + Random.GLOBAL_RNG, + state; + is_sampling = is_sampling, + is_return_log_prob = is_return_log_prob, + ) end function (model::GaussianNetwork)(state, action) x = model.pre(state) - μ, raw_logσ = model.μ(x), model.logσ(x) + μ, raw_logσ = model.μ(x), model.logσ(x) logσ = clamp.(raw_logσ, log(model.min_σ), log(model.max_σ)) σ = exp.(logσ) - logp_π = sum(normlogpdf(μ, σ, action) .- (2.0f0 .* (log(2.0f0) .- action .- softplus.(-2.0f0 .* action))), dims = 1) + logp_π = sum( + normlogpdf(μ, σ, action) .- + (2.0f0 .* (log(2.0f0) .- action .- softplus.(-2.0f0 .* action))), + dims = 1, + ) return logp_π end @@ -143,7 +166,7 @@ Flux.@functor DuelingNetwork function (m::DuelingNetwork)(state) x = m.base(state) val = m.val(x) - return val .+ m.adv(x) .- mean(m.adv(x), dims=1) + return val .+ m.adv(x) .- mean(m.adv(x), dims = 1) end ##### @@ -183,7 +206,7 @@ end """ VAE(;encoder, decoder, latent_dims) """ -Base.@kwdef struct VAE{E, D} +Base.@kwdef struct VAE{E,D} encoder::E decoder::D latent_dims::Int @@ -207,9 +230,14 @@ function reparamaterize(rng, μ, σ) return Float32(rand(rng, Normal(0, 1))) * σ + μ end -function decode(rng::AbstractRNG, model::VAE, state, z=nothing; is_normalize::Bool=true) +function decode(rng::AbstractRNG, model::VAE, state, z = nothing; is_normalize::Bool = true) if z === nothing - z = clamp.(randn(rng, Float32, (model.latent_dims, size(state)[2:end]...)), -0.5f0, 0.5f0) + z = + clamp.( + randn(rng, Float32, (model.latent_dims, size(state)[2:end]...)), + -0.5f0, + 0.5f0, + ) end a = model.decoder(vcat(state, z)) if is_normalize @@ -218,7 +246,7 @@ function decode(rng::AbstractRNG, model::VAE, state, z=nothing; is_normalize::Bo return a end -function decode(model::VAE, state, z=nothing; is_normalize::Bool=true) +function decode(model::VAE, state, z = nothing; is_normalize::Bool = true) decode(Random.GLOBAL_RNG, model, state, z; is_normalize) end diff --git a/src/ReinforcementLearningCore/test/components/trajectories.jl b/src/ReinforcementLearningCore/test/components/trajectories.jl index 8fb8eb9ae..c7c60e163 100644 --- a/src/ReinforcementLearningCore/test/components/trajectories.jl +++ b/src/ReinforcementLearningCore/test/components/trajectories.jl @@ -52,9 +52,9 @@ t = CircularArraySLARTTrajectory( capacity = 3, state = Vector{Int} => (4,), - legal_actions_mask = Vector{Bool} => (4, ), + legal_actions_mask = Vector{Bool} => (4,), ) - + # test instance type is same as type @test isa(t, CircularArraySLARTTrajectory) diff --git a/src/ReinforcementLearningCore/test/core/core.jl b/src/ReinforcementLearningCore/test/core/core.jl index fb6f5fde5..56a81b809 100644 --- a/src/ReinforcementLearningCore/test/core/core.jl +++ b/src/ReinforcementLearningCore/test/core/core.jl @@ -1,22 +1,18 @@ @testset "simple workflow" begin - env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy) + env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy) policy = RandomPolicy(action_space(env)) N_EPISODE = 10_000 hook = TotalRewardPerEpisode() run(policy, env, StopAfterEpisode(N_EPISODE), hook) - @test isapprox(sum(hook[]) / N_EPISODE, 21; atol=2) + @test isapprox(sum(hook[]) / N_EPISODE, 21; atol = 2) end @testset "multi agent" begin # https://github.com/JuliaReinforcementLearning/ReinforcementLearning.jl/issues/393 rps = RockPaperScissorsEnv() |> SequentialEnv - ma_policy = MultiAgentManager( - ( - NamedPolicy(p => RandomPolicy()) - for p in players(rps) - )... - ) + ma_policy = + MultiAgentManager((NamedPolicy(p => RandomPolicy()) for p in players(rps))...) run(ma_policy, rps, StopAfterEpisode(10)) end diff --git a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl index a2d657fa1..fc5b0bfc8 100644 --- a/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl +++ b/src/ReinforcementLearningCore/test/core/stop_conditions_test.jl @@ -1,5 +1,5 @@ @testset "test StopAfterNoImprovement" begin - env = StateTransformedEnv(CartPoleEnv{Float32}();state_mapping=deepcopy) + env = StateTransformedEnv(CartPoleEnv{Float32}(); state_mapping = deepcopy) policy = RandomPolicy(action_space(env)) total_reward_per_episode = TotalRewardPerEpisode() @@ -14,7 +14,8 @@ hook = ComposedHook(total_reward_per_episode) run(policy, env, stop_condition, hook) - @test argmax(total_reward_per_episode.rewards) + patience == length(total_reward_per_episode.rewards) + @test argmax(total_reward_per_episode.rewards) + patience == + length(total_reward_per_episode.rewards) end @testset "StopAfterNSeconds" begin diff --git a/src/ReinforcementLearningDatasets/docs/make.jl b/src/ReinforcementLearningDatasets/docs/make.jl index 5fdf1f97f..2a5eca27c 100644 --- a/src/ReinforcementLearningDatasets/docs/make.jl +++ b/src/ReinforcementLearningDatasets/docs/make.jl @@ -1,4 +1,4 @@ -push!(LOAD_PATH,"../src/") +push!(LOAD_PATH, "../src/") using Documenter, ReinforcementLearningDatasets -makedocs(sitename="ReinforcementLearningDatasets") \ No newline at end of file +makedocs(sitename = "ReinforcementLearningDatasets") diff --git a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl index 4ecc99e29..c0d813fc2 100644 --- a/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl +++ b/src/ReinforcementLearningDatasets/src/ReinforcementLearningDatasets.jl @@ -28,4 +28,4 @@ include("deep_ope/d4rl/d4rl_policy.jl") include("deep_ope/d4rl/evaluate.jl") -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl b/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl index 53e65c436..da84bba32 100644 --- a/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl +++ b/src/ReinforcementLearningDatasets/src/atari/atari_dataset.jl @@ -16,8 +16,8 @@ Represents an `Iterable` dataset with the following fields: - `meta::Dict`: the metadata provided along with the dataset. - `is_shuffle::Bool`: determines if the batches returned by `iterate` are shuffled. """ -struct AtariDataSet{T<:AbstractRNG} <:RLDataSet - dataset::Dict{Symbol, Any} +struct AtariDataSet{T<:AbstractRNG} <: RLDataSet + dataset::Dict{Symbol,Any} epochs::Vector{Int} repo::String length::Int @@ -61,29 +61,29 @@ function dataset( game::String, index::Int, epochs::Vector{Int}; - style::NTuple=SARTS, - repo::String="atari-replay-datasets", - rng::AbstractRNG=MersenneTwister(123), - is_shuffle::Bool=true, - batch_size::Int=256 + style::NTuple = SARTS, + repo::String = "atari-replay-datasets", + rng::AbstractRNG = MersenneTwister(123), + is_shuffle::Bool = true, + batch_size::Int = 256, ) - - try + + try @datadep_str "$repo-$game-$index" catch e if isa(e, KeyError) throw("Invalid params, check out `atari_params()`") end end - - path = @datadep_str "$repo-$game-$index" + + path = @datadep_str "$repo-$game-$index" @assert length(readdir(path)) == 1 folder_name = readdir(path)[1] - + folder_path = "$path/$folder_name" files = readdir(folder_path) - file_prefixes = collect(Set(map(x->join(split(x,"_")[1:2], "_"), files))) + file_prefixes = collect(Set(map(x -> join(split(x, "_")[1:2], "_"), files))) fields = map(collect(file_prefixes)) do x if split(x, "_")[1] == "\$store\$" x = split(x, "_")[2] @@ -93,7 +93,7 @@ function dataset( end s_epochs = Set(epochs) - + dataset = Dict() for (prefix, field) in zip(file_prefixes, fields) @@ -110,9 +110,9 @@ function dataset( if haskey(dataset, field) if field == "observation" - dataset[field] = cat(dataset[field], data, dims=3) + dataset[field] = cat(dataset[field], data, dims = 3) else - dataset[field] = cat(dataset[field], data, dims=1) + dataset[field] = cat(dataset[field], data, dims = 1) end else dataset[field] = data @@ -122,24 +122,37 @@ function dataset( num_epochs = length(s_epochs) - atari_verify(dataset, num_epochs) + atari_verify(dataset, num_epochs) N_samples = size(dataset["observation"])[3] - - final_dataset = Dict{Symbol, Any}() - meta = Dict{String, Any}() - for (key, d_key) in zip(["observation", "action", "reward", "terminal"], Symbol.(["state", "action", "reward", "terminal"])) - final_dataset[d_key] = dataset[key] + final_dataset = Dict{Symbol,Any}() + meta = Dict{String,Any}() + + for (key, d_key) in zip( + ["observation", "action", "reward", "terminal"], + Symbol.(["state", "action", "reward", "terminal"]), + ) + final_dataset[d_key] = dataset[key] end - + for key in keys(dataset) if !(key in ["observation", "action", "reward", "terminal"]) meta[key] = dataset[key] end end - return AtariDataSet(final_dataset, epochs, repo, N_samples, batch_size, style, rng, meta, is_shuffle) + return AtariDataSet( + final_dataset, + epochs, + repo, + N_samples, + batch_size, + style, + rng, + meta, + is_shuffle, + ) end @@ -153,7 +166,7 @@ function iterate(ds::AtariDataSet, state = 0) if is_shuffle inds = rand(rng, 1:length-1, batch_size) else - if (state+1) * batch_size <= length + if (state + 1) * batch_size <= length inds = state*batch_size+1:(state+1)*batch_size else return nothing @@ -161,15 +174,17 @@ function iterate(ds::AtariDataSet, state = 0) state += 1 end - batch = (state = view(ds.dataset[:state], :, :, inds), - action = view(ds.dataset[:action], inds), - reward = view(ds.dataset[:reward], inds), - terminal = view(ds.dataset[:terminal], inds)) + batch = ( + state = view(ds.dataset[:state], :, :, inds), + action = view(ds.dataset[:action], inds), + reward = view(ds.dataset[:reward], inds), + terminal = view(ds.dataset[:terminal], inds), + ) if style == SARTS - batch = merge(batch, (next_state = view(ds.dataset[:state], :, :, (1).+(inds)),)) + batch = merge(batch, (next_state = view(ds.dataset[:state], :, :, (1) .+ (inds)),)) end - + return batch, state end @@ -179,7 +194,8 @@ length(ds::AtariDataSet) = ds.length IteratorEltype(::Type{AtariDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit) function atari_verify(dataset::Dict, num_epochs::Int) - @assert size(dataset["observation"]) == (atari_frame_size, atari_frame_size, num_epochs*samples_per_epoch) + @assert size(dataset["observation"]) == + (atari_frame_size, atari_frame_size, num_epochs * samples_per_epoch) @assert size(dataset["action"]) == (num_epochs * samples_per_epoch,) @assert size(dataset["reward"]) == (num_epochs * samples_per_epoch,) @assert size(dataset["terminal"]) == (num_epochs * samples_per_epoch,) diff --git a/src/ReinforcementLearningDatasets/src/atari/register.jl b/src/ReinforcementLearningDatasets/src/atari/register.jl index 8b4ddb81b..485a0ac31 100644 --- a/src/ReinforcementLearningDatasets/src/atari/register.jl +++ b/src/ReinforcementLearningDatasets/src/atari/register.jl @@ -9,19 +9,66 @@ function atari_params() end const ATARI_GAMES = [ - "air-raid", "alien", "amidar", "assault", "asterix", - "asteroids", "atlantis", "bank-heist", "battle-zone", "beam-rider", - "berzerk", "bowling", "boxing", "breakout", "carnival", "centipede", - "chopper-command", "crazy-climber", "demon-attack", - "double-dunk", "elevator-action", "enduro", "fishing-derby", "freeway", - "frostbite", "gopher", "gravitar", "hero", "ice-hockey", "jamesbond", - "journey-escape", "kangaroo", "krull", "kung-fu-master", - "montezuma-revenge", "ms-pacman", "name-this-game", "phoenix", - "pitfall", "pong", "pooyan", "private-eye", "qbert", "riverraid", - "road-runner", "robotank", "seaquest", "skiing", "solaris", - "space-invaders", "star-gunner", "tennis", "time-pilot", "tutankham", - "up-n-down", "venture", "video-pinball", "wizard-of-wor", - "yars-revenge", "zaxxon" + "air-raid", + "alien", + "amidar", + "assault", + "asterix", + "asteroids", + "atlantis", + "bank-heist", + "battle-zone", + "beam-rider", + "berzerk", + "bowling", + "boxing", + "breakout", + "carnival", + "centipede", + "chopper-command", + "crazy-climber", + "demon-attack", + "double-dunk", + "elevator-action", + "enduro", + "fishing-derby", + "freeway", + "frostbite", + "gopher", + "gravitar", + "hero", + "ice-hockey", + "jamesbond", + "journey-escape", + "kangaroo", + "krull", + "kung-fu-master", + "montezuma-revenge", + "ms-pacman", + "name-this-game", + "phoenix", + "pitfall", + "pong", + "pooyan", + "private-eye", + "qbert", + "riverraid", + "road-runner", + "robotank", + "seaquest", + "skiing", + "solaris", + "space-invaders", + "star-gunner", + "tennis", + "time-pilot", + "tutankham", + "up-n-down", + "venture", + "video-pinball", + "wizard-of-wor", + "yars-revenge", + "zaxxon", ] game_name(game) = join(titlecase.(split(game, "-"))) @@ -48,9 +95,9 @@ function atari_init() encountered during training into 5 replay datasets per game, resulting in a total of 300 datasets. """, "gs://atari-replay-datasets/dqn/$(game_name(game))/$index/replay_logs/"; - fetch_method = fetch_gc_bucket - ) + fetch_method = fetch_gc_bucket, + ), ) end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/common.jl b/src/ReinforcementLearningDatasets/src/common.jl index bb3347000..ac574e031 100644 --- a/src/ReinforcementLearningDatasets/src/common.jl +++ b/src/ReinforcementLearningDatasets/src/common.jl @@ -31,10 +31,18 @@ fetch a gc bucket from `src` to `dest`. """ function fetch_gc_bucket(src, dest) if Sys.iswindows() - try run(`cmd /C gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end + try + run(`cmd /C gsutil -v`) + catch x + throw("gsutil not found, install gsutil to proceed further") + end run(`cmd /C gsutil -m cp -r $src $dest`) else - try run(`gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end + try + run(`gsutil -v`) + catch x + throw("gsutil not found, install gsutil to proceed further") + end run(`gsutil -m cp -r $src $dest`) end return dest @@ -45,11 +53,19 @@ fetch a gc file from `src` to `dest`. """ function fetch_gc_file(src, dest) if Sys.iswindows() - try run(`cmd /C gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end + try + run(`cmd /C gsutil -v`) + catch x + throw("gsutil not found, install gsutil to proceed further") + end run(`cmd /C gsutil -m cp $src $dest`) else - try run(`gsutil -v`) catch x throw("gsutil not found, install gsutil to proceed further") end + try + run(`gsutil -v`) + catch x + throw("gsutil not found, install gsutil to proceed further") + end run(`gsutil -m cp $src $dest`) end return dest -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl index 7d976bd66..eb5716068 100644 --- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl +++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl/register.jl @@ -1,4 +1,4 @@ -export d4rl_dataset_params +export d4rl_dataset_params function d4rl_dataset_params() dataset = keys(D4RL_DATASET_URLS) @@ -6,7 +6,7 @@ function d4rl_dataset_params() @info dataset repo end -const D4RL_DATASET_URLS = Dict{String, String}( +const D4RL_DATASET_URLS = Dict{String,String}( "maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-open-sparse.hdf5", "maze2d-umaze-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-umaze-sparse-v1.hdf5", "maze2d-medium-v1" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/maze2d/maze2d-medium-sparse-v1.hdf5", @@ -62,209 +62,209 @@ const D4RL_DATASET_URLS = Dict{String, String}( "antmaze-medium-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_big-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", "antmaze-large-play-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_False_sparse.hdf5", "antmaze-large-diverse-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/ant_maze_new/Ant_maze_hardest-maze_noisy_multistart_True_multigoal_True_sparse.hdf5", - "flow-ring-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5", - "flow-ring-controller-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5", - "flow-merge-random-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5", - "flow-merge-controller-v0"=>"http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5", + "flow-ring-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-random.hdf5", + "flow-ring-controller-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-ring-v0-idm.hdf5", + "flow-merge-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-random.hdf5", + "flow-merge-controller-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/flow/flow-merge-v0-idm.hdf5", "kitchen-complete-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/mini_kitchen_microwave_kettle_light_slider-v0.hdf5", "kitchen-partial-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_light_slider-v0.hdf5", "kitchen-mixed-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/kitchen/kitchen_microwave_kettle_bottomburner_light-v0.hdf5", - "carla-lane-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5", - "carla-town-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5", - "carla-town-full-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", - "bullet-halfcheetah-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5", - "bullet-halfcheetah-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5", - "bullet-halfcheetah-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5", - "bullet-halfcheetah-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5", - "bullet-halfcheetah-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5", - "bullet-hopper-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5", - "bullet-hopper-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5", - "bullet-hopper-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5", - "bullet-hopper-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5", - "bullet-hopper-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5", - "bullet-ant-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5", - "bullet-ant-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5", - "bullet-ant-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5", - "bullet-ant-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5", - "bullet-ant-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5", - "bullet-walker2d-random-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5", - "bullet-walker2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5", - "bullet-walker2d-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5", - "bullet-walker2d-medium-expert-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5", - "bullet-walker2d-medium-replay-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5", - "bullet-maze2d-open-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5", - "bullet-maze2d-umaze-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5", - "bullet-maze2d-medium-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5", - "bullet-maze2d-large-v0"=> "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5", + "carla-lane-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_lane_follow_flat-v0.hdf5", + "carla-town-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_subsamp_flat-v0.hdf5", + "carla-town-full-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/carla/carla_town_flat-v0.hdf5", + "bullet-halfcheetah-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_random.hdf5", + "bullet-halfcheetah-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium.hdf5", + "bullet-halfcheetah-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_expert.hdf5", + "bullet-halfcheetah-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_expert.hdf5", + "bullet-halfcheetah-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-halfcheetah_medium_replay.hdf5", + "bullet-hopper-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_random.hdf5", + "bullet-hopper-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium.hdf5", + "bullet-hopper-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_expert.hdf5", + "bullet-hopper-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_expert.hdf5", + "bullet-hopper-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-hopper_medium_replay.hdf5", + "bullet-ant-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_random.hdf5", + "bullet-ant-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium.hdf5", + "bullet-ant-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_expert.hdf5", + "bullet-ant-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_expert.hdf5", + "bullet-ant-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-ant_medium_replay.hdf5", + "bullet-walker2d-random-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_random.hdf5", + "bullet-walker2d-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium.hdf5", + "bullet-walker2d-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_expert.hdf5", + "bullet-walker2d-medium-expert-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_expert.hdf5", + "bullet-walker2d-medium-replay-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-walker2d_medium_replay.hdf5", + "bullet-maze2d-open-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-open-sparse.hdf5", + "bullet-maze2d-umaze-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-umaze-sparse.hdf5", + "bullet-maze2d-medium-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-medium-sparse.hdf5", + "bullet-maze2d-large-v0" => "http://rail.eecs.berkeley.edu/datasets/offline_rl/bullet/bullet-maze2d-large-sparse.hdf5", ) -const D4RL_REF_MIN_SCORE = Dict{String, Float32}( - "maze2d-open-v0" => 0.01 , - "maze2d-umaze-v1" => 23.85 , - "maze2d-medium-v1" => 13.13 , - "maze2d-large-v1" => 6.7 , - "maze2d-open-dense-v0" => 11.17817 , - "maze2d-umaze-dense-v1" => 68.537689 , - "maze2d-medium-dense-v1" => 44.264742 , - "maze2d-large-dense-v1" => 30.569041 , - "minigrid-fourrooms-v0" => 0.01442 , - "minigrid-fourrooms-random-v0" => 0.01442 , - "pen-human-v0" => 96.262799 , - "pen-cloned-v0" => 96.262799 , - "pen-expert-v0" => 96.262799 , - "hammer-human-v0" => -274.856578 , - "hammer-cloned-v0" => -274.856578 , - "hammer-expert-v0" => -274.856578 , - "relocate-human-v0" => -6.425911 , - "relocate-cloned-v0" => -6.425911 , - "relocate-expert-v0" => -6.425911 , - "door-human-v0" => -56.512833 , - "door-cloned-v0" => -56.512833 , - "door-expert-v0" => -56.512833 , - "halfcheetah-random-v0" => -280.178953 , - "halfcheetah-medium-v0" => -280.178953 , - "halfcheetah-expert-v0" => -280.178953 , - "halfcheetah-medium-replay-v0" => -280.178953 , - "halfcheetah-medium-expert-v0" => -280.178953 , - "walker2d-random-v0" => 1.629008 , - "walker2d-medium-v0" => 1.629008 , - "walker2d-expert-v0" => 1.629008 , - "walker2d-medium-replay-v0" => 1.629008 , - "walker2d-medium-expert-v0" => 1.629008 , - "hopper-random-v0" => -20.272305 , - "hopper-medium-v0" => -20.272305 , - "hopper-expert-v0" => -20.272305 , - "hopper-medium-replay-v0" => -20.272305 , - "hopper-medium-expert-v0" => -20.272305 , +const D4RL_REF_MIN_SCORE = Dict{String,Float32}( + "maze2d-open-v0" => 0.01, + "maze2d-umaze-v1" => 23.85, + "maze2d-medium-v1" => 13.13, + "maze2d-large-v1" => 6.7, + "maze2d-open-dense-v0" => 11.17817, + "maze2d-umaze-dense-v1" => 68.537689, + "maze2d-medium-dense-v1" => 44.264742, + "maze2d-large-dense-v1" => 30.569041, + "minigrid-fourrooms-v0" => 0.01442, + "minigrid-fourrooms-random-v0" => 0.01442, + "pen-human-v0" => 96.262799, + "pen-cloned-v0" => 96.262799, + "pen-expert-v0" => 96.262799, + "hammer-human-v0" => -274.856578, + "hammer-cloned-v0" => -274.856578, + "hammer-expert-v0" => -274.856578, + "relocate-human-v0" => -6.425911, + "relocate-cloned-v0" => -6.425911, + "relocate-expert-v0" => -6.425911, + "door-human-v0" => -56.512833, + "door-cloned-v0" => -56.512833, + "door-expert-v0" => -56.512833, + "halfcheetah-random-v0" => -280.178953, + "halfcheetah-medium-v0" => -280.178953, + "halfcheetah-expert-v0" => -280.178953, + "halfcheetah-medium-replay-v0" => -280.178953, + "halfcheetah-medium-expert-v0" => -280.178953, + "walker2d-random-v0" => 1.629008, + "walker2d-medium-v0" => 1.629008, + "walker2d-expert-v0" => 1.629008, + "walker2d-medium-replay-v0" => 1.629008, + "walker2d-medium-expert-v0" => 1.629008, + "hopper-random-v0" => -20.272305, + "hopper-medium-v0" => -20.272305, + "hopper-expert-v0" => -20.272305, + "hopper-medium-replay-v0" => -20.272305, + "hopper-medium-expert-v0" => -20.272305, "ant-random-v0" => -325.6, "ant-medium-v0" => -325.6, "ant-expert-v0" => -325.6, "ant-medium-replay-v0" => -325.6, "ant-medium-expert-v0" => -325.6, - "antmaze-umaze-v0" => 0.0 , - "antmaze-umaze-diverse-v0" => 0.0 , - "antmaze-medium-play-v0" => 0.0 , - "antmaze-medium-diverse-v0" => 0.0 , - "antmaze-large-play-v0" => 0.0 , - "antmaze-large-diverse-v0" => 0.0 , - "kitchen-complete-v0" => 0.0 , - "kitchen-partial-v0" => 0.0 , - "kitchen-mixed-v0" => 0.0 , - "flow-ring-random-v0" => -165.22 , - "flow-ring-controller-v0" => -165.22 , - "flow-merge-random-v0" => 118.67993 , - "flow-merge-controller-v0" => 118.67993 , - "carla-lane-v0"=> -0.8503839912088142, - "carla-town-v0"=> -114.81579500772153, # random score - "bullet-halfcheetah-random-v0"=> -1275.766996, - "bullet-halfcheetah-medium-v0"=> -1275.766996, - "bullet-halfcheetah-expert-v0"=> -1275.766996, - "bullet-halfcheetah-medium-expert-v0"=> -1275.766996, - "bullet-halfcheetah-medium-replay-v0"=> -1275.766996, - "bullet-hopper-random-v0"=> 20.058972, - "bullet-hopper-medium-v0"=> 20.058972, - "bullet-hopper-expert-v0"=> 20.058972, - "bullet-hopper-medium-expert-v0"=> 20.058972, - "bullet-hopper-medium-replay-v0"=> 20.058972, - "bullet-ant-random-v0"=> 373.705955, - "bullet-ant-medium-v0"=> 373.705955, - "bullet-ant-expert-v0"=> 373.705955, - "bullet-ant-medium-expert-v0"=> 373.705955, - "bullet-ant-medium-replay-v0"=> 373.705955, - "bullet-walker2d-random-v0"=> 16.523877, - "bullet-walker2d-medium-v0"=> 16.523877, - "bullet-walker2d-expert-v0"=> 16.523877, - "bullet-walker2d-medium-expert-v0"=> 16.523877, - "bullet-walker2d-medium-replay-v0"=> 16.523877, - "bullet-maze2d-open-v0"=> 8.750000, - "bullet-maze2d-umaze-v0"=> 32.460000, - "bullet-maze2d-medium-v0"=> 14.870000, - "bullet-maze2d-large-v0"=> 1.820000, + "antmaze-umaze-v0" => 0.0, + "antmaze-umaze-diverse-v0" => 0.0, + "antmaze-medium-play-v0" => 0.0, + "antmaze-medium-diverse-v0" => 0.0, + "antmaze-large-play-v0" => 0.0, + "antmaze-large-diverse-v0" => 0.0, + "kitchen-complete-v0" => 0.0, + "kitchen-partial-v0" => 0.0, + "kitchen-mixed-v0" => 0.0, + "flow-ring-random-v0" => -165.22, + "flow-ring-controller-v0" => -165.22, + "flow-merge-random-v0" => 118.67993, + "flow-merge-controller-v0" => 118.67993, + "carla-lane-v0" => -0.8503839912088142, + "carla-town-v0" => -114.81579500772153, # random score + "bullet-halfcheetah-random-v0" => -1275.766996, + "bullet-halfcheetah-medium-v0" => -1275.766996, + "bullet-halfcheetah-expert-v0" => -1275.766996, + "bullet-halfcheetah-medium-expert-v0" => -1275.766996, + "bullet-halfcheetah-medium-replay-v0" => -1275.766996, + "bullet-hopper-random-v0" => 20.058972, + "bullet-hopper-medium-v0" => 20.058972, + "bullet-hopper-expert-v0" => 20.058972, + "bullet-hopper-medium-expert-v0" => 20.058972, + "bullet-hopper-medium-replay-v0" => 20.058972, + "bullet-ant-random-v0" => 373.705955, + "bullet-ant-medium-v0" => 373.705955, + "bullet-ant-expert-v0" => 373.705955, + "bullet-ant-medium-expert-v0" => 373.705955, + "bullet-ant-medium-replay-v0" => 373.705955, + "bullet-walker2d-random-v0" => 16.523877, + "bullet-walker2d-medium-v0" => 16.523877, + "bullet-walker2d-expert-v0" => 16.523877, + "bullet-walker2d-medium-expert-v0" => 16.523877, + "bullet-walker2d-medium-replay-v0" => 16.523877, + "bullet-maze2d-open-v0" => 8.750000, + "bullet-maze2d-umaze-v0" => 32.460000, + "bullet-maze2d-medium-v0" => 14.870000, + "bullet-maze2d-large-v0" => 1.820000, ) -const D4RL_REF_MAX_SCORE = Dict{String, Float32}( - "maze2d-open-v0" => 20.66 , - "maze2d-umaze-v1" => 161.86 , - "maze2d-medium-v1" => 277.39 , - "maze2d-large-v1" => 273.99 , - "maze2d-open-dense-v0" => 27.166538620695782 , - "maze2d-umaze-dense-v1" => 193.66285642381482 , - "maze2d-medium-dense-v1" => 297.4552547777125 , - "maze2d-large-dense-v1" => 303.4857382709002 , - "minigrid-fourrooms-v0" => 2.89685 , - "minigrid-fourrooms-random-v0" => 2.89685 , - "pen-human-v0" => 3076.8331017826877 , - "pen-cloned-v0" => 3076.8331017826877 , - "pen-expert-v0" => 3076.8331017826877 , - "hammer-human-v0" => 12794.134825156867 , - "hammer-cloned-v0" => 12794.134825156867 , - "hammer-expert-v0" => 12794.134825156867 , - "relocate-human-v0" => 4233.877797728884 , - "relocate-cloned-v0" => 4233.877797728884 , - "relocate-expert-v0" => 4233.877797728884 , - "door-human-v0" => 2880.5693087298737 , - "door-cloned-v0" => 2880.5693087298737 , - "door-expert-v0" => 2880.5693087298737 , - "halfcheetah-random-v0" => 12135.0 , - "halfcheetah-medium-v0" => 12135.0 , - "halfcheetah-expert-v0" => 12135.0 , - "halfcheetah-medium-replay-v0" => 12135.0 , - "halfcheetah-medium-expert-v0" => 12135.0 , - "walker2d-random-v0" => 4592.3 , - "walker2d-medium-v0" => 4592.3 , - "walker2d-expert-v0" => 4592.3 , - "walker2d-medium-replay-v0" => 4592.3 , - "walker2d-medium-expert-v0" => 4592.3 , - "hopper-random-v0" => 3234.3 , - "hopper-medium-v0" => 3234.3 , - "hopper-expert-v0" => 3234.3 , - "hopper-medium-replay-v0" => 3234.3 , - "hopper-medium-expert-v0" => 3234.3 , +const D4RL_REF_MAX_SCORE = Dict{String,Float32}( + "maze2d-open-v0" => 20.66, + "maze2d-umaze-v1" => 161.86, + "maze2d-medium-v1" => 277.39, + "maze2d-large-v1" => 273.99, + "maze2d-open-dense-v0" => 27.166538620695782, + "maze2d-umaze-dense-v1" => 193.66285642381482, + "maze2d-medium-dense-v1" => 297.4552547777125, + "maze2d-large-dense-v1" => 303.4857382709002, + "minigrid-fourrooms-v0" => 2.89685, + "minigrid-fourrooms-random-v0" => 2.89685, + "pen-human-v0" => 3076.8331017826877, + "pen-cloned-v0" => 3076.8331017826877, + "pen-expert-v0" => 3076.8331017826877, + "hammer-human-v0" => 12794.134825156867, + "hammer-cloned-v0" => 12794.134825156867, + "hammer-expert-v0" => 12794.134825156867, + "relocate-human-v0" => 4233.877797728884, + "relocate-cloned-v0" => 4233.877797728884, + "relocate-expert-v0" => 4233.877797728884, + "door-human-v0" => 2880.5693087298737, + "door-cloned-v0" => 2880.5693087298737, + "door-expert-v0" => 2880.5693087298737, + "halfcheetah-random-v0" => 12135.0, + "halfcheetah-medium-v0" => 12135.0, + "halfcheetah-expert-v0" => 12135.0, + "halfcheetah-medium-replay-v0" => 12135.0, + "halfcheetah-medium-expert-v0" => 12135.0, + "walker2d-random-v0" => 4592.3, + "walker2d-medium-v0" => 4592.3, + "walker2d-expert-v0" => 4592.3, + "walker2d-medium-replay-v0" => 4592.3, + "walker2d-medium-expert-v0" => 4592.3, + "hopper-random-v0" => 3234.3, + "hopper-medium-v0" => 3234.3, + "hopper-expert-v0" => 3234.3, + "hopper-medium-replay-v0" => 3234.3, + "hopper-medium-expert-v0" => 3234.3, "ant-random-v0" => 3879.7, "ant-medium-v0" => 3879.7, "ant-expert-v0" => 3879.7, "ant-medium-replay-v0" => 3879.7, "ant-medium-expert-v0" => 3879.7, - "antmaze-umaze-v0" => 1.0 , - "antmaze-umaze-diverse-v0" => 1.0 , - "antmaze-medium-play-v0" => 1.0 , - "antmaze-medium-diverse-v0" => 1.0 , - "antmaze-large-play-v0" => 1.0 , - "antmaze-large-diverse-v0" => 1.0 , - "kitchen-complete-v0" => 4.0 , - "kitchen-partial-v0" => 4.0 , - "kitchen-mixed-v0" => 4.0 , - "flow-ring-random-v0" => 24.42 , - "flow-ring-controller-v0" => 24.42 , - "flow-merge-random-v0" => 330.03179 , - "flow-merge-controller-v0" => 330.03179 , - "carla-lane-v0"=> 1023.5784385429523, - "carla-town-v0"=> 2440.1772022247314, # avg dataset score - "bullet-halfcheetah-random-v0"=> 2381.6725, - "bullet-halfcheetah-medium-v0"=> 2381.6725, - "bullet-halfcheetah-expert-v0"=> 2381.6725, - "bullet-halfcheetah-medium-expert-v0"=> 2381.6725, - "bullet-halfcheetah-medium-replay-v0"=> 2381.6725, - "bullet-hopper-random-v0"=> 1441.8059623430963, - "bullet-hopper-medium-v0"=> 1441.8059623430963, - "bullet-hopper-expert-v0"=> 1441.8059623430963, - "bullet-hopper-medium-expert-v0"=> 1441.8059623430963, - "bullet-hopper-medium-replay-v0"=> 1441.8059623430963, - "bullet-ant-random-v0"=> 2650.495, - "bullet-ant-medium-v0"=> 2650.495, - "bullet-ant-expert-v0"=> 2650.495, - "bullet-ant-medium-expert-v0"=> 2650.495, - "bullet-ant-medium-replay-v0"=> 2650.495, - "bullet-walker2d-random-v0"=> 1623.6476303317536, - "bullet-walker2d-medium-v0"=> 1623.6476303317536, - "bullet-walker2d-expert-v0"=> 1623.6476303317536, - "bullet-walker2d-medium-expert-v0"=> 1623.6476303317536, - "bullet-walker2d-medium-replay-v0"=> 1623.6476303317536, - "bullet-maze2d-open-v0"=> 64.15, - "bullet-maze2d-umaze-v0"=> 153.99, - "bullet-maze2d-medium-v0"=> 238.05, - "bullet-maze2d-large-v0"=> 285.92, + "antmaze-umaze-v0" => 1.0, + "antmaze-umaze-diverse-v0" => 1.0, + "antmaze-medium-play-v0" => 1.0, + "antmaze-medium-diverse-v0" => 1.0, + "antmaze-large-play-v0" => 1.0, + "antmaze-large-diverse-v0" => 1.0, + "kitchen-complete-v0" => 4.0, + "kitchen-partial-v0" => 4.0, + "kitchen-mixed-v0" => 4.0, + "flow-ring-random-v0" => 24.42, + "flow-ring-controller-v0" => 24.42, + "flow-merge-random-v0" => 330.03179, + "flow-merge-controller-v0" => 330.03179, + "carla-lane-v0" => 1023.5784385429523, + "carla-town-v0" => 2440.1772022247314, # avg dataset score + "bullet-halfcheetah-random-v0" => 2381.6725, + "bullet-halfcheetah-medium-v0" => 2381.6725, + "bullet-halfcheetah-expert-v0" => 2381.6725, + "bullet-halfcheetah-medium-expert-v0" => 2381.6725, + "bullet-halfcheetah-medium-replay-v0" => 2381.6725, + "bullet-hopper-random-v0" => 1441.8059623430963, + "bullet-hopper-medium-v0" => 1441.8059623430963, + "bullet-hopper-expert-v0" => 1441.8059623430963, + "bullet-hopper-medium-expert-v0" => 1441.8059623430963, + "bullet-hopper-medium-replay-v0" => 1441.8059623430963, + "bullet-ant-random-v0" => 2650.495, + "bullet-ant-medium-v0" => 2650.495, + "bullet-ant-expert-v0" => 2650.495, + "bullet-ant-medium-expert-v0" => 2650.495, + "bullet-ant-medium-replay-v0" => 2650.495, + "bullet-walker2d-random-v0" => 1623.6476303317536, + "bullet-walker2d-medium-v0" => 1623.6476303317536, + "bullet-walker2d-expert-v0" => 1623.6476303317536, + "bullet-walker2d-medium-expert-v0" => 1623.6476303317536, + "bullet-walker2d-medium-replay-v0" => 1623.6476303317536, + "bullet-maze2d-open-v0" => 64.15, + "bullet-maze2d-umaze-v0" => 153.99, + "bullet-maze2d-medium-v0" => 238.05, + "bullet-maze2d-large-v0" => 285.92, ) # give a prompt for flow and carla tasks @@ -274,20 +274,20 @@ function d4rl_init() for ds in keys(D4RL_DATASET_URLS) register( DataDep( - repo*"-"* ds, + repo * "-" * ds, """ Credits: https://arxiv.org/abs/2004.07219 The following dataset is fetched from the d4rl. The dataset is fetched and modified in a form that is useful for RL.jl package. - + Dataset information: Name: $(ds) $(if ds in keys(D4RL_REF_MAX_SCORE) "MAXIMUM_SCORE: " * string(D4RL_REF_MAX_SCORE[ds]) end) $(if ds in keys(D4RL_REF_MIN_SCORE) "MINIMUM_SCORE: " * string(D4RL_REF_MIN_SCORE[ds]) end) """, #check if the MAX and MIN score part is even necessary and make the log file prettier D4RL_DATASET_URLS[ds], - ) + ), ) end nothing -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl index 105b8e65f..0ed98335e 100644 --- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl +++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_dataset.jl @@ -21,7 +21,7 @@ Represents an `Iterable` dataset with the following fields: - `is_shuffle::Bool`: determines if the batches returned by `iterate` are shuffled. """ struct D4RLDataSet{T<:AbstractRNG} <: RLDataSet - dataset::Dict{Symbol, Any} + dataset::Dict{Symbol,Any} repo::String dataset_size::Integer batch_size::Integer @@ -58,42 +58,47 @@ been tested in this package yet. """ function dataset( dataset::String; - repo::String="d4rl", - style::NTuple=SARTS, - rng::AbstractRNG=MersenneTwister(123), - is_shuffle::Bool=true, - batch_size::Int=256 + repo::String = "d4rl", + style::NTuple = SARTS, + rng::AbstractRNG = MersenneTwister(123), + is_shuffle::Bool = true, + batch_size::Int = 256, ) - - try - @datadep_str repo*"-"*dataset + + try + @datadep_str repo * "-" * dataset catch e if isa(e, KeyError) - throw("Invalid params, check out d4rl_pybullet_dataset_params() or d4rl_dataset_params()") + throw( + "Invalid params, check out d4rl_pybullet_dataset_params() or d4rl_dataset_params()", + ) end end - - path = @datadep_str repo*"-"*dataset + + path = @datadep_str repo * "-" * dataset @assert length(readdir(path)) == 1 file_name = readdir(path)[1] - - data = h5open(path*"/"*file_name, "r") do file + + data = h5open(path * "/" * file_name, "r") do file read(file) end # sanity checks on data d4rl_verify(data) - dataset = Dict{Symbol, Any}() - meta = Dict{String, Any}() + dataset = Dict{Symbol,Any}() + meta = Dict{String,Any}() N_samples = size(data["observations"])[2] - - for (key, d_key) in zip(["observations", "actions", "rewards", "terminals"], Symbol.(["state", "action", "reward", "terminal"])) - dataset[d_key] = data[key] + + for (key, d_key) in zip( + ["observations", "actions", "rewards", "terminals"], + Symbol.(["state", "action", "reward", "terminal"]), + ) + dataset[d_key] = data[key] end - + for key in keys(data) if !(key in ["observations", "actions", "rewards", "terminals"]) meta[key] = data[key] @@ -113,9 +118,13 @@ function iterate(ds::D4RLDataSet, state = 0) if is_shuffle inds = rand(rng, 1:size, batch_size) - map((x)-> if x <= size x else 1 end, inds) + map((x) -> if x <= size + x + else + 1 + end, inds) else - if (state+1) * batch_size <= size + if (state + 1) * batch_size <= size inds = state*batch_size+1:(state+1)*batch_size else return nothing @@ -123,15 +132,17 @@ function iterate(ds::D4RLDataSet, state = 0) state += 1 end - batch = (state = copy(ds.dataset[:state][:, inds]), - action = copy(ds.dataset[:action][:, inds]), - reward = copy(ds.dataset[:reward][inds]), - terminal = copy(ds.dataset[:terminal][inds])) + batch = ( + state = copy(ds.dataset[:state][:, inds]), + action = copy(ds.dataset[:action][:, inds]), + reward = copy(ds.dataset[:reward][inds]), + terminal = copy(ds.dataset[:terminal][inds]), + ) if style == SARTS batch = merge(batch, (next_state = copy(ds.dataset[:state][:, (1).+(inds)]),)) end - + return batch, state end @@ -141,11 +152,12 @@ length(ds::D4RLDataSet) = ds.dataset_size IteratorEltype(::Type{D4RLDataSet}) = EltypeUnknown() # see if eltype can be known (not sure about carla and adroit) -function d4rl_verify(data::Dict{String, Any}) +function d4rl_verify(data::Dict{String,Any}) for key in ["observations", "actions", "rewards", "terminals"] @assert (key in keys(data)) "Expected keys not present in data" end N_samples = size(data["observations"])[2] @assert size(data["rewards"]) == (N_samples,) || size(data["rewards"]) == (1, N_samples) - @assert size(data["terminals"]) == (N_samples,) || size(data["terminals"]) == (1, N_samples) + @assert size(data["terminals"]) == (N_samples,) || + size(data["terminals"]) == (1, N_samples) end diff --git a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl index a8aba9748..7355c90eb 100644 --- a/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl +++ b/src/ReinforcementLearningDatasets/src/d4rl/d4rl_pybullet/register.jl @@ -7,18 +7,18 @@ function d4rl_pybullet_dataset_params() end const D4RL_PYBULLET_URLS = Dict( - "hopper-bullet-mixed-v0" => "https://www.dropbox.com/s/xv3p0h7dzgxt8xb/hopper-bullet-mixed-v0.hdf5?dl=1", - "walker2d-bullet-random-v0" => "https://www.dropbox.com/s/1gwcfl2nmx6878m/walker2d-bullet-random-v0.hdf5?dl=1", - "hopper-bullet-medium-v0" => "https://www.dropbox.com/s/w22kgzldn6eng7j/hopper-bullet-medium-v0.hdf5?dl=1", + "hopper-bullet-mixed-v0" => "https://www.dropbox.com/s/xv3p0h7dzgxt8xb/hopper-bullet-mixed-v0.hdf5?dl=1", + "walker2d-bullet-random-v0" => "https://www.dropbox.com/s/1gwcfl2nmx6878m/walker2d-bullet-random-v0.hdf5?dl=1", + "hopper-bullet-medium-v0" => "https://www.dropbox.com/s/w22kgzldn6eng7j/hopper-bullet-medium-v0.hdf5?dl=1", "walker2d-bullet-mixed-v0" => "https://www.dropbox.com/s/i4u2ii0d85iblou/walker2d-bullet-mixed-v0.hdf5?dl=1", - "halfcheetah-bullet-mixed-v0" => "https://www.dropbox.com/s/scj1rqun963aw90/halfcheetah-bullet-mixed-v0.hdf5?dl=1", + "halfcheetah-bullet-mixed-v0" => "https://www.dropbox.com/s/scj1rqun963aw90/halfcheetah-bullet-mixed-v0.hdf5?dl=1", "halfcheetah-bullet-random-v0" => "https://www.dropbox.com/s/jnvpb1hp60zt2ak/halfcheetah-bullet-random-v0.hdf5?dl=1", - "walker2d-bullet-medium-v0" => "https://www.dropbox.com/s/v0f2kz48b1hw6or/walker2d-bullet-medium-v0.hdf5?dl=1", - "hopper-bullet-random-v0" => "https://www.dropbox.com/s/bino8ojd7iq4p4d/hopper-bullet-random-v0.hdf5?dl=1", - "ant-bullet-random-v0" => "https://www.dropbox.com/s/2xpmh4wk2m7i8xh/ant-bullet-random-v0.hdf5?dl=1", - "halfcheetah-bullet-medium-v0" => "https://www.dropbox.com/s/v4xgssp1w968a9l/halfcheetah-bullet-medium-v0.hdf5?dl=1", - "ant-bullet-medium-v0" => "https://www.dropbox.com/s/6n79kwd94xthr1t/ant-bullet-medium-v0.hdf5?dl=1", - "ant-bullet-mixed-v0" => "https://www.dropbox.com/s/pmy3dzab35g4whk/ant-bullet-mixed-v0.hdf5?dl=1" + "walker2d-bullet-medium-v0" => "https://www.dropbox.com/s/v0f2kz48b1hw6or/walker2d-bullet-medium-v0.hdf5?dl=1", + "hopper-bullet-random-v0" => "https://www.dropbox.com/s/bino8ojd7iq4p4d/hopper-bullet-random-v0.hdf5?dl=1", + "ant-bullet-random-v0" => "https://www.dropbox.com/s/2xpmh4wk2m7i8xh/ant-bullet-random-v0.hdf5?dl=1", + "halfcheetah-bullet-medium-v0" => "https://www.dropbox.com/s/v4xgssp1w968a9l/halfcheetah-bullet-medium-v0.hdf5?dl=1", + "ant-bullet-medium-v0" => "https://www.dropbox.com/s/6n79kwd94xthr1t/ant-bullet-medium-v0.hdf5?dl=1", + "ant-bullet-mixed-v0" => "https://www.dropbox.com/s/pmy3dzab35g4whk/ant-bullet-mixed-v0.hdf5?dl=1", ) function d4rl_pybullet_init() @@ -26,14 +26,14 @@ function d4rl_pybullet_init() for ds in keys(D4RL_PYBULLET_URLS) register( DataDep( - repo* "-" * ds, + repo * "-" * ds, """ Credits: https://github.com/takuseno/d4rl-pybullet The following dataset is fetched from the d4rl-pybullet. - """, + """, D4RL_PYBULLET_URLS[ds], - ) + ), ) end nothing -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl index eaa1d1d41..c8e7f8281 100644 --- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl +++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policies.jl @@ -2,7 +2,7 @@ export d4rl_policy_params function d4rl_policy_params() d4rl_policy_paths = [split(policy["policy_path"], "/")[2] for policy in D4RL_POLICIES] - env = Set(join.(map(x->x[1:end-2], split.(d4rl_policy_paths, "_")), "_")) + env = Set(join.(map(x -> x[1:end-2], split.(d4rl_policy_paths, "_")), "_")) agent = ["dapg", "online"] epoch = 0:10 @@ -12,322 +12,234 @@ end const D4RL_POLICIES = [ Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_0.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.0, - "return_std =>" => 0.0 + "return_std =>" => 0.0, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_10.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.48, - "return_std =>" => 0.4995998398718718 + "return_std =>" => 0.4995998398718718, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_1.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.0, - "return_std =>" => 0.0 + "return_std =>" => 0.0, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_2.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.0, - "return_std =>" => 0.0 + "return_std =>" => 0.0, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_3.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.0, - "return_std =>" => 0.0 + "return_std =>" => 0.0, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_4.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.01, - "return_std =>" => 0.09949874371066199 + "return_std =>" => 0.09949874371066199, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_5.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.13, - "return_std =>" => 0.33630343441600474 + "return_std =>" => 0.33630343441600474, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_6.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.22, - "return_std =>" => 0.41424630354415964 + "return_std =>" => 0.41424630354415964, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_7.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.12, - "return_std =>" => 0.32496153618543844 + "return_std =>" => 0.32496153618543844, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_8.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.39, - "return_std =>" => 0.487749935930288 + "return_std =>" => 0.487749935930288, ), Dict( "policy_path" => "antmaze_large/antmaze_large_dapg_9.pkl", - "task.task_names" => [ - "antmaze-large-play-v0", - "antmaze-large-diverse-v0" - ], + "task.task_names" => ["antmaze-large-play-v0", "antmaze-large-diverse-v0"], "agent_name" => "BC", "return_mean" => 0.49, - "return_std =>" => 0.4998999899979995 + "return_std =>" => 0.4998999899979995, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_0.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.66, - "return_std =>" => 0.4737087712930805 + "return_std =>" => 0.4737087712930805, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_10.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.12, - "return_std =>" => 0.32496153618543844 + "return_std =>" => 0.32496153618543844, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_1.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.53, - "return_std =>" => 0.49909918853871116 + "return_std =>" => 0.49909918853871116, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_2.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.66, - "return_std =>" => 0.4737087712930805 + "return_std =>" => 0.4737087712930805, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_3.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.57, - "return_std =>" => 0.49507575177946245 + "return_std =>" => 0.49507575177946245, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_4.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.58, - "return_std =>" => 0.49355850717012273 + "return_std =>" => 0.49355850717012273, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_5.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.42, - "return_std =>" => 0.49355850717012273 + "return_std =>" => 0.49355850717012273, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_6.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.45, - "return_std =>" => 0.49749371855331004 + "return_std =>" => 0.49749371855331004, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_7.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.27, - "return_std =>" => 0.4439594576084623 + "return_std =>" => 0.4439594576084623, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_8.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.1, - "return_std =>" => 0.29999999999999993 + "return_std =>" => 0.29999999999999993, ), Dict( "policy_path" => "antmaze_medium/antmaze_medium_dapg_9.pkl", - "task.task_names" => [ - "antmaze-medium-play-v0", - "antmaze-medium-diverse-v0" - ], + "task.task_names" => ["antmaze-medium-play-v0", "antmaze-medium-diverse-v0"], "agent_name" => "DAPG", "return_mean" => 0.15, - "return_std =>" => 0.3570714214271425 + "return_std =>" => 0.3570714214271425, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_0.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.11, - "return_std =>" => 0.31288975694324034 + "return_std =>" => 0.31288975694324034, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_10.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.84, - "return_std =>" => 0.36660605559646725 + "return_std =>" => 0.36660605559646725, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_1.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.15, - "return_std =>" => 0.3570714214271425 + "return_std =>" => 0.3570714214271425, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_2.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.08, - "return_std =>" => 0.2712931993250107 + "return_std =>" => 0.2712931993250107, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_3.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.13, - "return_std =>" => 0.33630343441600474 + "return_std =>" => 0.33630343441600474, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_4.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.19, - "return_std =>" => 0.3923009049186606 + "return_std =>" => 0.3923009049186606, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_5.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.27, - "return_std =>" => 0.4439594576084623 + "return_std =>" => 0.4439594576084623, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_6.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.41, - "return_std =>" => 0.4918333050943175 + "return_std =>" => 0.4918333050943175, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_7.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.66, - "return_std =>" => 0.4737087712930805 + "return_std =>" => 0.4737087712930805, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_8.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.72, - "return_std =>" => 0.4489988864128729 + "return_std =>" => 0.4489988864128729, ), Dict( "policy_path" => "antmaze_umaze/antmaze_umaze_dapg_9.pkl", - "task.task_names" => [ - "antmaze-umaze-v0" - ], + "task.task_names" => ["antmaze-umaze-v0"], "agent_name" => "DAPG", "return_mean" => 0.45, - "return_std =>" => 0.49749371855331 + "return_std =>" => 0.49749371855331, ), Dict( "policy_path" => "ant/ant_online_0.pkl", @@ -336,11 +248,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => -61.02055183305979, - "return_std =>" => 118.86259895376526 + "return_std =>" => 118.86259895376526, ), Dict( "policy_path" => "ant/ant_online_10.pkl", @@ -349,11 +261,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 5226.071929273204, - "return_std =>" => 1351.489114884685 + "return_std =>" => 1351.489114884685, ), Dict( "policy_path" => "ant/ant_online_1.pkl", @@ -362,11 +274,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1128.957315814236, - "return_std =>" => 545.9910621405912 + "return_std =>" => 545.9910621405912, ), Dict( "policy_path" => "ant/ant_online_2.pkl", @@ -375,11 +287,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1874.9426222623788, - "return_std =>" => 821.523301172575 + "return_std =>" => 821.523301172575, ), Dict( "policy_path" => "ant/ant_online_3.pkl", @@ -388,11 +300,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 2694.0050365558186, - "return_std =>" => 829.1251729756312 + "return_std =>" => 829.1251729756312, ), Dict( "policy_path" => "ant/ant_online_4.pkl", @@ -401,11 +313,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 2927.728155987557, - "return_std =>" => 1218.962159178784 + "return_std =>" => 1218.962159178784, ), Dict( "policy_path" => "ant/ant_online_5.pkl", @@ -414,11 +326,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => -271.0455967662947, - "return_std =>" => 181.7343490946006 + "return_std =>" => 181.7343490946006, ), Dict( "policy_path" => "ant/ant_online_6.pkl", @@ -427,11 +339,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 3923.0820284011284, - "return_std =>" => 1384.459574872169 + "return_std =>" => 1384.459574872169, ), Dict( "policy_path" => "ant/ant_online_7.pkl", @@ -440,11 +352,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 4564.024787293475, - "return_std =>" => 1207.181426135141 + "return_std =>" => 1207.181426135141, ), Dict( "policy_path" => "ant/ant_online_8.pkl", @@ -453,11 +365,11 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 5116.58562094113, - "return_std =>" => 962.8694737383373 + "return_std =>" => 962.8694737383373, ), Dict( "policy_path" => "ant/ant_online_9.pkl", @@ -466,132 +378,88 @@ const D4RL_POLICIES = [ "ant-random-v0", "ant-expert-v0", "ant-medium-replay-v0", - "ant-medium-expert-v0" + "ant-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 5176.548960934259, - "return_std =>" => 1000.122269767824 + "return_std =>" => 1000.122269767824, ), Dict( "policy_path" => "door/door_dapg_0.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => -53.63337645679012, - "return_std =>" => 2.0058239428094895 + "return_std =>" => 2.0058239428094895, ), Dict( "policy_path" => "door/door_dapg_10.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 2974.9306587121887, - "return_std =>" => 52.48250668645121 + "return_std =>" => 52.48250668645121, ), Dict( "policy_path" => "door/door_dapg_1.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => -51.41658735064874, - "return_std =>" => 0.6978335854285623 + "return_std =>" => 0.6978335854285623, ), Dict( "policy_path" => "door/door_dapg_2.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 86.28632719532406, - "return_std =>" => 256.30747202806475 + "return_std =>" => 256.30747202806475, ), Dict( "policy_path" => "door/door_dapg_3.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 1282.0275007615646, - "return_std =>" => 633.9669441391286 + "return_std =>" => 633.9669441391286, ), Dict( "policy_path" => "door/door_dapg_4.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 1607.4255566289276, - "return_std =>" => 499.58651630841575 + "return_std =>" => 499.58651630841575, ), Dict( "policy_path" => "door/door_dapg_5.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 2142.36638691816, - "return_std =>" => 442.0537003890031 + "return_std =>" => 442.0537003890031, ), Dict( "policy_path" => "door/door_dapg_6.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 2525.495218483574, - "return_std =>" => 160.8683834534215 + "return_std =>" => 160.8683834534215, ), Dict( "policy_path" => "door/door_dapg_7.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 2794.653907232321, - "return_std =>" => 62.78226619278986 + "return_std =>" => 62.78226619278986, ), Dict( "policy_path" => "door/door_dapg_8.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 2870.85173247603, - "return_std =>" => 37.96052715176604 + "return_std =>" => 37.96052715176604, ), Dict( "policy_path" => "door/door_dapg_9.pkl", - "task.task_names" => [ - "door-cloned-v0", - "door-expert-v0", - "door-human-v0" - ], + "task.task_names" => ["door-cloned-v0", "door-expert-v0", "door-human-v0"], "agent_name" => "DAPG", "return_mean" => 2959.4718836123457, - "return_std =>" => 53.31391818495784 + "return_std =>" => 53.31391818495784, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_0.pkl", @@ -600,11 +468,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => -309.2417932614121, - "return_std =>" => 91.3640277992432 + "return_std =>" => 91.3640277992432, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_10.pkl", @@ -613,11 +481,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 12695.696030461002, - "return_std =>" => 209.98612023443096 + "return_std =>" => 209.98612023443096, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_1.pkl", @@ -626,11 +494,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 5686.148033603298, - "return_std =>" => 77.60317050580818 + "return_std =>" => 77.60317050580818, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_2.pkl", @@ -639,11 +507,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 6898.252473142946, - "return_std =>" => 131.2808199171071 + "return_std =>" => 131.2808199171071, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_3.pkl", @@ -652,11 +520,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 7843.345957832609, - "return_std =>" => 119.82879594969056 + "return_std =>" => 119.82879594969056, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_4.pkl", @@ -665,11 +533,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 8661.367146815282, - "return_std =>" => 142.1433195543218 + "return_std =>" => 142.1433195543218, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_5.pkl", @@ -678,11 +546,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 9197.889639800613, - "return_std =>" => 125.40543058761767 + "return_std =>" => 125.40543058761767, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_6.pkl", @@ -691,11 +559,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 9623.789519132608, - "return_std =>" => 130.91946985245835 + "return_std =>" => 130.91946985245835, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_7.pkl", @@ -704,11 +572,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 10255.26711299773, - "return_std =>" => 173.52116806555978 + "return_std =>" => 173.52116806555978, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_8.pkl", @@ -717,11 +585,11 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 10899.460856799158, - "return_std =>" => 324.2557642475202 + "return_std =>" => 324.2557642475202, ), Dict( "policy_path" => "halfcheetah/halfcheetah_online_9.pkl", @@ -730,132 +598,99 @@ const D4RL_POLICIES = [ "halfcheetah-random-v0", "halfcheetah-expert-v0", "halfcheetah-medium-replay-v0", - "halfcheetah-medium-expert-v0" + "halfcheetah-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 11829.054827593913, - "return_std =>" => 240.63510160394745 + "return_std =>" => 240.63510160394745, ), Dict( "policy_path" => "hammer/hammer_dapg_0.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => -236.37114898868305, - "return_std =>" => 5.2941436284324075 + "return_std =>" => 5.2941436284324075, ), Dict( "policy_path" => "hammer/hammer_dapg_10.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 17585.58837262877, - "return_std =>" => 96.53489547795978 + "return_std =>" => 96.53489547795978, ), Dict( "policy_path" => "hammer/hammer_dapg_1.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 128.60395654435058, - "return_std =>" => 30.68441678661929 + "return_std =>" => 30.68441678661929, ), Dict( "policy_path" => "hammer/hammer_dapg_2.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 7408.354956936379, - "return_std =>" => 7294.096332941535 + "return_std =>" => 7294.096332941535, ), Dict( "policy_path" => "hammer/hammer_dapg_3.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 15594.112899701715, - "return_std =>" => 197.28904701529942 + "return_std =>" => 197.28904701529942, ), Dict( "policy_path" => "hammer/hammer_dapg_4.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 16245.548923178216, - "return_std =>" => 262.7060238728634 + "return_std =>" => 262.7060238728634, ), Dict( "policy_path" => "hammer/hammer_dapg_5.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 16595.136728219404, - "return_std =>" => 124.5270089215883 + "return_std =>" => 124.5270089215883, ), Dict( "policy_path" => "hammer/hammer_dapg_6.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 17065.590900836418, - "return_std =>" => 55.85140116556182 + "return_std =>" => 55.85140116556182, ), Dict( "policy_path" => "hammer/hammer_dapg_7.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 17209.380445590097, - "return_std =>" => 35.922080086069116 + "return_std =>" => 35.922080086069116, ), Dict( "policy_path" => "hammer/hammer_dapg_8.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 17388.10343669515, - "return_std =>" => 71.04818789434533 + "return_std =>" => 71.04818789434533, ), Dict( "policy_path" => "hammer/hammer_dapg_9.pkl", - "task.task_names" => [ - "hammer-cloned-v0", - "hammer-expert-v0", - "hammer-human-v0" - ], + "task.task_names" => + ["hammer-cloned-v0", "hammer-expert-v0", "hammer-human-v0"], "agent_name" => "DAPG", "return_mean" => 17565.807571496796, - "return_std =>" => 83.22119300427666 + "return_std =>" => 83.22119300427666, ), Dict( "policy_path" => "hopper/hopper_online_0.pkl", @@ -864,11 +699,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 89.08207455972816, - "return_std =>" => 45.69740377810402 + "return_std =>" => 45.69740377810402, ), Dict( "policy_path" => "hopper/hopper_online_10.pkl", @@ -877,11 +712,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1290.7677147248753, - "return_std =>" => 86.34701290680572 + "return_std =>" => 86.34701290680572, ), Dict( "policy_path" => "hopper/hopper_online_1.pkl", @@ -890,11 +725,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1134.244611055915, - "return_std =>" => 407.6547443287992 + "return_std =>" => 407.6547443287992, ), Dict( "policy_path" => "hopper/hopper_online_2.pkl", @@ -903,11 +738,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 727.0768143435397, - "return_std =>" => 92.94955320157855 + "return_std =>" => 92.94955320157855, ), Dict( "policy_path" => "hopper/hopper_online_3.pkl", @@ -916,11 +751,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1571.2810005160163, - "return_std =>" => 447.3216244940128 + "return_std =>" => 447.3216244940128, ), Dict( "policy_path" => "hopper/hopper_online_4.pkl", @@ -929,11 +764,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1140.2394986005213, - "return_std =>" => 671.1379607505328 + "return_std =>" => 671.1379607505328, ), Dict( "policy_path" => "hopper/hopper_online_5.pkl", @@ -942,11 +777,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1872.571834592923, - "return_std =>" => 793.8865779126361 + "return_std =>" => 793.8865779126361, ), Dict( "policy_path" => "hopper/hopper_online_6.pkl", @@ -955,11 +790,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 3088.2017624993064, - "return_std =>" => 356.52713477862386 + "return_std =>" => 356.52713477862386, ), Dict( "policy_path" => "hopper/hopper_online_7.pkl", @@ -968,11 +803,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 1726.0060438089222, - "return_std =>" => 761.6326666292086 + "return_std =>" => 761.6326666292086, ), Dict( "policy_path" => "hopper/hopper_online_8.pkl", @@ -981,11 +816,11 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 2952.957468938808, - "return_std =>" => 682.5831907733249 + "return_std =>" => 682.5831907733249, ), Dict( "policy_path" => "hopper/hopper_online_9.pkl", @@ -994,550 +829,407 @@ const D4RL_POLICIES = [ "hopper-random-v0", "hopper-expert-v0", "hopper-medium-replay-v0", - "hopper-medium-expert-v0" + "hopper-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 2369.7998719150673, - "return_std =>" => 1119.4914225331481 + "return_std =>" => 1119.4914225331481, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_0.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 2.21, - "return_std =>" => 8.873888662812938 + "return_std =>" => 8.873888662812938, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_10.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 627.86, - "return_std =>" => 161.0254650668645 + "return_std =>" => 161.0254650668645, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_1.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 41.74, - "return_std =>" => 72.2068722491149 + "return_std =>" => 72.2068722491149, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_2.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 124.9, - "return_std =>" => 131.5638628195448 + "return_std =>" => 131.5638628195448, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_3.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 107.78, - "return_std =>" => 109.32251186283638 + "return_std =>" => 109.32251186283638, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_4.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 289.46, - "return_std =>" => 262.69070862898826 + "return_std =>" => 262.69070862898826, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_5.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 356.17, - "return_std =>" => 276.9112151936068 + "return_std =>" => 276.9112151936068, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_6.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 393.87, - "return_std =>" => 309.08651394067647 + "return_std =>" => 309.08651394067647, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_7.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 517.4, - "return_std =>" => 274.58688970888613 + "return_std =>" => 274.58688970888613, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_8.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 565.42, - "return_std =>" => 210.94450360225082 + "return_std =>" => 210.94450360225082, ), Dict( "policy_path" => "maze2d_large/maze2d_large_dapg_9.pkl", - "task.task_names" => [ - "maze2d-large-v1" - ], + "task.task_names" => ["maze2d-large-v1"], "agent_name" => "DAPG", "return_mean" => 629.22, - "return_std =>" => 123.23023817229276 + "return_std =>" => 123.23023817229276, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_0.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 83.15, - "return_std =>" => 177.59827561099797 + "return_std =>" => 177.59827561099797, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_10.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 442.35, - "return_std =>" => 161.2205554512203 + "return_std =>" => 161.2205554512203, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_1.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 177.8, - "return_std =>" => 218.1089635938881 + "return_std =>" => 218.1089635938881, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_2.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 249.33, - "return_std =>" => 237.2338110388146 + "return_std =>" => 237.2338110388146, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_3.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 214.81, - "return_std =>" => 246.09809812349224 + "return_std =>" => 246.09809812349224, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_4.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 254.63, - "return_std =>" => 262.0181541420365 + "return_std =>" => 262.0181541420365, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_5.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 238.76, - "return_std =>" => 260.3596404975241 + "return_std =>" => 260.3596404975241, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_6.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 374.9, - "return_std =>" => 222.18107480161314 + "return_std =>" => 222.18107480161314, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_7.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 379.68, - "return_std =>" => 228.59111443798514 + "return_std =>" => 228.59111443798514, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_8.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 392.9, - "return_std =>" => 217.99805044999832 + "return_std =>" => 217.99805044999832, ), Dict( "policy_path" => "maze2d_medium/maze2d_medium_dapg_9.pkl", - "task.task_names" => [ - "maze2d-medium-v1" - ], + "task.task_names" => ["maze2d-medium-v1"], "agent_name" => "DAPG", "return_mean" => 432.03, - "return_std =>" => 173.93714123211294 + "return_std =>" => 173.93714123211294, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_0.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 22.19, - "return_std =>" => 25.18320670605711 + "return_std =>" => 25.18320670605711, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_10.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 250.64, - "return_std =>" => 36.357810715168206 + "return_std =>" => 36.357810715168206, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_1.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 43.33, - "return_std =>" => 66.01621846182951 + "return_std =>" => 66.01621846182951, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_2.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 100.97, - "return_std =>" => 95.598060126762 + "return_std =>" => 95.598060126762, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_3.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 115.26, - "return_std =>" => 120.07919220247945 + "return_std =>" => 120.07919220247945, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_4.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 106.56, - "return_std =>" => 123.82562901112192 + "return_std =>" => 123.82562901112192, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_5.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 142.5, - "return_std =>" => 111.55568116416124 + "return_std =>" => 111.55568116416124, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_6.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 172.13, - "return_std =>" => 118.24048841238772 + "return_std =>" => 118.24048841238772, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_7.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 190.98, - "return_std =>" => 73.81706848690214 + "return_std =>" => 73.81706848690214, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_8.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 228.17, - "return_std =>" => 39.635856241539685 + "return_std =>" => 39.635856241539685, ), Dict( "policy_path" => "maze2d_umaze/maze2d_umaze_dapg_9.pkl", - "task.task_names" => [ - "maze2d-umaze-v1" - ], + "task.task_names" => ["maze2d-umaze-v1"], "agent_name" => "DAPG", "return_mean" => 239.34, - "return_std =>" => 37.597664821102924 + "return_std =>" => 37.597664821102924, ), Dict( "policy_path" => "pen/pen_dapg_0.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 1984.096763504694, - "return_std =>" => 1929.6110474391166 + "return_std =>" => 1929.6110474391166, ), Dict( "policy_path" => "pen/pen_dapg_10.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 3808.794849593491, - "return_std =>" => 1932.9965631785215 + "return_std =>" => 1932.9965631785215, ), Dict( "policy_path" => "pen/pen_dapg_1.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 2480.1224231814135, - "return_std =>" => 2125.5773427152635 + "return_std =>" => 2125.5773427152635, ), Dict( "policy_path" => "pen/pen_dapg_2.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 2494.1335875747145, - "return_std =>" => 2118.0014860996175 + "return_std =>" => 2118.0014860996175, ), Dict( "policy_path" => "pen/pen_dapg_3.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 2802.87073294418, - "return_std =>" => 2120.3981104287323 + "return_std =>" => 2120.3981104287323, ), Dict( "policy_path" => "pen/pen_dapg_4.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 3136.18545171068, - "return_std =>" => 2112.923714191993 + "return_std =>" => 2112.923714191993, ), Dict( "policy_path" => "pen/pen_dapg_5.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 3110.619191864754, - "return_std =>" => 2012.2585161410343 + "return_std =>" => 2012.2585161410343, ), Dict( "policy_path" => "pen/pen_dapg_6.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 3410.4384362331157, - "return_std =>" => 2029.187357465904 + "return_std =>" => 2029.187357465904, ), Dict( "policy_path" => "pen/pen_dapg_7.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 3489.353704450997, - "return_std =>" => 2035.2279026017748 + "return_std =>" => 2035.2279026017748, ), Dict( "policy_path" => "pen/pen_dapg_8.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 3673.9622983303598, - "return_std =>" => 2052.8837762657795 + "return_std =>" => 2052.8837762657795, ), Dict( "policy_path" => "pen/pen_dapg_9.pkl", - "task.task_names" => [ - "pen-cloned-v0", - "pen-expert-v0", - "pen-human-v0" - ], + "task.task_names" => ["pen-cloned-v0", "pen-expert-v0", "pen-human-v0"], "agent_name" => "DAPG", "return_mean" => 3683.932983177092, - "return_std =>" => 2028.9543873822265 + "return_std =>" => 2028.9543873822265, ), Dict( "policy_path" => "relocate/relocate_dapg_0.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => -4.4718813284277195, - "return_std =>" => 0.9021515021945451 + "return_std =>" => 0.9021515021945451, ), Dict( "policy_path" => "relocate/relocate_dapg_10.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 3481.7834354311035, - "return_std =>" => 813.1857720257618 + "return_std =>" => 813.1857720257618, ), Dict( "policy_path" => "relocate/relocate_dapg_1.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 5.070946470816939, - "return_std =>" => 31.708695854456067 + "return_std =>" => 31.708695854456067, ), Dict( "policy_path" => "relocate/relocate_dapg_2.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 54.976670129729555, - "return_std =>" => 140.09635704443158 + "return_std =>" => 140.09635704443158, ), Dict( "policy_path" => "relocate/relocate_dapg_3.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 54.11338525066304, - "return_std =>" => 146.87277676706216 + "return_std =>" => 146.87277676706216, ), Dict( "policy_path" => "relocate/relocate_dapg_4.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 97.16474411169358, - "return_std =>" => 164.81156449057102 + "return_std =>" => 164.81156449057102, ), Dict( "policy_path" => "relocate/relocate_dapg_5.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 366.3185681324701, - "return_std =>" => 581.577837554543 + "return_std =>" => 581.577837554543, ), Dict( "policy_path" => "relocate/relocate_dapg_6.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 1254.0676523894747, - "return_std =>" => 929.5248207929493 + "return_std =>" => 929.5248207929493, ), Dict( "policy_path" => "relocate/relocate_dapg_7.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 2700.2361856493385, - "return_std =>" => 1089.9871332809942 + "return_std =>" => 1089.9871332809942, ), Dict( "policy_path" => "relocate/relocate_dapg_8.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 2570.351217370911, - "return_std =>" => 1266.9305994339466 + "return_std =>" => 1266.9305994339466, ), Dict( "policy_path" => "relocate/relocate_dapg_9.pkl", - "task.task_names" => [ - "relocate-cloned-v0", - "relocate-expert-v0", - "relocate-human-v0" - ], + "task.task_names" => + ["relocate-cloned-v0", "relocate-expert-v0", "relocate-human-v0"], "agent_name" => "DAPG", "return_mean" => 3379.424369497742, - "return_std =>" => 948.6183219418235 + "return_std =>" => 948.6183219418235, ), Dict( "policy_path" => "walker/walker_online_0.pkl", @@ -1546,11 +1238,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 17.57372020467802, - "return_std =>" => 51.686802739349666 + "return_std =>" => 51.686802739349666, ), Dict( "policy_path" => "walker/walker_online_10.pkl", @@ -1559,11 +1251,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 4120.947079569632, - "return_std =>" => 468.1515654051671 + "return_std =>" => 468.1515654051671, ), Dict( "policy_path" => "walker/walker_online_1.pkl", @@ -1572,11 +1264,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 193.84631742541606, - "return_std =>" => 185.16785303932383 + "return_std =>" => 185.16785303932383, ), Dict( "policy_path" => "walker/walker_online_2.pkl", @@ -1585,11 +1277,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 942.6191179097373, - "return_std =>" => 532.9834162811841 + "return_std =>" => 532.9834162811841, ), Dict( "policy_path" => "walker/walker_online_3.pkl", @@ -1598,11 +1290,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 2786.7497792224794, - "return_std =>" => 477.5450988462439 + "return_std =>" => 477.5450988462439, ), Dict( "policy_path" => "walker/walker_online_4.pkl", @@ -1611,11 +1303,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 914.4680927038296, - "return_std =>" => 559.5155757967623 + "return_std =>" => 559.5155757967623, ), Dict( "policy_path" => "walker/walker_online_5.pkl", @@ -1624,11 +1316,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 3481.491012709211, - "return_std =>" => 87.12729823320758 + "return_std =>" => 87.12729823320758, ), Dict( "policy_path" => "walker/walker_online_6.pkl", @@ -1637,11 +1329,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 2720.2509272083826, - "return_std =>" => 746.9753406110725 + "return_std =>" => 746.9753406110725, ), Dict( "policy_path" => "walker/walker_online_7.pkl", @@ -1650,11 +1342,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 3926.346852318098, - "return_std =>" => 365.4230491920236 + "return_std =>" => 365.4230491920236, ), Dict( "policy_path" => "walker/walker_online_8.pkl", @@ -1663,11 +1355,11 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 3695.4887678612936, - "return_std =>" => 262.0350155576298 + "return_std =>" => 262.0350155576298, ), Dict( "policy_path" => "walker/walker_online_9.pkl", @@ -1676,10 +1368,10 @@ const D4RL_POLICIES = [ "walker2d-random-v0", "walker2d-expert-v0", "walker2d-medium-replay-v0", - "walker2d-medium-expert-v0" + "walker2d-medium-expert-v0", ], "agent_name" => "SAC", "return_mean" => 4122.358396232011, - "return_std =>" => 107.76279305206488 - ) -] \ No newline at end of file + "return_std =>" => 107.76279305206488, + ), +] diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl index f47b55624..c50f83366 100644 --- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl +++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/d4rl_policy.jl @@ -21,8 +21,8 @@ Flux.@functor D4RLGaussianNetwork function (model::D4RLGaussianNetwork)( state::AbstractArray; - rng::AbstractRNG=MersenneTwister(123), - noisy::Bool=true + rng::AbstractRNG = MersenneTwister(123), + noisy::Bool = true, ) x = model.pre(state) μ, logσ = model.μ(x), model.logσ(x) @@ -32,7 +32,7 @@ function (model::D4RLGaussianNetwork)( a = μ + exp.(logσ) end a, μ -end +end """ d4rl_policy(env, agent, epoch) @@ -45,11 +45,8 @@ Check [deep_ope](https://github.com/google-research/deep_ope) with preloaded wei - `agent::String`: can be `dapg` or `online`. - `epoch::Int`: can be in `0:10`. """ -function d4rl_policy( - env::String, - agent::String, - epoch::Int) - +function d4rl_policy(env::String, agent::String, epoch::Int) + folder_prefix = "deep-ope-d4rl" try @datadep_str "$(folder_prefix)-$(env)_$(agent)_$(epoch)" @@ -60,13 +57,13 @@ function d4rl_policy( end policy_folder = @datadep_str "$(folder_prefix)-$(env)_$(agent)_$(epoch)" policy_file = "$(policy_folder)/$(readdir(policy_folder)[1])" - + model_params = Pickle.npyload(policy_file) @pipe parse_network_params(model_params) |> build_model(_...) end function parse_network_params(model_params::Dict) - size_dict = Dict{String, Tuple}() + size_dict = Dict{String,Tuple}() nonlinearity = nothing output_transformation = nothing for param in model_params @@ -81,7 +78,7 @@ function parse_network_params(model_params::Dict) nonlinearity = tanh end else - if param_value == "tanh_gaussian" + if param_value == "tanh_gaussian" output_transformation = tanh else output_transformation = identity @@ -92,29 +89,31 @@ function parse_network_params(model_params::Dict) model_params, size_dict, nonlinearity, output_transformation end -function build_model(model_params::Dict, size_dict::Dict, nonlinearity::Function, output_transformation::Function) +function build_model( + model_params::Dict, + size_dict::Dict, + nonlinearity::Function, + output_transformation::Function, +) fc_0 = Dense(size_dict["fc0/weight"]..., nonlinearity) fc_0 = @set fc_0.weight = model_params["fc0/weight"] fc_0 = @set fc_0.bias = model_params["fc0/bias"] - + fc_1 = Dense(size_dict["fc1/weight"]..., nonlinearity) fc_1 = @set fc_1.weight = model_params["fc1/weight"] fc_1 = @set fc_1.bias = model_params["fc1/bias"] - + μ_fc = Dense(size_dict["last_fc/weight"]...) μ_fc = @set μ_fc.weight = model_params["last_fc/weight"] μ_fc = @set μ_fc.bias = model_params["last_fc/bias"] - + log_σ_fc = Dense(size_dict["last_fc_log_std/weight"]...) log_σ_fc = @set log_σ_fc.weight = model_params["last_fc_log_std/weight"] log_σ_fc = @set log_σ_fc.bias = model_params["last_fc_log_std/bias"] - - pre = Chain( - fc_0, - fc_1 - ) + + pre = Chain(fc_0, fc_1) μ = Chain(μ_fc) log_σ = Chain(log_σ_fc) - + D4RLGaussianNetwork(pre, μ, log_σ) -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl index 816736bc4..b3a742428 100644 --- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl +++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/evaluate.jl @@ -17,28 +17,30 @@ function deep_ope_d4rl_evaluate( env_name::String, agent::String, epoch::Int; - gym_env_name::Union{String, Nothing}=nothing, - rng::AbstractRNG=MersenneTwister(123), - num_evaluations::Int=10, - γ::Float64=1.0, - noisy::Bool=false, - env_seed::Union{Int, Nothing}=nothing -) + gym_env_name::Union{String,Nothing} = nothing, + rng::AbstractRNG = MersenneTwister(123), + num_evaluations::Int = 10, + γ::Float64 = 1.0, + noisy::Bool = false, + env_seed::Union{Int,Nothing} = nothing, +) policy_folder = "$(env_name)_$(agent)_$(epoch)" if gym_env_name === nothing for policy in D4RL_POLICIES policy_file = split(policy["policy_path"], "/")[end] - if chop(policy_file, head=0, tail=4) == policy_folder + if chop(policy_file, head = 0, tail = 4) == policy_folder gym_env_name = policy["task.task_names"][1] break end end - if gym_env_name === nothing error("invalid parameters") end + if gym_env_name === nothing + error("invalid parameters") + end end - env = GymEnv(gym_env_name; seed=env_seed) + env = GymEnv(gym_env_name; seed = env_seed) model = d4rl_policy(env_name, agent, epoch) scores = Vector{Float64}(undef, num_evaluations) @@ -48,14 +50,23 @@ function deep_ope_d4rl_evaluate( reset!(env) while !is_terminated(env) s = state(env) - a = model(s;rng=rng, noisy=noisy)[1] - s, a , env(a) + a = model(s; rng = rng, noisy = noisy)[1] + s, a, env(a) r = reward(env) t = is_terminated(env) - score += r*γ*(1-t) + score += r * γ * (1 - t) end scores[eval] = score end - plt = lineplot(1:length(scores), scores, title = "$(gym_env_name) scores", name = "scores", xlabel = "episode", canvas = DotCanvas, ylabel = "score", border=:ascii) + plt = lineplot( + 1:length(scores), + scores, + title = "$(gym_env_name) scores", + name = "scores", + xlabel = "episode", + canvas = DotCanvas, + ylabel = "score", + border = :ascii, + ) plt -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl index 0c0081dc4..d1be2502f 100644 --- a/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl +++ b/src/ReinforcementLearningDatasets/src/deep_ope/d4rl/register.jl @@ -1,11 +1,11 @@ gcs_prefix = "gs://gresearch/deep-ope/d4rl" -folder_prefix = "deep-ope-d4rl" +folder_prefix = "deep-ope-d4rl" policies = D4RL_POLICIES function deep_ope_d4rl_init() for policy in policies gcs_policy_folder = policy["policy_path"] - local_policy_folder = chop(split(gcs_policy_folder, "/")[end], head=0, tail=4) + local_policy_folder = chop(split(gcs_policy_folder, "/")[end], head = 0, tail = 4) register( DataDep( "$(folder_prefix)-$(local_policy_folder)", @@ -16,7 +16,7 @@ function deep_ope_d4rl_init() Authors: Justin Fu, Mohammad Norouzi, Ofir Nachum, George Tucker, ziyu wang, Alexander Novikov, Mengjiao Yang, Michael R Zhang, Yutian Chen, Aviral Kumar, Cosmin Paduraru, Sergey Levine, Thomas Paine Year: 2021 - + Deep OPE contains: Policies for the tasks in the D4RL, DeepMind Locomotion and Control Suite datasets. Policies trained with the following algorithms (D4PG, ABM, CRR, SAC, DAPG and BC) and snapshots along the training trajectory. This facilitates @@ -29,8 +29,8 @@ function deep_ope_d4rl_init() what datasets are available, please refer to D4RL: Datasets for Deep Data-Driven Reinforcement Learning. """, "$(gcs_prefix)/$(gcs_policy_folder)"; - fetch_method=fetch_gc_file - ) + fetch_method = fetch_gc_file, + ), ) end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/init.jl b/src/ReinforcementLearningDatasets/src/init.jl index 90bf6f0bd..fd1ede41d 100644 --- a/src/ReinforcementLearningDatasets/src/init.jl +++ b/src/ReinforcementLearningDatasets/src/init.jl @@ -6,4 +6,4 @@ function __init__() RLDatasets.bsuite_init() RLDatasets.dm_init() RLDatasets.deep_ope_d4rl_init() -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl index 287e6dbac..ec022d63e 100644 --- a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl +++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/register.jl @@ -57,13 +57,13 @@ const TESTING_SUITE = [ ] # Total of 45 games. -const ALL = cat(TUNING_SUITE, TESTING_SUITE, dims=1) +const ALL = cat(TUNING_SUITE, TESTING_SUITE, dims = 1) function rl_unplugged_atari_params() game = ALL run = 1:5 shards = 0:99 - + @info game run shards end @@ -98,11 +98,12 @@ function rl_unplugged_atari_init() on Atari if you are interested in comparing your approach to other state of the art offline RL methods with discrete actions. """, - "gs://rl_unplugged/atari/$game/"*@sprintf("run_%i-%05i-of-%05i", run, index, num_shards); - fetch_method = fetch_gc_file - ) + "gs://rl_unplugged/atari/$game/" * + @sprintf("run_%i-%05i-of-%05i", run, index, num_shards); + fetch_method = fetch_gc_file, + ), ) end end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl index 753265e08..08cade692 100644 --- a/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl +++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/atari/rl_unplugged_atari.jl @@ -1,7 +1,7 @@ export rl_unplugged_atari_dataset using Base.Threads -using Printf:@sprintf +using Printf: @sprintf using Base.Iterators using TFRecord using ImageCore @@ -14,14 +14,14 @@ using PNGFiles Represent an AtariRLTransition and can also represent a batch. """ struct AtariRLTransition <: RLTransition - state - action - reward - terminal - next_state - next_action - episode_id - episode_return + state::Any + action::Any + reward::Any + terminal::Any + next_state::Any + next_action::Any + episode_id::Any + episode_return::Any end function decode_frame(bytes) @@ -29,7 +29,7 @@ function decode_frame(bytes) end function decode_state(bytes) - PermutedDimsArray(StackedView((decode_frame(x) for x in bytes)...), (2,3,1)) + PermutedDimsArray(StackedView((decode_frame(x) for x in bytes)...), (2, 3, 1)) end function AtariRLTransition(example::TFRecord.Example) @@ -70,65 +70,62 @@ function rl_unplugged_atari_dataset( game::String, run::Int, shards::Vector{Int}; - shuffle_buffer_size=10_000, - tf_reader_bufsize=1*1024*1024, - tf_reader_sz=10_000, - batch_size=256, - n_preallocations=nthreads()*12 + shuffle_buffer_size = 10_000, + tf_reader_bufsize = 1 * 1024 * 1024, + tf_reader_sz = 10_000, + batch_size = 256, + n_preallocations = nthreads() * 12, ) n = nthreads() @info "Loading the shards $shards in $run run of $game with $n threads" folders = [ - @datadep_str "rl-unplugged-atari-$(titlecase(game))-$run-$shard" - for shard in shards + @datadep_str "rl-unplugged-atari-$(titlecase(game))-$run-$shard" for + shard in shards ] - + ch_files = Channel{String}(length(folders)) do ch for folder in cycle(folders) file = folder * "/$(readdir(folder)[1])" put!(ch, file) end end - + shuffled_files = buffered_shuffle(ch_files, length(folders)) - + ch_src = Channel{AtariRLTransition}(n * tf_reader_sz) do ch for fs in partition(shuffled_files, n) Threads.foreach( TFRecord.read( fs; - compression=:gzip, - bufsize=tf_reader_bufsize, - channel_size=tf_reader_sz, + compression = :gzip, + bufsize = tf_reader_bufsize, + channel_size = tf_reader_sz, ); - schedule=Threads.StaticSchedule() + schedule = Threads.StaticSchedule(), ) do x put!(ch, AtariRLTransition(x)) end end end - - transitions = buffered_shuffle( - ch_src, - shuffle_buffer_size - ) - + + transitions = buffered_shuffle(ch_src, shuffle_buffer_size) + buffer = AtariRLTransition( - Array{UInt8, 4}(undef, 84, 84, 4, batch_size), - Array{Int, 1}(undef, batch_size), - Array{Float32, 1}(undef, batch_size), - Array{Bool, 1}(undef, batch_size), - Array{UInt8, 4}(undef, 84, 84, 4, batch_size), - Array{Int, 1}(undef, batch_size), - Array{Int, 1}(undef, batch_size), - Array{Float32, 1}(undef, batch_size), + Array{UInt8,4}(undef, 84, 84, 4, batch_size), + Array{Int,1}(undef, batch_size), + Array{Float32,1}(undef, batch_size), + Array{Bool,1}(undef, batch_size), + Array{UInt8,4}(undef, 84, 84, 4, batch_size), + Array{Int,1}(undef, batch_size), + Array{Int,1}(undef, batch_size), + Array{Float32,1}(undef, batch_size), ) taskref = Ref{Task}() - res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff + res = RingBuffer(buffer; taskref = taskref, sz = n_preallocations) do buff Threads.@threads for i in 1:batch_size batch!(buff, popfirst!(transitions), i) end @@ -137,4 +134,4 @@ function rl_unplugged_atari_dataset( bind(ch_src, taskref[]) bind(ch_files, taskref[]) res -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl index af8ff4af6..713342fa9 100644 --- a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl +++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/bsuite.jl @@ -3,11 +3,11 @@ export rl_unplugged_bsuite_dataset using TFRecord struct BSuiteRLTransition <: RLTransition - state - action - reward - terminal - next_state + state::Any + action::Any + reward::Any + terminal::Any + next_state::Any end function BSuiteRLTransition(example::TFRecord.Example, game::String) @@ -55,46 +55,43 @@ function rl_unplugged_bsuite_dataset( game::String, shards::Vector{Int}, type::String; - is_shuffle::Bool=true, - stochasticity::Float64=0.0, - shuffle_buffer_size::Int=10_000, - tf_reader_bufsize::Int=10_000, - tf_reader_sz::Int=10_000, - batch_size::Int=256, - n_preallocations::Int=nthreads()*12 -) + is_shuffle::Bool = true, + stochasticity::Float64 = 0.0, + shuffle_buffer_size::Int = 10_000, + tf_reader_bufsize::Int = 10_000, + tf_reader_sz::Int = 10_000, + batch_size::Int = 256, + n_preallocations::Int = nthreads() * 12, +) n = nthreads() repo = "rl-unplugged-bsuite" - - folders= [ - @datadep_str "$repo-$game-$stochasticity-$shard-$type" - for shard in shards - ] - + + folders = [@datadep_str "$repo-$game-$stochasticity-$shard-$type" for shard in shards] + ch_files = Channel{String}(length(folders)) do ch for folder in cycle(folders) file = folder * "/$(readdir(folder)[1])" put!(ch, file) end end - + if is_shuffle files = buffered_shuffle(ch_files, length(folders)) else files = ch_files end - + ch_src = Channel{BSuiteRLTransition}(n * tf_reader_sz) do ch for fs in partition(files, n) Threads.foreach( TFRecord.read( fs; - compression=:gzip, - bufsize=tf_reader_bufsize, - channel_size=tf_reader_sz, + compression = :gzip, + bufsize = tf_reader_bufsize, + channel_size = tf_reader_sz, ); - schedule=Threads.StaticSchedule() + schedule = Threads.StaticSchedule(), ) do x put!(ch, BSuiteRLTransition(x, game)) end @@ -102,33 +99,30 @@ function rl_unplugged_bsuite_dataset( end if is_shuffle - transitions = buffered_shuffle( - ch_src, - shuffle_buffer_size - ) + transitions = buffered_shuffle(ch_src, shuffle_buffer_size) else transitions = ch_src end - + taskref = Ref{Task}() - ob_size = game=="mountain_car" ? 3 : 6 + ob_size = game == "mountain_car" ? 3 : 6 if game == "catch" - obs_template = Array{Float32, 3}(undef, 10, 5, batch_size) + obs_template = Array{Float32,3}(undef, 10, 5, batch_size) else - obs_template = Array{Float32, 2}(undef, ob_size, batch_size) + obs_template = Array{Float32,2}(undef, ob_size, batch_size) end buffer = BSuiteRLTransition( copy(obs_template), - Array{Int, 1}(undef, batch_size), - Array{Float32, 1}(undef, batch_size), - Array{Bool, 1}(undef, batch_size), + Array{Int,1}(undef, batch_size), + Array{Float32,1}(undef, batch_size), + Array{Bool,1}(undef, batch_size), copy(obs_template), ) - res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff + res = RingBuffer(buffer; taskref = taskref, sz = n_preallocations) do buff Threads.@threads for i in 1:batch_size batch!(buff, take!(transitions), i) end @@ -137,4 +131,4 @@ function rl_unplugged_bsuite_dataset( bind(ch_src, taskref[]) bind(ch_files, taskref[]) res -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl index 487a6f72b..488619f77 100644 --- a/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl +++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/bsuite/register.jl @@ -1,17 +1,9 @@ repo = "bsuite" export bsuite_params -const BSUITE_DATASETS = [ - "cartpole", - "catch", - "mountain_car" -] +const BSUITE_DATASETS = ["cartpole", "catch", "mountain_car"] -types = [ - "full", - "full_train", - "full_valid" -] +types = ["full", "full_train", "full_valid"] function bsuite_params() game = BSUITE_DATASETS @@ -47,11 +39,11 @@ function bsuite_init() where the stochasticity of the environment is easy to control. """, "gs://rl_unplugged/$repo/$env/0_$stochasticity/$(index)_$type-00000-of-00001", - fetch_method = fetch_gc_file - ) + fetch_method = fetch_gc_file, + ), ) end end end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl index 364555e79..b46fffd43 100644 --- a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl +++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/register.jl @@ -12,170 +12,166 @@ function dm_params() @info game shards end -const DM_LOCOMOTION_RODENT = Dict{String, String}( +const DM_LOCOMOTION_RODENT = Dict{String,String}( "rodent_gaps" => "dm_locomotion/rodent_gaps/seq2", "rodent_escape" => "dm_locomotion/rodent_bowl_escape/seq2", "rodent_two_touch" => "dm_locomotion/rodent_two_touch/seq40", - "rodent_mazes" => "dm_locomotion/rodent_mazes/seq40" + "rodent_mazes" => "dm_locomotion/rodent_mazes/seq40", ) -const DM_LOCOMOTION_RODENT_SIZE = Dict{String, Tuple}( +const DM_LOCOMOTION_RODENT_SIZE = Dict{String,Tuple}( "observation/walker/actuator_activation" => (38,), # "observation/walker/sensors_torque" => (), # "observation/walker/sensors_force" => (), "observation/walker/body_height" => (1,), "observation/walker/end_effectors_pos" => (12,), - "observation/walker/joints_pos"=> (30,), - "observation/walker/joints_vel"=> (30,), - "observation/walker/tendons_pos"=> (8,), - "observation/walker/tendons_vel"=> (8,), - "observation/walker/appendages_pos"=> (15,), - "observation/walker/world_zaxis"=> (3,), - "observation/walker/sensors_accelerometer"=> (3,), - "observation/walker/sensors_velocimeter"=> (3,), - "observation/walker/sensors_gyro" => (3,), - "observation/walker/sensors_touch"=> (4,), - "observation/walker/egocentric_camera"=> (64, 64, 3), - "action"=> (38,), - "discount"=> (), - "reward"=> (), - "step_type"=> () + "observation/walker/joints_pos" => (30,), + "observation/walker/joints_vel" => (30,), + "observation/walker/tendons_pos" => (8,), + "observation/walker/tendons_vel" => (8,), + "observation/walker/appendages_pos" => (15,), + "observation/walker/world_zaxis" => (3,), + "observation/walker/sensors_accelerometer" => (3,), + "observation/walker/sensors_velocimeter" => (3,), + "observation/walker/sensors_gyro" => (3,), + "observation/walker/sensors_touch" => (4,), + "observation/walker/egocentric_camera" => (64, 64, 3), + "action" => (38,), + "discount" => (), + "reward" => (), + "step_type" => (), ) -const DM_LOCOMOTION_HUMANOID = Dict{String, String}( +const DM_LOCOMOTION_HUMANOID = Dict{String,String}( "humanoid_corridor" => "dm_locomotion/humanoid_corridor/seq2", "humanoid_gaps" => "dm_locomotion/humanoid_gaps/seq2", - "humanoid_walls" => "dm_locomotion/humanoid_walls/seq40" + "humanoid_walls" => "dm_locomotion/humanoid_walls/seq40", ) -const DM_LOCOMOTION_HUMANOID_SIZE = Dict{String, Tuple}( +const DM_LOCOMOTION_HUMANOID_SIZE = Dict{String,Tuple}( # "observation/walker/actuator_activation" => (0,), "observation/walker/sensors_torque" => (6,), # "observation/walker/sensors_force" => (), - "observation/walker/joints_vel"=> (56,), - "observation/walker/sensors_velocimeter"=> (3,), - "observation/walker/sensors_gyro"=> (3,), - "observation/walker/joints_pos"=> (56,), + "observation/walker/joints_vel" => (56,), + "observation/walker/sensors_velocimeter" => (3,), + "observation/walker/sensors_gyro" => (3,), + "observation/walker/joints_pos" => (56,), "observation/walker/appendages_pos" => (15,), - "observation/walker/world_zaxis"=> (3,), - "observation/walker/body_height"=> (1,), - "observation/walker/sensors_accelerometer"=> (3,), - "observation/walker/end_effectors_pos"=> (12,), - "observation/walker/egocentric_camera"=> ( - 64, - 64, - 3, - ), - "action"=> (56,), - "discount"=> (), - "reward"=> (), + "observation/walker/world_zaxis" => (3,), + "observation/walker/body_height" => (1,), + "observation/walker/sensors_accelerometer" => (3,), + "observation/walker/end_effectors_pos" => (12,), + "observation/walker/egocentric_camera" => (64, 64, 3), + "action" => (56,), + "discount" => (), + "reward" => (), # "episodic_reward"=> (), - "step_type"=> () + "step_type" => (), ) -const DM_CONTROL_SUITE_SIZE = Dict{String, Dict{String, Tuple}}( - "cartpole_swingup" => Dict{String, Tuple}( - "observation/position"=> (3,), - "observation/velocity"=> (2,), - "action"=> (1,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () +const DM_CONTROL_SUITE_SIZE = Dict{String,Dict{String,Tuple}}( + "cartpole_swingup" => Dict{String,Tuple}( + "observation/position" => (3,), + "observation/velocity" => (2,), + "action" => (1,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), ), - "cheetah_run" => Dict{String, Tuple}( - "observation/position"=> (8,), - "observation/velocity"=> (9,), - "action"=> (6,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () + "cheetah_run" => Dict{String,Tuple}( + "observation/position" => (8,), + "observation/velocity" => (9,), + "action" => (6,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), ), - "finger_turn_hard" => Dict{String, Tuple}( - "observation/position"=> (4,), - "observation/velocity"=> (3,), - "observation/touch"=> (2,), - "observation/target_position"=> (2,), - "observation/dist_to_target"=> (1,), - "action"=> (2,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () + "finger_turn_hard" => Dict{String,Tuple}( + "observation/position" => (4,), + "observation/velocity" => (3,), + "observation/touch" => (2,), + "observation/target_position" => (2,), + "observation/dist_to_target" => (1,), + "action" => (2,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), ), - "fish_swim" => Dict{String, Tuple}( - "observation/target"=> (3,), - "observation/velocity"=> (13,), - "observation/upright"=> (1,), - "observation/joint_angles"=> (7,), - "action"=> (5,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () + "fish_swim" => Dict{String,Tuple}( + "observation/target" => (3,), + "observation/velocity" => (13,), + "observation/upright" => (1,), + "observation/joint_angles" => (7,), + "action" => (5,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), ), - "humanoid_run" => Dict{String, Tuple}( - "observation/velocity"=> (27,), - "observation/com_velocity"=> (3,), - "observation/torso_vertical"=> (3,), - "observation/extremities"=> (12,), - "observation/head_height"=> (1,), - "observation/joint_angles"=> (21,), - "action"=> (21,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () + "humanoid_run" => Dict{String,Tuple}( + "observation/velocity" => (27,), + "observation/com_velocity" => (3,), + "observation/torso_vertical" => (3,), + "observation/extremities" => (12,), + "observation/head_height" => (1,), + "observation/joint_angles" => (21,), + "action" => (21,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), ), - "manipulator_insert_ball" => Dict{String, Tuple}( - "observation/arm_pos"=> (16,), - "observation/arm_vel"=> (8,), - "observation/touch"=> (5,), - "observation/hand_pos"=> (4,), - "observation/object_pos"=> (4,), - "observation/object_vel"=> (3,), - "observation/target_pos"=> (4,), - "action"=> (5,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () + "manipulator_insert_ball" => Dict{String,Tuple}( + "observation/arm_pos" => (16,), + "observation/arm_vel" => (8,), + "observation/touch" => (5,), + "observation/hand_pos" => (4,), + "observation/object_pos" => (4,), + "observation/object_vel" => (3,), + "observation/target_pos" => (4,), + "action" => (5,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), ), - "manipulator_insert_peg" => Dict{String, Tuple}( - "observation/arm_pos"=> (16,), - "observation/arm_vel"=> (8,), - "observation/touch"=> (5,), - "observation/hand_pos"=> (4,), - "observation/object_pos"=> (4,), - "observation/object_vel"=> (3,), - "observation/target_pos"=> (4,), - "episodic_reward"=> (), - "action"=> (5,), - "discount"=> (), - "reward"=> (), - "step_type"=> () + "manipulator_insert_peg" => Dict{String,Tuple}( + "observation/arm_pos" => (16,), + "observation/arm_vel" => (8,), + "observation/touch" => (5,), + "observation/hand_pos" => (4,), + "observation/object_pos" => (4,), + "observation/object_vel" => (3,), + "observation/target_pos" => (4,), + "episodic_reward" => (), + "action" => (5,), + "discount" => (), + "reward" => (), + "step_type" => (), ), - "walker_stand" => Dict{String, Tuple}( - "observation/orientations"=> (14,), - "observation/velocity"=> (9,), - "observation/height"=> (1,), - "action"=> (6,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () + "walker_stand" => Dict{String,Tuple}( + "observation/orientations" => (14,), + "observation/velocity" => (9,), + "observation/height" => (1,), + "action" => (6,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), + ), + "walker_walk" => Dict{String,Tuple}( + "observation/orientations" => (14,), + "observation/velocity" => (9,), + "observation/height" => (1,), + "action" => (6,), + "discount" => (), + "reward" => (), + "episodic_reward" => (), + "step_type" => (), ), - "walker_walk" => Dict{String, Tuple}( - "observation/orientations"=> (14,), - "observation/velocity"=> (9,), - "observation/height"=> (1,), - "action"=> (6,), - "discount"=> (), - "reward"=> (), - "episodic_reward"=> (), - "step_type"=> () - ) ) const DM_LOCOMOTION = merge(DM_LOCOMOTION_HUMANOID, DM_LOCOMOTION_RODENT) @@ -208,9 +204,10 @@ function dm_init() please refer to the paper. DeepMind Control Suite is a traditional continuous action RL benchmark. In particular, it is recommended that you test your approach in DeepMind Control Suite if you are interested in comparing against other state of the art offline RL methods. """, - "gs://rl_unplugged/dm_control_suite/$task/"*@sprintf("train-%05i-of-%05i", index, num_shards); + "gs://rl_unplugged/dm_control_suite/$task/" * + @sprintf("train-%05i-of-%05i", index, num_shards); fetch_method = fetch_gc_file, - ) + ), ) end end @@ -240,10 +237,11 @@ function dm_init() It is recommended that you to try offline RL methods on DeepMind Locomotion dataset, if you are interested in very challenging offline RL dataset with continuous action space. """, - "gs://rl_unplugged/$(DM_LOCOMOTION[task])/"*@sprintf("train-%05i-of-%05i", index, num_shards); - fetch_method = fetch_gc_file - ) + "gs://rl_unplugged/$(DM_LOCOMOTION[task])/" * + @sprintf("train-%05i-of-%05i", index, num_shards); + fetch_method = fetch_gc_file, + ), ) end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl index df3ebad5d..115ba1bde 100644 --- a/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl +++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/dm/rl_unplugged_dm.jl @@ -3,31 +3,41 @@ export rl_unplugged_dm_dataset using TFRecord function make_batch_array(type::Type, feature_dims::Int, size::Tuple, batch_size::Int) - Array{type, feature_dims+1}(undef, size..., batch_size) + Array{type,feature_dims + 1}(undef, size..., batch_size) end -function dm_buffer_dict(feature_size::Dict{String, Tuple}, batch_size::Int) - obs_buffer = Dict{Symbol, AbstractArray}() +function dm_buffer_dict(feature_size::Dict{String,Tuple}, batch_size::Int) + obs_buffer = Dict{Symbol,AbstractArray}() - buffer_dict = Dict{Symbol, Any}() + buffer_dict = Dict{Symbol,Any}() for feature in keys(feature_size) feature_dims = length(feature_size[feature]) if split(feature, "/")[1] == "observation" - ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0)) + ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0)) if split(feature, "/")[end] == "egocentric_camera" - obs_buffer[ob_key] = make_batch_array(UInt8, feature_dims, feature_size[feature], batch_size) + obs_buffer[ob_key] = + make_batch_array(UInt8, feature_dims, feature_size[feature], batch_size) else - obs_buffer[ob_key] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size) + obs_buffer[ob_key] = make_batch_array( + Float32, + feature_dims, + feature_size[feature], + batch_size, + ) end elseif feature == "action" - buffer_dict[:action] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size) - buffer_dict[:next_action] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size) + buffer_dict[:action] = + make_batch_array(Float32, feature_dims, feature_size[feature], batch_size) + buffer_dict[:next_action] = + make_batch_array(Float32, feature_dims, feature_size[feature], batch_size) elseif feature == "step_type" - buffer_dict[:terminal] = make_batch_array(Bool, feature_dims, feature_size[feature], batch_size) + buffer_dict[:terminal] = + make_batch_array(Bool, feature_dims, feature_size[feature], batch_size) else ob_key = Symbol(feature) - buffer_dict[ob_key] = make_batch_array(Float32, feature_dims, feature_size[feature], batch_size) + buffer_dict[ob_key] = + make_batch_array(Float32, feature_dims, feature_size[feature], batch_size) end end @@ -61,28 +71,33 @@ function batch_named_tuple!(dest::NamedTuple, src::NamedTuple, i::Int) end end -function make_transition(example::TFRecord.Example, feature_size::Dict{String, Tuple}) +function make_transition(example::TFRecord.Example, feature_size::Dict{String,Tuple}) f = example.features.feature - - observation_dict = Dict{Symbol, AbstractArray}() - next_observation_dict = Dict{Symbol, AbstractArray}() - transition_dict = Dict{Symbol, Any}() + + observation_dict = Dict{Symbol,AbstractArray}() + next_observation_dict = Dict{Symbol,AbstractArray}() + transition_dict = Dict{Symbol,Any}() for feature in keys(feature_size) if split(feature, "/")[1] == "observation" - ob_key = Symbol(chop(feature, head = length("observation")+1, tail=0)) + ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0)) if split(feature, "/")[end] == "egocentric_camera" cam_feature_size = feature_size[feature] ob_size = prod(cam_feature_size) - observation_dict[ob_key] = reshape(f[feature].bytes_list.value[1][1:ob_size], cam_feature_size...) - next_observation_dict[ob_key] = reshape(f[feature].bytes_list.value[1][ob_size+1:end], cam_feature_size...) + observation_dict[ob_key] = + reshape(f[feature].bytes_list.value[1][1:ob_size], cam_feature_size...) + next_observation_dict[ob_key] = reshape( + f[feature].bytes_list.value[1][ob_size+1:end], + cam_feature_size..., + ) else if feature_size[feature] == () observation_dict[ob_key] = f[feature].float_list.value else ob_size = feature_size[feature][1] observation_dict[ob_key] = f[feature].float_list.value[1:ob_size] - next_observation_dict[ob_key] = f[feature].float_list.value[ob_size+1:end] + next_observation_dict[ob_key] = + f[feature].float_list.value[ob_size+1:end] end end elseif feature == "action" @@ -138,28 +153,25 @@ function rl_unplugged_dm_dataset( shards; type = "dm_control_suite", is_shuffle = true, - shuffle_buffer_size=10_000, - tf_reader_bufsize=10_000, - tf_reader_sz=10_000, - batch_size=256, - n_preallocations=nthreads()*12 -) + shuffle_buffer_size = 10_000, + tf_reader_bufsize = 10_000, + tf_reader_sz = 10_000, + batch_size = 256, + n_preallocations = nthreads() * 12, +) n = nthreads() repo = "rl-unplugged-dm" - - folders= [ - @datadep_str "$repo-$game-$shard" - for shard in shards - ] - + + folders = [@datadep_str "$repo-$game-$shard" for shard in shards] + ch_files = Channel{String}(length(folders)) do ch for folder in cycle(folders) file = folder * "/$(readdir(folder)[1])" put!(ch, file) end end - + if is_shuffle files = buffered_shuffle(ch_files, length(folders)) else @@ -173,11 +185,11 @@ function rl_unplugged_dm_dataset( Threads.foreach( TFRecord.read( fs; - compression=:gzip, - bufsize=tf_reader_bufsize, - channel_size=tf_reader_sz, + compression = :gzip, + bufsize = tf_reader_bufsize, + channel_size = tf_reader_sz, ); - schedule=Threads.StaticSchedule() + schedule = Threads.StaticSchedule(), ) do x put!(ch, make_transition(x, feature_size)) end @@ -185,21 +197,18 @@ function rl_unplugged_dm_dataset( end if is_shuffle - transitions = buffered_shuffle( - ch_src, - shuffle_buffer_size - ) + transitions = buffered_shuffle(ch_src, shuffle_buffer_size) else transitions = ch_src end - + taskref = Ref{Task}() - + buffer_dict = dm_buffer_dict(feature_size, batch_size) buffer = NamedTuple(buffer_dict) - res = RingBuffer(buffer;taskref=taskref, sz=n_preallocations) do buff + res = RingBuffer(buffer; taskref = taskref, sz = n_preallocations) do buff Threads.@threads for i in 1:batch_size batch_named_tuple!(buff, take!(transitions), i) end @@ -208,4 +217,4 @@ function rl_unplugged_dm_dataset( bind(ch_src, taskref[]) bind(ch_files, taskref[]) res -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl b/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl index 4c09e678e..5a62247de 100644 --- a/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl +++ b/src/ReinforcementLearningDatasets/src/rl_unplugged/util.jl @@ -27,7 +27,7 @@ Therefore, it acts as a channel that holds a shuffled buffer which is of type Ve - `buffer::Vector{T}`, The shuffled buffer. - `rng<:AbstractRNG`. """ -struct BufferedShuffle{T, R<:AbstractRNG} <: AbstractChannel{T} +struct BufferedShuffle{T,R<:AbstractRNG} <: AbstractChannel{T} src::Channel{T} buffer::Vector{T} rng::R @@ -43,7 +43,11 @@ Arguments: - `buffer_size::Int`. The size of the buffered channel. - `rng<:AbstractRNG` = Random.GLOBAL_RNG. """ -function buffered_shuffle(src::Channel{T}, buffer_size::Int;rng=Random.GLOBAL_RNG) where T +function buffered_shuffle( + src::Channel{T}, + buffer_size::Int; + rng = Random.GLOBAL_RNG, +) where {T} buffer = Array{T}(undef, buffer_size) p = Progress(buffer_size) Threads.@threads for i in 1:buffer_size @@ -70,7 +74,7 @@ function Base.take!(b::BufferedShuffle) end end -function Base.iterate(b::BufferedShuffle, state=nothing) +function Base.iterate(b::BufferedShuffle, state = nothing) try return (popfirst!(b), nothing) catch e @@ -104,14 +108,14 @@ Return a RingBuffer that gives batches with the specs in `buffer`. - `buffer::T`: the type containing the batch. - `sz::Int`:size of the internal buffers. """ -function RingBuffer(f!, buffer::T;sz=Threads.nthreads(), taskref=nothing) where T +function RingBuffer(f!, buffer::T; sz = Threads.nthreads(), taskref = nothing) where {T} buffers = Channel{T}(sz) for _ in 1:sz put!(buffers, deepcopy(buffer)) end - results = Channel{T}(sz, spawn=true, taskref=taskref) do ch - Threads.foreach(buffers;schedule=Threads.StaticSchedule()) do x - # for x in buffers + results = Channel{T}(sz, spawn = true, taskref = taskref) do ch + Threads.foreach(buffers; schedule = Threads.StaticSchedule()) do x + # for x in buffers f!(x) # in-place operation put!(ch, x) end diff --git a/src/ReinforcementLearningDatasets/test/atari_dataset.jl b/src/ReinforcementLearningDatasets/test/atari_dataset.jl index eb6c03175..b89c63ef3 100644 --- a/src/ReinforcementLearningDatasets/test/atari_dataset.jl +++ b/src/ReinforcementLearningDatasets/test/atari_dataset.jl @@ -13,11 +13,11 @@ rng = StableRNG(123) "pong", index, epochs; - repo="atari-replay-datasets", + repo = "atari-replay-datasets", style = style, rng = rng, is_shuffle = true, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset @@ -64,11 +64,11 @@ end "pong", index, epochs; - repo="atari-replay-datasets", + repo = "atari-replay-datasets", style = style, rng = rng, is_shuffle = false, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset @@ -118,4 +118,4 @@ end @test data_dict[:reward][batch_size+1:batch_size*2] == iter2[:reward] @test data_dict[:terminal][batch_size+1:batch_size*2] == iter2[:terminal] @test data_dict[:state][:, :, batch_size+2:batch_size*2+1] == iter2[:next_state] -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/bsuite.jl b/src/ReinforcementLearningDatasets/test/bsuite.jl index 0bceeb7c0..36fd647e0 100644 --- a/src/ReinforcementLearningDatasets/test/bsuite.jl +++ b/src/ReinforcementLearningDatasets/test/bsuite.jl @@ -10,13 +10,13 @@ tf_reader_bufsize = 10_000, tf_reader_sz = 10_000, batch_size = 256, - n_preallocations = Threads.nthreads() * 12 + n_preallocations = Threads.nthreads() * 12, ) - @test typeof(ds)<:RingBuffer + @test typeof(ds) <: RingBuffer s_size = 6 - + data_1 = take!(ds) @test size(data_1.state) == (s_size, batch_size) @@ -25,11 +25,11 @@ @test size(data_1.reward) == (batch_size,) @test size(data_1.terminal) == (batch_size,) - @test typeof(data_1.state) == Array{Float32, 2} - @test typeof(data_1.next_state) == Array{Float32, 2} - @test typeof(data_1.action) == Array{Int, 1} - @test typeof(data_1.reward) == Array{Float32, 1} - @test typeof(data_1.terminal) == Array{Bool, 1} + @test typeof(data_1.state) == Array{Float32,2} + @test typeof(data_1.next_state) == Array{Float32,2} + @test typeof(data_1.action) == Array{Int,1} + @test typeof(data_1.reward) == Array{Float32,1} + @test typeof(data_1.terminal) == Array{Bool,1} end @@ -44,13 +44,13 @@ tf_reader_bufsize = 10_000, tf_reader_sz = 10_000, batch_size = 256, - n_preallocations = Threads.nthreads() * 12 + n_preallocations = Threads.nthreads() * 12, ) - @test typeof(ds)<:RingBuffer + @test typeof(ds) <: RingBuffer s_size = 6 - + data_1 = take!(ds) @test size(data_1.state) == (s_size, batch_size) @@ -59,11 +59,11 @@ @test size(data_1.reward) == (batch_size,) @test size(data_1.terminal) == (batch_size,) - @test typeof(data_1.state) == Array{Float32, 2} - @test typeof(data_1.next_state) == Array{Float32, 2} - @test typeof(data_1.action) == Array{Int, 1} - @test typeof(data_1.reward) == Array{Float32, 1} - @test typeof(data_1.terminal) == Array{Bool, 1} + @test typeof(data_1.state) == Array{Float32,2} + @test typeof(data_1.next_state) == Array{Float32,2} + @test typeof(data_1.action) == Array{Int,1} + @test typeof(data_1.reward) == Array{Float32,1} + @test typeof(data_1.terminal) == Array{Bool,1} end @@ -78,13 +78,13 @@ tf_reader_bufsize = 10_000, tf_reader_sz = 10_000, batch_size = 256, - n_preallocations = Threads.nthreads() * 12 + n_preallocations = Threads.nthreads() * 12, ) - @test typeof(ds)<:RingBuffer + @test typeof(ds) <: RingBuffer s_size = (10, 5) - + data_1 = take!(ds) @test size(data_1.state) == (s_size[1], s_size[2], batch_size) @@ -93,11 +93,11 @@ @test size(data_1.reward) == (batch_size,) @test size(data_1.terminal) == (batch_size,) - @test typeof(data_1.state) == Array{Float32, 3} - @test typeof(data_1.next_state) == Array{Float32, 3} - @test typeof(data_1.action) == Array{Int, 1} - @test typeof(data_1.reward) == Array{Float32, 1} - @test typeof(data_1.terminal) == Array{Bool, 1} + @test typeof(data_1.state) == Array{Float32,3} + @test typeof(data_1.next_state) == Array{Float32,3} + @test typeof(data_1.action) == Array{Int,1} + @test typeof(data_1.reward) == Array{Float32,1} + @test typeof(data_1.terminal) == Array{Bool,1} end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl b/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl index a146647c7..245bfae92 100644 --- a/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl +++ b/src/ReinforcementLearningDatasets/test/d4rl_pybullet.jl @@ -2,11 +2,11 @@ using Base: batch_size_err_str @testset "d4rl_pybullet" begin ds = dataset( "hopper-bullet-mixed-v0"; - repo="d4rl-pybullet", + repo = "d4rl-pybullet", style = style, rng = rng, is_shuffle = true, - batch_size = batch_size + batch_size = batch_size, ) n_s = 15 @@ -23,10 +23,12 @@ using Base: batch_size_err_str for sample in Iterators.take(ds, 3) @test typeof(sample) <: NamedTuple{SARTS} - @test size(sample[:state]) == (n_s, batch_size) - @test size(sample[:action]) == (n_a, batch_size) - @test size(sample[:reward]) == (1, batch_size) || size(sample[:reward]) == (batch_size,) - @test size(sample[:terminal]) == (1, batch_size) || size(sample[:terminal]) == (batch_size,) + @test size(sample[:state]) == (n_s, batch_size) + @test size(sample[:action]) == (n_a, batch_size) + @test size(sample[:reward]) == (1, batch_size) || + size(sample[:reward]) == (batch_size,) + @test size(sample[:terminal]) == (1, batch_size) || + size(sample[:terminal]) == (batch_size,) end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/dataset.jl b/src/ReinforcementLearningDatasets/test/dataset.jl index 9a726c6ab..a2ac0f3a9 100644 --- a/src/ReinforcementLearningDatasets/test/dataset.jl +++ b/src/ReinforcementLearningDatasets/test/dataset.jl @@ -8,11 +8,11 @@ rng = MersenneTwister(123) @testset "dataset_shuffle" begin ds = dataset( "hopper-medium-replay-v0"; - repo="d4rl", + repo = "d4rl", style = style, rng = rng, is_shuffle = true, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset @@ -58,7 +58,7 @@ end style = style, rng = rng, is_shuffle = false, - batch_size = batch_size + batch_size = batch_size, ) data_dict = ds.dataset diff --git a/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl b/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl index 39afb2b32..6105513de 100644 --- a/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl +++ b/src/ReinforcementLearningDatasets/test/deep_ope_d4rl.jl @@ -6,7 +6,7 @@ using UnicodePlots @testset "d4rl_policies" begin model = d4rl_policy("ant", "online", 10) - @test typeof(model) <: D4RLGaussianNetwork + @test typeof(model) <: D4RLGaussianNetwork env = GymEnv("ant-medium-v0") @@ -16,6 +16,6 @@ using UnicodePlots end @testset "d4rl_policy_evaluate" begin - plt = deep_ope_d4rl_evaluate("halfcheetah", "online", 10; num_evaluations=100) + plt = deep_ope_d4rl_evaluate("halfcheetah", "online", 10; num_evaluations = 100) @test typeof(plt) <: UnicodePlots.Plot -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl b/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl index 949a22b83..d2b01028f 100644 --- a/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl +++ b/src/ReinforcementLearningDatasets/test/rl_unplugged_atari.jl @@ -4,13 +4,13 @@ 1, [1, 2]; shuffle_buffer_size = 10_000, - tf_reader_bufsize = 1*1024*1024, + tf_reader_bufsize = 1 * 1024 * 1024, tf_reader_sz = 10_000, batch_size = 256, - n_preallocations = Threads.nthreads() * 12 + n_preallocations = Threads.nthreads() * 12, ) - @test typeof(ds)<:RingBuffer + @test typeof(ds) <: RingBuffer data_1 = take!(ds) @@ -26,8 +26,8 @@ @test size(data_1.episode_id) == (batch_size,) @test size(data_1.episode_return) == (batch_size,) - @test typeof(data_1.state) == Array{UInt8, 4} - @test typeof(data_1.next_state) == Array{UInt8, 4} + @test typeof(data_1.state) == Array{UInt8,4} + @test typeof(data_1.next_state) == Array{UInt8,4} @test typeof(data_1.action) == Vector{Int64} @test typeof(data_1.next_action) == Vector{Int64} @test typeof(data_1.reward) == Vector{Float32} @@ -39,4 +39,4 @@ take!(ds) data_2 = take!(ds) -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl b/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl index 897ce16dd..5c3c7ded7 100644 --- a/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl +++ b/src/ReinforcementLearningDatasets/test/rl_unplugged_dm.jl @@ -4,37 +4,40 @@ using Base.Threads ds = rl_unplugged_dm_dataset( "fish_swim", [1, 2]; - type="dm_control_suite", + type = "dm_control_suite", is_shuffle = true, - shuffle_buffer_size=10_000, - tf_reader_bufsize=10_000, - tf_reader_sz=10_000, - batch_size=256, - n_preallocations=nthreads()*12 + shuffle_buffer_size = 10_000, + tf_reader_bufsize = 10_000, + tf_reader_sz = 10_000, + batch_size = 256, + n_preallocations = nthreads() * 12, ) - @test typeof(ds)<:RingBuffer + @test typeof(ds) <: RingBuffer data = take!(ds) - + batch_size = 256 feature_size = ReinforcementLearningDatasets.DM_CONTROL_SUITE_SIZE["fish_swim"] - + @test typeof(data.state) <: NamedTuple @test typeof(data.next_state) <: NamedTuple - + for feature in keys(feature_size) if split(feature, "/")[1] != "observation" if feature != "step_type" ob_key = Symbol(feature) - @test size(getfield(data, ob_key)) == (feature_size[feature]..., batch_size,) + @test size(getfield(data, ob_key)) == + (feature_size[feature]..., batch_size) end else state = data.state next_state = data.next_state - ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0)) - @test size(getfield(state, ob_key)) == (feature_size[feature]...,batch_size) - @test size(getfield(next_state, ob_key)) == (feature_size[feature]..., batch_size,) + ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0)) + @test size(getfield(state, ob_key)) == + (feature_size[feature]..., batch_size) + @test size(getfield(next_state, ob_key)) == + (feature_size[feature]..., batch_size) end end end @@ -43,37 +46,40 @@ using Base.Threads ds = rl_unplugged_dm_dataset( "humanoid_corridor", [1, 2]; - type="dm_locomotion_humanoid", + type = "dm_locomotion_humanoid", is_shuffle = true, - shuffle_buffer_size=10_000, - tf_reader_bufsize=10_000, - tf_reader_sz=10_000, - batch_size=256, - n_preallocations=nthreads()*12 + shuffle_buffer_size = 10_000, + tf_reader_bufsize = 10_000, + tf_reader_sz = 10_000, + batch_size = 256, + n_preallocations = nthreads() * 12, ) - @test typeof(ds)<:RingBuffer + @test typeof(ds) <: RingBuffer data = take!(ds) - + batch_size = 256 feature_size = ReinforcementLearningDatasets.DM_LOCOMOTION_HUMANOID_SIZE - + @test typeof(data.state) <: NamedTuple @test typeof(data.next_state) <: NamedTuple - + for feature in keys(feature_size) if split(feature, "/")[1] != "observation" if feature != "step_type" ob_key = Symbol(feature) - @test size(getfield(data, ob_key)) == (feature_size[feature]..., batch_size,) + @test size(getfield(data, ob_key)) == + (feature_size[feature]..., batch_size) end else state = data.state next_state = data.next_state - ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0)) - @test size(getfield(state, ob_key)) == (feature_size[feature]..., batch_size,) - @test size(getfield(next_state, ob_key)) == (feature_size[feature]..., batch_size,) + ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0)) + @test size(getfield(state, ob_key)) == + (feature_size[feature]..., batch_size) + @test size(getfield(next_state, ob_key)) == + (feature_size[feature]..., batch_size) end end end @@ -82,38 +88,41 @@ using Base.Threads ds = rl_unplugged_dm_dataset( "rodent_escape", [1, 2]; - type="dm_locomotion_rodent", + type = "dm_locomotion_rodent", is_shuffle = true, - shuffle_buffer_size=10_000, - tf_reader_bufsize=10_000, - tf_reader_sz=10_000, - batch_size=256, - n_preallocations=nthreads()*12 + shuffle_buffer_size = 10_000, + tf_reader_bufsize = 10_000, + tf_reader_sz = 10_000, + batch_size = 256, + n_preallocations = nthreads() * 12, ) - @test typeof(ds)<:RingBuffer + @test typeof(ds) <: RingBuffer data = take!(ds) - + batch_size = 256 feature_size = ReinforcementLearningDatasets.DM_LOCOMOTION_RODENT_SIZE - + @test typeof(data.state) <: NamedTuple @test typeof(data.next_state) <: NamedTuple - + for feature in keys(feature_size) if split(feature, "/")[1] != "observation" if feature != "step_type" ob_key = Symbol(feature) - @test size(getfield(data, ob_key)) == (feature_size[feature]..., batch_size,) + @test size(getfield(data, ob_key)) == + (feature_size[feature]..., batch_size) end else state = data.state next_state = data.next_state - ob_key = Symbol(chop(feature, head=length("observation")+1, tail=0)) - @test size(getfield(state, ob_key)) == (feature_size[feature]..., batch_size,) - @test size(getfield(next_state, ob_key)) == (feature_size[feature]..., batch_size,) + ob_key = Symbol(chop(feature, head = length("observation") + 1, tail = 0)) + @test size(getfield(state, ob_key)) == + (feature_size[feature]..., batch_size) + @test size(getfield(next_state, ob_key)) == + (feature_size[feature]..., batch_size) end end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl index f3d2dd7e9..dfd1f0ee8 100644 --- a/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl +++ b/src/ReinforcementLearningEnvironments/src/ReinforcementLearningEnvironments.jl @@ -4,7 +4,7 @@ using ReinforcementLearningBase using Random using Requires using IntervalSets -using Base.Threads:@spawn +using Base.Threads: @spawn using Markdown const RLEnvs = ReinforcementLearningEnvironments @@ -31,9 +31,7 @@ function __init__() @require OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" include( "environments/3rd_party/AcrobotEnv.jl", ) - @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include( - "plots.jl", - ) + @require Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plots.jl") end diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl index c371bb80b..eac7957c7 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/AcrobotEnv.jl @@ -18,23 +18,23 @@ AcrobotEnv(;kwargs...) - `avail_torque = [T(-1.), T(0.), T(1.)]` """ function AcrobotEnv(; - T=Float64, - link_length_a=T(1.0), - link_length_b=T(1.0), - link_mass_a=T(1.0), - link_mass_b=T(1.0), - link_com_pos_a=T(0.5), - link_com_pos_b=T(0.5), - link_moi=T(1.0), - max_torque_noise=T(0.0), - max_vel_a=T(4 * π), - max_vel_b=T(9 * π), - g=T(9.8), - dt=T(0.2), - max_steps=200, - rng=Random.GLOBAL_RNG, - book_or_nips="book", - avail_torque=[T(-1.0), T(0.0), T(1.0)], + T = Float64, + link_length_a = T(1.0), + link_length_b = T(1.0), + link_mass_a = T(1.0), + link_mass_b = T(1.0), + link_com_pos_a = T(0.5), + link_com_pos_b = T(0.5), + link_moi = T(1.0), + max_torque_noise = T(0.0), + max_vel_a = T(4 * π), + max_vel_b = T(9 * π), + g = T(9.8), + dt = T(0.2), + max_steps = 200, + rng = Random.GLOBAL_RNG, + book_or_nips = "book", + avail_torque = [T(-1.0), T(0.0), T(1.0)], ) params = AcrobotEnvParams{T}( @@ -81,7 +81,7 @@ RLBase.is_terminated(env::AcrobotEnv) = env.done RLBase.state(env::AcrobotEnv) = acrobot_observation(env.state) RLBase.reward(env::AcrobotEnv) = env.reward -function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number} +function RLBase.reset!(env::AcrobotEnv{T}) where {T<:Number} env.state[:] = T(0.1) * rand(env.rng, T, 4) .- T(0.05) env.t = 0 env.action = 2 @@ -91,7 +91,7 @@ function RLBase.reset!(env::AcrobotEnv{T}) where {T <: Number} end # governing equations as per python gym -function (env::AcrobotEnv{T})(a) where {T <: Number} +function (env::AcrobotEnv{T})(a) where {T<:Number} env.action = a env.t += 1 torque = env.avail_torque[a] @@ -137,7 +137,7 @@ function dsdt(du, s_augmented, env::AcrobotEnv, t) # extract action and state a = s_augmented[end] - s = s_augmented[1:end - 1] + s = s_augmented[1:end-1] # writing in standard form theta1 = s[1] @@ -201,7 +201,7 @@ function wrap(x, m, M) while x < m x = x + diff end -return x + return x end function bound(x, m, M) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl index 33ac9ea86..95a1b3e8e 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/gym.jl @@ -1,13 +1,15 @@ using .PyCall -function GymEnv(name::String; seed::Union{Int, Nothing}=nothing) +function GymEnv(name::String; seed::Union{Int,Nothing} = nothing) if !PyCall.pyexists("gym") error( "Cannot import module 'gym'.\n\nIf you did not yet install it, try running\n`ReinforcementLearningEnvironments.install_gym()`\n", ) end gym = pyimport_conda("gym", "gym") - if PyCall.pyexists("d4rl") pyimport("d4rl") end + if PyCall.pyexists("d4rl") + pyimport("d4rl") + end pyenv = try gym.make(name) catch e @@ -15,7 +17,9 @@ function GymEnv(name::String; seed::Union{Int, Nothing}=nothing) "Gym environment $name not found.\n\nRun `ReinforcementLearningEnvironments.list_gym_env_names()` to find supported environments.\n", ) end - if seed !== nothing pyenv.seed(seed) end + if seed !== nothing + pyenv.seed(seed) + end obs_space = space_transform(pyenv.observation_space) act_space = space_transform(pyenv.action_space) obs_type = if obs_space isa Space{<:Union{Array{<:Interval},Array{<:ZeroTo}}} @@ -139,8 +143,10 @@ function list_gym_env_names(; "d4rl.gym_bullet.gym_envs", "d4rl.pointmaze_bullet.bullet_maze", # yet to include flow and carla ], -) - if PyCall.pyexists("d4rl") pyimport("d4rl") end +) + if PyCall.pyexists("d4rl") + pyimport("d4rl") + end gym = pyimport("gym") [x.id for x in gym.envs.registry.all() if split(x.entry_point, ':')[1] in modules] end diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl index c325b65d6..de8c15f14 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl @@ -44,7 +44,7 @@ import .OpenSpiel: `True` or `False` (instead of `true` or `false`). Another approach is to just specify parameters in `kwargs` in the Julia style. """ -function OpenSpielEnv(name="kuhn_poker"; kwargs...) +function OpenSpielEnv(name = "kuhn_poker"; kwargs...) game = load_game(String(name); kwargs...) state = new_initial_state(game) OpenSpielEnv(state, game) @@ -60,7 +60,7 @@ RLBase.current_player(env::OpenSpielEnv) = OpenSpiel.current_player(env.state) RLBase.chance_player(env::OpenSpielEnv) = convert(Int, OpenSpiel.CHANCE_PLAYER) function RLBase.players(env::OpenSpielEnv) - p = 0:(num_players(env.game) - 1) + p = 0:(num_players(env.game)-1) if ChanceStyle(env) === EXPLICIT_STOCHASTIC (p..., RLBase.chance_player(env)) else @@ -91,7 +91,7 @@ function RLBase.prob(env::OpenSpielEnv, player) # @assert player == chance_player(env) p = zeros(length(action_space(env))) for (k, v) in chance_outcomes(env.state) - p[k + 1] = v + p[k+1] = v end p end @@ -102,7 +102,7 @@ function RLBase.legal_action_space_mask(env::OpenSpielEnv, player) num_distinct_actions(env.game) mask = BitArray(undef, n) for a in legal_actions(env.state, player) - mask[a + 1] = true + mask[a+1] = true end mask end @@ -138,12 +138,16 @@ end _state(env::OpenSpielEnv, ::RLBase.InformationSet{String}, player) = information_state_string(env.state, player) -_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) = - reshape(information_state_tensor(env.state, player), reverse(information_state_tensor_shape(env.game))...) +_state(env::OpenSpielEnv, ::RLBase.InformationSet{Array}, player) = reshape( + information_state_tensor(env.state, player), + reverse(information_state_tensor_shape(env.game))..., +) _state(env::OpenSpielEnv, ::Observation{String}, player) = observation_string(env.state, player) -_state(env::OpenSpielEnv, ::Observation{Array}, player) = - reshape(observation_tensor(env.state, player), reverse(observation_tensor_shape(env.game))...) +_state(env::OpenSpielEnv, ::Observation{Array}, player) = reshape( + observation_tensor(env.state, player), + reverse(observation_tensor_shape(env.game))..., +) RLBase.state_space( env::OpenSpielEnv, @@ -151,16 +155,18 @@ RLBase.state_space( p, ) = WorldSpace{AbstractString}() -RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array}, - p, -) = Space( - fill(typemin(Float64)..typemax(Float64), reverse(information_state_tensor_shape(env.game))...), +RLBase.state_space(env::OpenSpielEnv, ::InformationSet{Array}, p) = Space( + fill( + typemin(Float64) .. typemax(Float64), + reverse(information_state_tensor_shape(env.game))..., + ), ) -RLBase.state_space(env::OpenSpielEnv, ::Observation{Array}, - p, -) = Space( - fill(typemin(Float64)..typemax(Float64), reverse(observation_tensor_shape(env.game))...), +RLBase.state_space(env::OpenSpielEnv, ::Observation{Array}, p) = Space( + fill( + typemin(Float64) .. typemax(Float64), + reverse(observation_tensor_shape(env.game))..., + ), ) Random.seed!(env::OpenSpielEnv, s) = @warn "seed!(OpenSpielEnv) is not supported currently." @@ -201,7 +207,9 @@ RLBase.RewardStyle(env::OpenSpielEnv) = reward_model(get_type(env.game)) == OpenSpiel.REWARDS ? RLBase.STEP_REWARD : RLBase.TERMINAL_REWARD -RLBase.StateStyle(env::OpenSpielEnv) = (RLBase.InformationSet{String}(), +RLBase.StateStyle(env::OpenSpielEnv) = ( + RLBase.InformationSet{String}(), RLBase.InformationSet{Array}(), RLBase.Observation{String}(), - RLBase.Observation{Array}(),) + RLBase.Observation{Array}(), +) diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl index df3a9e87f..02cfdc201 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/snake.jl @@ -43,7 +43,7 @@ end RLBase.action_space(env::SnakeGameEnv) = 1:4 RLBase.state(env::SnakeGameEnv) = env.game.board -RLBase.state_space(env::SnakeGameEnv) = Space(fill(false..true, size(env.game.board))) +RLBase.state_space(env::SnakeGameEnv) = Space(fill(false .. true, size(env.game.board))) RLBase.reward(env::SnakeGameEnv{<:Any,SINGLE_AGENT}) = length(env.game.snakes[]) - env.latest_snakes_length[] RLBase.reward(env::SnakeGameEnv) = length.(env.game.snakes) .- env.latest_snakes_length diff --git a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl index 83586f4e3..0acd51427 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/3rd_party/structs.jl @@ -6,7 +6,7 @@ struct GymEnv{T,Ta,To,P} <: AbstractEnv end export GymEnv -mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S <: AbstractRNG} <: AbstractEnv +mutable struct AtariEnv{IsGrayScale,TerminalOnLifeLoss,N,S<:AbstractRNG} <: AbstractEnv ale::Ptr{Nothing} name::String screens::Tuple{Array{UInt8,N},Array{UInt8,N}} # for max-pooling @@ -65,7 +65,7 @@ end export AcrobotEnvParams -mutable struct AcrobotEnv{T,R <: AbstractRNG} <: AbstractEnv +mutable struct AcrobotEnv{T,R<:AbstractRNG} <: AbstractEnv params::AcrobotEnvParams{T} state::Vector{T} action::Int diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl index 2c491bf63..74d0056c1 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/BitFlippingEnv.jl @@ -41,8 +41,8 @@ end RLBase.state(env::BitFlippingEnv) = state(env::BitFlippingEnv, Observation{BitArray{1}}()) RLBase.state(env::BitFlippingEnv, ::Observation) = env.state RLBase.state(env::BitFlippingEnv, ::GoalState) = env.goal_state -RLBase.state_space(env::BitFlippingEnv, ::Observation) = Space(fill(false..true, env.N)) -RLBase.state_space(env::BitFlippingEnv, ::GoalState) = Space(fill(false..true, env.N)) +RLBase.state_space(env::BitFlippingEnv, ::Observation) = Space(fill(false .. true, env.N)) +RLBase.state_space(env::BitFlippingEnv, ::GoalState) = Space(fill(false .. true, env.N)) RLBase.is_terminated(env::BitFlippingEnv) = (env.state == env.goal_state) || (env.t >= env.max_steps) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl index fd8c5af49..dd0cc04e9 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/GraphShortestPathEnv.jl @@ -5,7 +5,7 @@ using SparseArrays using LinearAlgebra -mutable struct GraphShortestPathEnv{G, R} <: AbstractEnv +mutable struct GraphShortestPathEnv{G,R} <: AbstractEnv graph::G pos::Int goal::Int @@ -31,7 +31,12 @@ Quoted **A.3** in the the paper [Decision Transformer: Reinforcement Learning vi > lengths and maximizing them corresponds to generating shortest paths. """ -function GraphShortestPathEnv(rng=Random.GLOBAL_RNG; n=20, sparsity=0.1, max_steps=10) +function GraphShortestPathEnv( + rng = Random.GLOBAL_RNG; + n = 20, + sparsity = 0.1, + max_steps = 10, +) graph = sprand(rng, Bool, n, n, sparsity) .| I(n) goal = rand(rng, 1:n) @@ -55,7 +60,8 @@ RLBase.state_space(env::GraphShortestPathEnv) = axes(env.graph, 2) RLBase.action_space(env::GraphShortestPathEnv) = axes(env.graph, 2) RLBase.legal_action_space(env::GraphShortestPathEnv) = (env.graph[:, env.pos]).nzind RLBase.reward(env::GraphShortestPathEnv) = env.reward -RLBase.is_terminated(env::GraphShortestPathEnv) = env.pos == env.goal || env.step >= env.max_steps +RLBase.is_terminated(env::GraphShortestPathEnv) = + env.pos == env.goal || env.step >= env.max_steps function RLBase.reset!(env::GraphShortestPathEnv) env.step = 0 @@ -144,4 +150,4 @@ barplot(1:10, [sum(h[1].steps .== i) for i in 1:10]) # random walk # 10 ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 769 # └ ┘ # -=# \ No newline at end of file +=# diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl index 5cafab01d..18094e54a 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/MountainCarEnv.jl @@ -77,7 +77,7 @@ function MountainCarEnv(; env = MountainCarEnv( params, action_space, - Space([params.min_pos..params.max_pos, -params.max_speed..params.max_speed]), + Space([params.min_pos .. params.max_pos, -params.max_speed .. params.max_speed]), zeros(T, 2), rand(action_space), false, diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl index 72dea511c..bacdc04e1 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/PendulumEnv.jl @@ -53,7 +53,7 @@ function PendulumEnv(; rng = Random.GLOBAL_RNG, ) high = T.([1, 1, max_speed]) - action_space = continuous ? -2.0..2.0 : Base.OneTo(n_actions) + action_space = continuous ? -2.0 .. 2.0 : Base.OneTo(n_actions) env = PendulumEnv( PendulumEnvParams(max_speed, max_torque, g, m, l, dt, max_steps), action_space, diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl index 1026ac8cc..0f38fb624 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/PigEnv.jl @@ -36,7 +36,7 @@ RLBase.prob(env::PigEnv, ::ChancePlayer) = fill(1 / 6, 6) # TODO: uniform distr RLBase.state(env::PigEnv, ::Observation{Vector{Int}}, p) = env.scores RLBase.state_space(env::PigEnv, ::Observation, p) = - Space([0..(PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores]) + Space([0 .. (PIG_TARGET_SCORE + PIG_N_SIDES - 1) for _ in env.scores]) RLBase.is_terminated(env::PigEnv) = any(s >= PIG_TARGET_SCORE for s in env.scores) diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl index f4b9572b8..f42d7217b 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/SpeakerListenerEnv.jl @@ -7,9 +7,9 @@ mutable struct SpeakerListenerEnv{T<:Vector{Float64}} <: AbstractEnv player_pos::T landmarks_pos::Vector{T} landmarks_num::Int - ϵ - damping - max_accel + ϵ::Any + damping::Any + max_accel::Any space_dim::Int init_step::Int play_step::Int @@ -46,7 +46,8 @@ function SpeakerListenerEnv(; max_accel = 0.5, space_dim::Int = 2, max_steps::Int = 50, - continuous::Bool = true) + continuous::Bool = true, +) SpeakerListenerEnv( zeros(N), zeros(N), @@ -74,21 +75,24 @@ function RLBase.reset!(env::SpeakerListenerEnv) env.landmarks_pos = [zeros(env.space_dim) for _ in Base.OneTo(env.landmarks_num)] end -RLBase.is_terminated(env::SpeakerListenerEnv) = (reward(env) > - env.ϵ) || (env.play_step > env.max_steps) +RLBase.is_terminated(env::SpeakerListenerEnv) = + (reward(env) > -env.ϵ) || (env.play_step > env.max_steps) RLBase.players(::SpeakerListenerEnv) = (:Speaker, :Listener, CHANCE_PLAYER) -RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) = Dict(p => state(env, p) for p in players) +RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) = + Dict(p => state(env, p) for p in players) -RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) = - # for speaker, it can observe the target and help listener to arrive it. +RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) = +# for speaker, it can observe the target and help listener to arrive it. if player == :Speaker env.target - # for listener, it can observe current velocity, relative positions of landmarks, and speaker's conveyed information. + # for listener, it can observe current velocity, relative positions of landmarks, and speaker's conveyed information. elseif player == :Listener vcat( env.player_vel..., ( - vcat((landmark_pos .- env.player_pos)...) for landmark_pos in env.landmarks_pos + vcat((landmark_pos .- env.player_pos)...) for + landmark_pos in env.landmarks_pos )..., env.content..., ) @@ -96,47 +100,60 @@ RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) = @error "No player $player." end -RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) = vcat(env.landmarks_pos, [env.player_pos]) +RLBase.state(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) = + vcat(env.landmarks_pos, [env.player_pos]) -RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) = +RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, players::Tuple) = Space(Dict(player => state_space(env, player) for player in players)) -RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) = +RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, player::Symbol) = if player == :Speaker # env.target - Space([[0., 1.] for _ in Base.OneTo(env.landmarks_num)]) + Space([[0.0, 1.0] for _ in Base.OneTo(env.landmarks_num)]) elseif player == :Listener - Space(vcat( - # relative positions of landmarks, no bounds. - (vcat( - Space([ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim)])... - ) for _ in Base.OneTo(env.landmarks_num + 1))..., - # communication content from `Speaker` - [[0., 1.] for _ in Base.OneTo(env.landmarks_num)], - )) + Space( + vcat( + # relative positions of landmarks, no bounds. + ( + vcat( + Space([ + ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim) + ])..., + ) for _ in Base.OneTo(env.landmarks_num + 1) + )..., + # communication content from `Speaker` + [[0.0, 1.0] for _ in Base.OneTo(env.landmarks_num)], + ), + ) else @error "No player $player." end -RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) = - Space( - vcat( - # landmarks' positions - (Space([ClosedInterval(-1, 1) for _ in Base.OneTo(env.space_dim)]) for _ in Base.OneTo(env.landmarks_num))..., - # player's position, no bounds. - Space([ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim)]), - ) - ) - -RLBase.action_space(env::SpeakerListenerEnv, players::Tuple) = - Space(Dict(p => action_space(env, p) for p in players)) - -RLBase.action_space(env::SpeakerListenerEnv, player::Symbol) = +RLBase.state_space(env::SpeakerListenerEnv, ::Observation{Any}, ::ChancePlayer) = Space( + vcat( + # landmarks' positions + ( + Space([ClosedInterval(-1, 1) for _ in Base.OneTo(env.space_dim)]) for + _ in Base.OneTo(env.landmarks_num) + )..., + # player's position, no bounds. + Space([ClosedInterval(-Inf, Inf) for _ in Base.OneTo(env.space_dim)]), + ), +) + +RLBase.action_space(env::SpeakerListenerEnv, players::Tuple) = + Space(Dict(p => action_space(env, p) for p in players)) + +RLBase.action_space(env::SpeakerListenerEnv, player::Symbol) = if player == :Speaker - env.continuous ? Space([ClosedInterval(0, 1) for _ in Base.OneTo(env.landmarks_num)]) : Space([ZeroTo(1) for _ in Base.OneTo(env.landmarks_num)]) + env.continuous ? + Space([ClosedInterval(0, 1) for _ in Base.OneTo(env.landmarks_num)]) : + Space([ZeroTo(1) for _ in Base.OneTo(env.landmarks_num)]) elseif player == :Listener # there has two directions in each dimension. - env.continuous ? Space([ClosedInterval(0, 1) for _ in Base.OneTo(2 * env.space_dim)]) : Space([ZeroTo(1) for _ in Base.OneTo(2 * env.space_dim)]) + env.continuous ? + Space([ClosedInterval(0, 1) for _ in Base.OneTo(2 * env.space_dim)]) : + Space([ZeroTo(1) for _ in Base.OneTo(2 * env.space_dim)]) else @error "No player $player." end @@ -157,7 +174,7 @@ function (env::SpeakerListenerEnv)(action, ::ChancePlayer) env.player_pos = action else @assert action in Base.OneTo(env.landmarks_num) "The target should be assigned to one of the landmarks." - env.target[action] = 1. + env.target[action] = 1.0 end end @@ -176,7 +193,7 @@ function (env::SpeakerListenerEnv)(action::Vector, player::Symbol) elseif player == :Listener # update velocity, here env.damping is for simulation physical rule. action = round.(action) - acceleration = [action[2 * i] - action[2 * i - 1] for i in Base.OneTo(env.space_dim)] + acceleration = [action[2*i] - action[2*i-1] for i in Base.OneTo(env.space_dim)] env.player_vel .*= (1 - env.damping) env.player_vel .+= (acceleration * env.max_accel) # update position @@ -190,14 +207,14 @@ RLBase.reward(::SpeakerListenerEnv, ::ChancePlayer) = -Inf function RLBase.reward(env::SpeakerListenerEnv, p) if sum(env.target) == 1 - goal = findfirst(env.target .== 1.) + goal = findfirst(env.target .== 1.0) -sum((env.landmarks_pos[goal] .- env.player_pos) .^ 2) else -Inf end end -RLBase.current_player(env::SpeakerListenerEnv) = +RLBase.current_player(env::SpeakerListenerEnv) = if env.init_step < env.landmarks_num + 2 CHANCE_PLAYER else diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl index 6a3f6a6f6..0dc94ec4b 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/StockTradingEnv.jl @@ -2,12 +2,12 @@ export StockTradingEnv, StockTradingEnvWithTurbulence using Pkg.Artifacts using DelimitedFiles -using LinearAlgebra:dot +using LinearAlgebra: dot using IntervalSets function load_default_stock_data(s) if s == "prices.csv" || s == "features.csv" - data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header=true) + data, _ = readdlm(joinpath(artifact"stock_trading_data", s), ',', header = true) collect(data') elseif s == "turbulence.csv" readdlm(joinpath(artifact"stock_trading_data", "turbulence.csv")) |> vec @@ -16,7 +16,8 @@ function load_default_stock_data(s) end end -mutable struct StockTradingEnv{F<:AbstractMatrix{Float64}, P<:AbstractMatrix{Float64}} <: AbstractEnv +mutable struct StockTradingEnv{F<:AbstractMatrix{Float64},P<:AbstractMatrix{Float64}} <: + AbstractEnv features::F prices::P HMAX_NORMALIZE::Float32 @@ -48,14 +49,14 @@ This environment is originally provided in [Deep Reinforcement Learning for Auto - `initial_account_balance=1_000_000`. """ function StockTradingEnv(; - initial_account_balance=1_000_000f0, - features=nothing, - prices=nothing, - first_day=nothing, - last_day=nothing, - HMAX_NORMALIZE = 100f0, + initial_account_balance = 1_000_000.0f0, + features = nothing, + prices = nothing, + first_day = nothing, + last_day = nothing, + HMAX_NORMALIZE = 100.0f0, TRANSACTION_FEE_PERCENT = 0.001f0, - REWARD_SCALING = 1f-4 + REWARD_SCALING = 1f-4, ) prices = isnothing(prices) ? load_default_stock_data("prices.csv") : prices features = isnothing(features) ? load_default_stock_data("features.csv") : features @@ -77,11 +78,11 @@ function StockTradingEnv(; REWARD_SCALING, initial_account_balance, state, - 0f0, + 0.0f0, day, first_day, last_day, - 0f0 + 0.0f0, ) _balance(env)[] = initial_account_balance @@ -108,10 +109,10 @@ function (env::StockTradingEnv)(actions) # then buy # better to shuffle? - for (i,b) in enumerate(actions) + for (i, b) in enumerate(actions) if b > 0 max_buy = div(_balance(env)[], _prices(env)[i]) - buy = min(b*env.HMAX_NORMALIZE, max_buy) + buy = min(b * env.HMAX_NORMALIZE, max_buy) _holds(env)[i] += buy deduction = buy * _prices(env)[i] cost = deduction * env.TRANSACTION_FEE_PERCENT @@ -136,12 +137,13 @@ function RLBase.reset!(env::StockTradingEnv) _balance(env)[] = env.initial_account_balance _prices(env) .= @view env.prices[:, env.day] _features(env) .= @view env.features[:, env.day] - env.total_cost = 0. - env.daily_reward = 0. + env.total_cost = 0.0 + env.daily_reward = 0.0 end -RLBase.state_space(env::StockTradingEnv) = Space(fill(-Inf32..Inf32, length(state(env)))) -RLBase.action_space(env::StockTradingEnv) = Space(fill(-1f0..1f0, length(_holds(env)))) +RLBase.state_space(env::StockTradingEnv) = Space(fill(-Inf32 .. Inf32, length(state(env)))) +RLBase.action_space(env::StockTradingEnv) = + Space(fill(-1.0f0 .. 1.0f0, length(_holds(env)))) RLBase.ChanceStyle(::StockTradingEnv) = DETERMINISTIC @@ -154,16 +156,16 @@ struct StockTradingEnvWithTurbulence{E<:StockTradingEnv} <: AbstractEnvWrapper end function StockTradingEnvWithTurbulence(; - turbulence_threshold=140., - turbulences=nothing, - kw... + turbulence_threshold = 140.0, + turbulences = nothing, + kw..., ) turbulences = isnothing(turbulences) && load_default_stock_data("turbulence.csv") StockTradingEnvWithTurbulence( - StockTradingEnv(;kw...), + StockTradingEnv(; kw...), turbulences, - turbulence_threshold + turbulence_threshold, ) end @@ -172,4 +174,4 @@ function (w::StockTradingEnvWithTurbulence)(actions) actions .= ifelse.(actions .< 0, -Inf32, 0) end w.env(actions) -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl index 455488619..d8301904e 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/examples/TicTacToeEnv.jl @@ -76,7 +76,7 @@ RLBase.players(env::TicTacToeEnv) = (CROSS, NOUGHT) RLBase.state(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = env.board RLBase.state_space(env::TicTacToeEnv, ::Observation{BitArray{3}}, p) = - Space(fill(false..true, 3, 3, 3)) + Space(fill(false .. true, 3, 3, 3)) RLBase.state(env::TicTacToeEnv, ::Observation{Int}, p) = get_tic_tac_toe_state_info()[env].index RLBase.state_space(env::TicTacToeEnv, ::Observation{Int}, p) = diff --git a/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl b/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl index 8e6aef27d..a1621f1e5 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/non_interactive/pendulum.jl @@ -70,7 +70,7 @@ RLBase.reward(env::PendulumNonInteractiveEnv) = 0 RLBase.is_terminated(env::PendulumNonInteractiveEnv) = env.done RLBase.state(env::PendulumNonInteractiveEnv) = env.state RLBase.state_space(env::PendulumNonInteractiveEnv{T}) where {T} = - Space([typemin(T)..typemax(T), typemin(T)..typemax(T)]) + Space([typemin(T) .. typemax(T), typemin(T) .. typemax(T)]) function RLBase.reset!(env::PendulumNonInteractiveEnv{Fl}) where {Fl} env.state .= (Fl(2 * pi) * rand(env.rng, Fl), randn(env.rng, Fl)) diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl index 47efe8628..831f80f7d 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/ActionTransformedEnv.jl @@ -13,15 +13,14 @@ end `legal_action_space(env)`. `action_mapping` will be applied to `action` before feeding it into `env`. """ -ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) = +ActionTransformedEnv(env; action_mapping = identity, action_space_mapping = identity) = ActionTransformedEnv(env, action_mapping, action_space_mapping) -Base.copy(env::ActionTransformedEnv) = - ActionTransformedEnv( - copy(env.env), - action_mapping = env.action_mapping, - action_space_mapping = env.action_space_mapping - ) +Base.copy(env::ActionTransformedEnv) = ActionTransformedEnv( + copy(env.env), + action_mapping = env.action_mapping, + action_space_mapping = env.action_space_mapping, +) RLBase.action_space(env::ActionTransformedEnv, args...) = env.action_space_mapping(action_space(env.env, args...)) diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl index 5aff4a669..d75af9c34 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/DefaultStateStyle.jl @@ -13,11 +13,12 @@ DefaultStateStyleEnv{S}(env::E) where {S,E} = DefaultStateStyleEnv{S,E}(env) RLBase.DefaultStateStyle(::DefaultStateStyleEnv{S}) where {S} = S -Base.copy(env::DefaultStateStyleEnv{S}) where S = DefaultStateStyleEnv{S}(copy(env.env)) +Base.copy(env::DefaultStateStyleEnv{S}) where {S} = DefaultStateStyleEnv{S}(copy(env.env)) -RLBase.state(env::DefaultStateStyleEnv{S}) where S = state(env.env, S) +RLBase.state(env::DefaultStateStyleEnv{S}) where {S} = state(env.env, S) RLBase.state(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state(env.env, ss) -RLBase.state(env::DefaultStateStyleEnv{S}, player) where S = state(env.env, S, player) +RLBase.state(env::DefaultStateStyleEnv{S}, player) where {S} = state(env.env, S, player) -RLBase.state_space(env::DefaultStateStyleEnv{S}) where S = state_space(env.env, S) -RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = state_space(env.env, ss) \ No newline at end of file +RLBase.state_space(env::DefaultStateStyleEnv{S}) where {S} = state_space(env.env, S) +RLBase.state_space(env::DefaultStateStyleEnv, ss::RLBase.AbstractStateStyle) = + state_space(env.env, ss) diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl index 4f18af426..2a88903dd 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/SequentialEnv.jl @@ -9,7 +9,7 @@ mutable struct SequentialEnv{E<:AbstractEnv} <: AbstractEnvWrapper env::E current_player_idx::Int actions::Vector{Any} - function SequentialEnv(env::T) where T<:AbstractEnv + function SequentialEnv(env::T) where {T<:AbstractEnv} @assert DynamicStyle(env) === SIMULTANEOUS "The SequentialEnv wrapper can only be applied to SIMULTANEOUS environments" new{T}(env, 1, Vector{Any}(undef, length(players(env)))) end @@ -32,7 +32,8 @@ end RLBase.reward(env::SequentialEnv) = reward(env, current_player(env)) -RLBase.reward(env::SequentialEnv, player) = current_player(env) == 1 ? reward(env.env, player) : 0 +RLBase.reward(env::SequentialEnv, player) = + current_player(env) == 1 ? reward(env.env, player) : 0 function (env::SequentialEnv)(action) env.actions[env.current_player_idx] = action @@ -43,4 +44,3 @@ function (env::SequentialEnv)(action) env.current_player_idx += 1 end end - diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl index e8626a3b8..97e18e928 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateCachedEnv.jl @@ -6,7 +6,7 @@ the next interaction with `env`. This function is useful because some environments are stateful during each `state(env)`. For example: `StateTransformedEnv(StackFrames(...))`. """ -mutable struct StateCachedEnv{S,E <: AbstractEnv} <: AbstractEnvWrapper +mutable struct StateCachedEnv{S,E<:AbstractEnv} <: AbstractEnvWrapper s::S env::E is_state_cached::Bool diff --git a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl index dfe90bddd..840c9ecdb 100644 --- a/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl +++ b/src/ReinforcementLearningEnvironments/src/environments/wrappers/StateTransformedEnv.jl @@ -12,11 +12,11 @@ end `state_mapping` will be applied on the original state when calling `state(env)`, and similarly `state_space_mapping` will be applied when calling `state_space(env)`. """ -StateTransformedEnv(env; state_mapping=identity, state_space_mapping=identity) = +StateTransformedEnv(env; state_mapping = identity, state_space_mapping = identity) = StateTransformedEnv(env, state_mapping, state_space_mapping) RLBase.state(env::StateTransformedEnv, args...; kwargs...) = env.state_mapping(state(env.env, args...; kwargs...)) -RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) = +RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) = env.state_space_mapping(state_space(env.env, args...; kwargs...)) diff --git a/src/ReinforcementLearningEnvironments/src/plots.jl b/src/ReinforcementLearningEnvironments/src/plots.jl index a2c3ce7bc..0e9e5cb61 100644 --- a/src/ReinforcementLearningEnvironments/src/plots.jl +++ b/src/ReinforcementLearningEnvironments/src/plots.jl @@ -8,36 +8,36 @@ function plot(env::CartPoleEnv; kwargs...) xthreshold = env.params.xthreshold # set the frame plot( - xlims=(-xthreshold, xthreshold), - ylims=(-.1, l + 0.1), - legend=false, - border=:none, + xlims = (-xthreshold, xthreshold), + ylims = (-.1, l + 0.1), + legend = false, + border = :none, ) # plot the cart - plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05]; - seriestype=:shape, - ) + plot!([x - 0.5, x - 0.5, x + 0.5, x + 0.5], [-.05, 0, 0, -.05]; seriestype = :shape) # plot the pole - plot!([x, x + l * sin(theta)], [0, l * cos(theta)]; - linewidth=3, - ) + plot!([x, x + l * sin(theta)], [0, l * cos(theta)]; linewidth = 3) # plot the arrow - plot!([x + (a == 1) - 0.5, x + 1.4 * (a == 1)-0.7], [ -.025, -.025]; - linewidth=3, - arrow=true, - color=2, + plot!( + [x + (a == 1) - 0.5, x + 1.4 * (a == 1) - 0.7], + [-.025, -.025]; + linewidth = 3, + arrow = true, + color = 2, ) # if done plot pink circle in top right if d - plot!([xthreshold - 0.2], [l]; - marker=:circle, - markersize=20, - markerstrokewidth=0., - color=:pink, + plot!( + [xthreshold - 0.2], + [l]; + marker = :circle, + markersize = 20, + markerstrokewidth = 0.0, + color = :pink, ) end - - plot!(;kwargs...) + + plot!(; kwargs...) end @@ -51,10 +51,10 @@ function plot(env::MountainCarEnv; kwargs...) d = env.done plot( - xlims=(env.params.min_pos - 0.1, env.params.max_pos + 0.2), - ylims=(-.1, height(env.params.max_pos) + 0.2), - legend=false, - border=:none, + xlims = (env.params.min_pos - 0.1, env.params.max_pos + 0.2), + ylims = (-.1, height(env.params.max_pos) + 0.2), + legend = false, + border = :none, ) # plot the terrain xs = LinRange(env.params.min_pos, env.params.max_pos, 100) @@ -72,17 +72,19 @@ function plot(env::MountainCarEnv; kwargs...) ys .+= clearance xs, ys = rotate(xs, ys, θ) xs, ys = translate(xs, ys, [x, height(x)]) - plot!(xs, ys; seriestype=:shape) + plot!(xs, ys; seriestype = :shape) # if done plot pink circle in top right if d - plot!([xthreshold - 0.2], [l]; - marker=:circle, - markersize=20, - markerstrokewidth=0., - color=:pink, + plot!( + [xthreshold - 0.2], + [l]; + marker = :circle, + markersize = 20, + markerstrokewidth = 0.0, + color = :pink, ) end - plot!(;kwargs...) - end + plot!(; kwargs...) +end diff --git a/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl b/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl index c7b3ab787..9c8b44de0 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/3rd_party/gym.jl @@ -1,10 +1,6 @@ @testset "gym envs" begin gym_env_names = ReinforcementLearningEnvironments.list_gym_env_names( - modules = [ - "gym.envs.algorithmic", - "gym.envs.classic_control", - "gym.envs.unittest", - ], + modules = ["gym.envs.algorithmic", "gym.envs.classic_control", "gym.envs.unittest"], ) # mujoco, box2d, robotics are not tested here for x in gym_env_names diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl index a912ecc24..e47ce9047 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/graph_shortest_path_env.jl @@ -6,4 +6,3 @@ RLBase.test_runnable!(env) end - diff --git a/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl b/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl index 7cb138328..c826e0e4e 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/examples/stock_trading_env.jl @@ -5,4 +5,3 @@ RLBase.test_interfaces!(env) RLBase.test_runnable!(env) end - diff --git a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl index 208307a67..049defa22 100644 --- a/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl +++ b/src/ReinforcementLearningEnvironments/test/environments/wrappers/wrappers.jl @@ -1,11 +1,11 @@ @testset "wrappers" begin @testset "ActionTransformedEnv" begin - env = TigerProblemEnv(; rng=StableRNG(123)) + env = TigerProblemEnv(; rng = StableRNG(123)) env′ = ActionTransformedEnv( env; - action_space_mapping=x -> Base.OneTo(3), - action_mapping=i -> action_space(env)[i], + action_space_mapping = x -> Base.OneTo(3), + action_mapping = i -> action_space(env)[i], ) RLBase.test_interfaces!(env′) @@ -14,7 +14,7 @@ @testset "DefaultStateStyleEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) S = InternalState{Int}() env′ = DefaultStateStyleEnv{S}(env) @test DefaultStateStyle(env′) === S @@ -35,7 +35,7 @@ @testset "MaxTimeoutEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) n = 100 env′ = MaxTimeoutEnv(env, n) @@ -55,7 +55,7 @@ @testset "RewardOverriddenEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) env′ = RewardOverriddenEnv(env, x -> sign(x)) RLBase.test_interfaces!(env′) @@ -69,7 +69,7 @@ @testset "StateCachedEnv" begin rng = StableRNG(123) - env = CartPoleEnv(; rng=rng) + env = CartPoleEnv(; rng = rng) env′ = StateCachedEnv(env) RLBase.test_interfaces!(env′) @@ -85,7 +85,7 @@ @testset "StateTransformedEnv" begin rng = StableRNG(123) - env = TigerProblemEnv(; rng=rng) + env = TigerProblemEnv(; rng = rng) # S = (:door1, :door2, :door3, :none) # env′ = StateTransformedEnv(env, state_mapping=s -> s+1) # RLBase.state_space(env::typeof(env′), ::RLBase.AbstractStateStyle, ::Any) = S @@ -97,14 +97,14 @@ @testset "StochasticEnv" begin env = KuhnPokerEnv() rng = StableRNG(123) - env′ = StochasticEnv(env; rng=rng) + env′ = StochasticEnv(env; rng = rng) RLBase.test_interfaces!(env′) RLBase.test_runnable!(env′) end @testset "SequentialEnv" begin - env = RockPaperScissorsEnv() + env = RockPaperScissorsEnv() env′ = SequentialEnv(env) RLBase.test_interfaces!(env′) RLBase.test_runnable!(env′) diff --git a/src/ReinforcementLearningExperiments/deps/build.jl b/src/ReinforcementLearningExperiments/deps/build.jl index def42b559..6dfde6249 100644 --- a/src/ReinforcementLearningExperiments/deps/build.jl +++ b/src/ReinforcementLearningExperiments/deps/build.jl @@ -2,10 +2,11 @@ using Weave const DEST_DIR = joinpath(@__DIR__, "..", "src", "experiments") -for (root, dirs, files) in walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments")) +for (root, dirs, files) in + walkdir(joinpath(@__DIR__, "..", "..", "..", "docs", "experiments")) for f in files if splitext(f)[2] == ".jl" - tangle(joinpath(root,f);informat="script", out_path=DEST_DIR) + tangle(joinpath(root, f); informat = "script", out_path = DEST_DIR) end end -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl index 0f615ba3f..2a273443a 100644 --- a/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl +++ b/src/ReinforcementLearningExperiments/src/ReinforcementLearningExperiments.jl @@ -23,7 +23,6 @@ for f in readdir(EXPERIMENTS_DIR) end # dynamic loading environments -function __init__() -end +function __init__() end end # module diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl index ddd2c6ace..793eaa897 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/common.jl @@ -4,7 +4,10 @@ const PERLearners = Union{PrioritizedDQNLearner,RainbowLearner,IQNLearner} -function RLBase.update!(learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners}, t::AbstractTrajectory) +function RLBase.update!( + learner::Union{DQNLearner,QRDQNLearner,REMDQNLearner,PERLearners}, + t::AbstractTrajectory, +) length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return learner.update_step += 1 diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl index b2ebab93c..ad2808f26 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqn.jl @@ -55,7 +55,7 @@ function DQNLearner(; traces = SARTS, update_step = 0, rng = Random.GLOBAL_RNG, - is_enable_double_DQN::Bool = true + is_enable_double_DQN::Bool = true, ) where {Tq,Tt,Tf} copyto!(approximator, target_approximator) sampler = NStepBatchSampler{traces}(; @@ -75,7 +75,7 @@ function DQNLearner(; sampler, rng, 0.0f0, - is_enable_double_DQN + is_enable_double_DQN, ) end @@ -117,14 +117,14 @@ function RLBase.update!(learner::DQNLearner, batch::NamedTuple) else q_values = Qₜ(s′) end - + if haskey(batch, :next_legal_actions_mask) l′ = send_to_device(D, batch[:next_legal_actions_mask]) q_values .+= ifelse.(l′, 0.0f0, typemin(Float32)) end if is_enable_double_DQN - selected_actions = dropdims(argmax(q_values, dims=1), dims=1) + selected_actions = dropdims(argmax(q_values, dims = 1), dims = 1) q′ = Qₜ(s′)[selected_actions] else q′ = dropdims(maximum(q_values; dims = 1), dims = 1) diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl index 4ec47c5ba..8f190ba2c 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/dqns.jl @@ -5,4 +5,4 @@ include("qr_dqn.jl") include("rem_dqn.jl") include("rainbow.jl") include("iqn.jl") -include("common.jl") \ No newline at end of file +include("common.jl") diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl index 2832b1905..0f6c1b519 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/qr_dqn.jl @@ -1,6 +1,6 @@ export QRDQNLearner, quantile_huber_loss -function quantile_huber_loss(ŷ, y; κ=1.0f0) +function quantile_huber_loss(ŷ, y; κ = 1.0f0) N, B = size(y) Δ = reshape(y, N, 1, B) .- reshape(ŷ, 1, N, B) abs_error = abs.(Δ) @@ -8,12 +8,13 @@ function quantile_huber_loss(ŷ, y; κ=1.0f0) linear = abs_error .- quadratic huber_loss = 0.5f0 .* quadratic .* quadratic .+ κ .* linear - cum_prob = send_to_device(device(y), range(0.5f0 / N; length=N, step=1.0f0 / N)) + cum_prob = send_to_device(device(y), range(0.5f0 / N; length = N, step = 1.0f0 / N)) loss = Zygote.dropgrad(abs.(cum_prob .- (Δ .< 0))) .* huber_loss - mean(sum(loss;dims=1)) + mean(sum(loss; dims = 1)) end -mutable struct QRDQNLearner{Tq <: AbstractApproximator,Tt <: AbstractApproximator,Tf,R} <: AbstractLearner +mutable struct QRDQNLearner{Tq<:AbstractApproximator,Tt<:AbstractApproximator,Tf,R} <: + AbstractLearner approximator::Tq target_approximator::Tt min_replay_history::Int @@ -51,25 +52,25 @@ See paper: [Distributional Reinforcement Learning with Quantile Regression](http function QRDQNLearner(; approximator, target_approximator, - stack_size::Union{Int,Nothing}=nothing, - γ::Float32=0.99f0, - batch_size::Int=32, - update_horizon::Int=1, - min_replay_history::Int=32, - update_freq::Int=1, - n_quantile::Int=1, - target_update_freq::Int=100, - traces=SARTS, - update_step=0, - loss_func=quantile_huber_loss, - rng=Random.GLOBAL_RNG + stack_size::Union{Int,Nothing} = nothing, + γ::Float32 = 0.99f0, + batch_size::Int = 32, + update_horizon::Int = 1, + min_replay_history::Int = 32, + update_freq::Int = 1, + n_quantile::Int = 1, + target_update_freq::Int = 100, + traces = SARTS, + update_step = 0, + loss_func = quantile_huber_loss, + rng = Random.GLOBAL_RNG, ) copyto!(approximator, target_approximator) sampler = NStepBatchSampler{traces}(; - γ=γ, - n=update_horizon, - stack_size=stack_size, - batch_size=batch_size, + γ = γ, + n = update_horizon, + stack_size = stack_size, + batch_size = batch_size, ) N = n_quantile @@ -100,7 +101,7 @@ function (learner::QRDQNLearner)(env) s = send_to_device(device(learner.approximator), state(env)) s = Flux.unsqueeze(s, ndims(s) + 1) q = reshape(learner.approximator(s), learner.n_quantile, :) - vec(mean(q, dims=1)) |> send_to_host + vec(mean(q, dims = 1)) |> send_to_host end function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple) @@ -117,10 +118,12 @@ function RLBase.update!(learner::QRDQNLearner, batch::NamedTuple) a = CartesianIndex.(a, 1:batch_size) target_quantiles = reshape(Qₜ(s′), N, :, batch_size) - qₜ = dropdims(mean(target_quantiles; dims=1); dims=1) - aₜ = dropdims(argmax(qₜ, dims=1); dims=1) + qₜ = dropdims(mean(target_quantiles; dims = 1); dims = 1) + aₜ = dropdims(argmax(qₜ, dims = 1); dims = 1) @views target_quantile_aₜ = target_quantiles[:, aₜ] - y = reshape(r, 1, batch_size) .+ γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ + y = + reshape(r, 1, batch_size) .+ + γ .* reshape(1 .- t, 1, batch_size) .* target_quantile_aₜ gs = gradient(params(Q)) do q = reshape(Q(s), N, :, batch_size) diff --git a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl index 182ce253f..08f270429 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/dqns/rem_dqn.jl @@ -120,7 +120,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple) target_q = Qₜ(s′) target_q = convex_polygon .* reshape(target_q, :, ensemble_num, batch_size) - target_q = dropdims(sum(target_q, dims=2), dims=2) + target_q = dropdims(sum(target_q, dims = 2), dims = 2) if haskey(batch, :next_legal_actions_mask) l′ = send_to_device(D, batch[:next_legal_actions_mask]) @@ -133,7 +133,7 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple) gs = gradient(params(Q)) do q = Q(s) q = convex_polygon .* reshape(q, :, ensemble_num, batch_size) - q = dropdims(sum(q, dims=2), dims=2)[a] + q = dropdims(sum(q, dims = 2), dims = 2)[a] loss = loss_func(G, q) ignore() do @@ -143,5 +143,4 @@ function RLBase.update!(learner::REMDQNLearner, batch::NamedTuple) end update!(Q, gs) -end - +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl index c7a270f3d..d01b0e44a 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/EDPolicy.jl @@ -24,7 +24,7 @@ performs the following update for each player: [Computing Approximate Equilibria in Sequential Adversarial Games by Exploitability Descent](https://arxiv.org/abs/1903.05614) """ -mutable struct EDPolicy{P<:NeuralNetworkApproximator, E<:AbstractExplorer} +mutable struct EDPolicy{P<:NeuralNetworkApproximator,E<:AbstractExplorer} opponent::Any learner::P explorer::E @@ -40,16 +40,16 @@ function (π::EDPolicy)(env::AbstractEnv) s = state(env) s = send_to_device(device(π.learner), Flux.unsqueeze(s, ndims(s) + 1)) logits = π.learner(s) |> vec |> send_to_host - ActionStyle(env) isa MinimalActionSet ? π.explorer(logits) : - π.explorer(logits, legal_action_space_mask(env)) + ActionStyle(env) isa MinimalActionSet ? π.explorer(logits) : + π.explorer(logits, legal_action_space_mask(env)) end function RLBase.prob(π::EDPolicy, env::AbstractEnv) s = @ignore state(env) |> - x-> send_to_device(device(π.learner), Flux.unsqueeze(x, ndims(x) + 1)) + x -> send_to_device(device(π.learner), Flux.unsqueeze(x, ndims(x) + 1)) logits = π.learner(s) |> vec |> send_to_host - ActionStyle(env) isa MinimalActionSet ? prob(π.explorer, logits) : - prob(π.explorer, logits, @ignore legal_action_space_mask(env)) + ActionStyle(env) isa MinimalActionSet ? prob(π.explorer, logits) : + prob(π.explorer, logits, @ignore legal_action_space_mask(env)) end function RLBase.prob(π::EDPolicy, env::AbstractEnv, action) @@ -66,12 +66,12 @@ function RLBase.prob(π::EDPolicy, env::AbstractEnv, action) end @error "action[$action] is not found in action space[$(action_space(env))]" end -end +end ## update policy function RLBase.update!( - π::EDPolicy, - Opponent_BR::BestResponsePolicy, + π::EDPolicy, + Opponent_BR::BestResponsePolicy, env::AbstractEnv, player::Any, ) @@ -79,10 +79,7 @@ function RLBase.update!( # construct policy vs best response policy_vs_br = PolicyVsBestReponse( - MultiAgentManager( - NamedPolicy(player, π), - NamedPolicy(π.opponent, Opponent_BR), - ), + MultiAgentManager(NamedPolicy(player, π), NamedPolicy(π.opponent, Opponent_BR)), env, player, ) @@ -94,7 +91,7 @@ function RLBase.update!( # compute expected reward from the start of `e` with policy_vs_best_reponse # baseline = ∑ₐ πᵢ(s, a) * q(s, a) baseline = @ignore [values_vs_br(policy_vs_br, e) for e in info_states] - + # Vector of shape `(length(info_states), length(action_space))` # compute expected reward from the start of `e` when playing each action. q_values = Flux.stack((q_value(π, policy_vs_br, e) for e in info_states), 1) @@ -106,18 +103,17 @@ function RLBase.update!( # get each info_state's loss # ∑ₐ πᵢ(s, a) * (q(s, a) - baseline), where baseline = ∑ₐ πᵢ(s, a) * q(s, a). - loss_per_state = - sum(policy_values .* advantage, dims=2) + loss_per_state = -sum(policy_values .* advantage, dims = 2) - sum(loss_per_state .* cfr_reach_prob) |> - x -> send_to_device(device(π.learner), x) + sum(loss_per_state .* cfr_reach_prob) |> x -> send_to_device(device(π.learner), x) end update!(π.learner, gs) end ## Supplement struct for Computing related results when player's policy versus opponent's best_response. -struct PolicyVsBestReponse{E, P<:MultiAgentManager} - info_reach_prob::Dict{E, Float64} - values_vs_br_cache::Dict{E, Float64} +struct PolicyVsBestReponse{E,P<:MultiAgentManager} + info_reach_prob::Dict{E,Float64} + values_vs_br_cache::Dict{E,Float64} player::Any policy::P end @@ -125,13 +121,8 @@ end function PolicyVsBestReponse(policy, env, player) E = typeof(env) - p = PolicyVsBestReponse( - Dict{E, Float64}(), - Dict{E, Float64}(), - player, - policy, - ) - + p = PolicyVsBestReponse(Dict{E,Float64}(), Dict{E,Float64}(), player, policy) + e = copy(env) RLBase.reset!(e) get_cfr_prob!(p, e) @@ -190,7 +181,7 @@ function values_vs_br(p::PolicyVsBestReponse, env::AbstractEnv) end function q_value(π::EDPolicy, p::PolicyVsBestReponse, env::AbstractEnv) - P, A = prob(π, env) , @ignore action_space(env) + P, A = prob(π, env), @ignore action_space(env) v = [] for (a, pₐ) in zip(A, P) value = pₐ == 0 ? pₐ : values_vs_br(p, @ignore child(env, a)) diff --git a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl index 53a3b72be..8b8c02567 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/exploitability_descent/exploitability_descent.jl @@ -9,7 +9,7 @@ export EDManager A special MultiAgentManager in which all agents use Exploitability Descent(ED) algorithm to play the game. """ mutable struct EDManager <: AbstractPolicy - agents::Dict{Any, EDPolicy} + agents::Dict{Any,EDPolicy} end ## interactions @@ -22,7 +22,8 @@ function (π::EDManager)(env::AbstractEnv) end end -RLBase.prob(π::EDManager, env::AbstractEnv, args...) = prob(π.agents[current_player(env)], env, args...) +RLBase.prob(π::EDManager, env::AbstractEnv, args...) = + prob(π.agents[current_player(env)], env, args...) ## run function function Base.run( diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl index 3c100545f..f713417af 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/abstract_nfsp.jl @@ -35,4 +35,4 @@ function Base.run( end hook(POST_EXPERIMENT_STAGE, policy, env) hook -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl index 8a975936f..d7379e8a2 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl @@ -20,8 +20,8 @@ See the paper https://arxiv.org/abs/1603.01121 for more details. mutable struct NFSPAgent <: AbstractPolicy rl_agent::Agent sl_agent::Agent - η - rng + η::Any + rng::Any update_freq::Int update_step::Int mode::Bool @@ -96,7 +96,7 @@ function (π::NFSPAgent)(::PostEpisodeStage, env::AbstractEnv, player::Any) if haskey(rl.trajectory, :legal_actions_mask) push!(rl.trajectory[:legal_actions_mask], legal_action_space_mask(env, player)) end - + # update the policy π.update_step += 1 if π.update_step % π.update_freq == 0 @@ -113,7 +113,7 @@ end function rl_learn!(policy::QBasedPolicy, t::AbstractTrajectory) learner = policy.learner length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return - + _, batch = sample(learner.rng, t, learner.sampler) if t isa PrioritizedTrajectory diff --git a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl index 2e152509c..ba3280e3f 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp_manager.jl @@ -7,7 +7,7 @@ export NFSPAgentManager A special MultiAgentManager in which all agents use NFSP policy to play the game. """ mutable struct NFSPAgentManager <: AbstractPolicy - agents::Dict{Any, NFSPAgent} + agents::Dict{Any,NFSPAgent} end ## interactions between the policy and env. @@ -20,7 +20,8 @@ function (π::NFSPAgentManager)(env::AbstractEnv) end end -RLBase.prob(π::NFSPAgentManager, env::AbstractEnv, args...) = prob(π.agents[current_player(env)], env, args...) +RLBase.prob(π::NFSPAgentManager, env::AbstractEnv, args...) = + prob(π.agents[current_player(env)], env, args...) ## update NFSPAgentManager function RLBase.update!(π::NFSPAgentManager, env::AbstractEnv) @@ -30,7 +31,10 @@ function RLBase.update!(π::NFSPAgentManager, env::AbstractEnv) update!(π.agents[current_player(env)], env) end -function (π::NFSPAgentManager)(stage::Union{PreEpisodeStage, PostEpisodeStage}, env::AbstractEnv) +function (π::NFSPAgentManager)( + stage::Union{PreEpisodeStage,PostEpisodeStage}, + env::AbstractEnv, +) @sync for (player, agent) in π.agents @async agent(stage, env, player) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl index ed232f846..856a88837 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BCQ.jl @@ -99,7 +99,7 @@ end function (l::BCQLearner)(env) s = send_to_device(device(l.policy), state(env)) s = Flux.unsqueeze(s, ndims(s) + 1) - s = repeat(s, outer=(1, 1, l.p)) + s = repeat(s, outer = (1, 1, l.p)) action = l.policy(s, decode(l.vae.model, s)) q_value = l.qnetwork1(vcat(s, action)) idx = argmax(q_value) @@ -128,11 +128,15 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS}) γ, τ, λ = l.γ, l.τ, l.λ - repeat_s′ = repeat(s′, outer=(1, 1, l.p)) + repeat_s′ = repeat(s′, outer = (1, 1, l.p)) repeat_a′ = l.target_policy(repeat_s′, decode(l.vae.model, repeat_s′)) q′_input = vcat(repeat_s′, repeat_a′) - q′ = maximum(λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)), dims=3) + q′ = maximum( + λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)), + dims = 3, + ) y = r .+ γ .* (1 .- t) .* vec(q′) @@ -143,7 +147,7 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS}) q_grad_1 = gradient(Flux.params(l.qnetwork1)) do q1 = l.qnetwork1(q_input) |> vec loss = mse(q1, y) - ignore() do + ignore() do l.critic_loss = loss end loss @@ -153,7 +157,7 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS}) q_grad_2 = gradient(Flux.params(l.qnetwork2)) do q2 = l.qnetwork2(q_input) |> vec loss = mse(q2, y) - ignore() do + ignore() do l.critic_loss += loss end loss @@ -165,7 +169,7 @@ function update_learner!(l::BCQLearner, batch::NamedTuple{SARTS}) sampled_action = decode(l.vae.model, s) perturbed_action = l.policy(s, sampled_action) actor_loss = -mean(l.qnetwork1(vcat(s, perturbed_action))) - ignore() do + ignore() do l.actor_loss = actor_loss end actor_loss diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl index 71321fdf3..40844f0c3 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/BEAR.jl @@ -33,7 +33,7 @@ mutable struct BEARLearner{ # Logging actor_loss::Float32 critic_loss::Float32 - mmd_loss + mmd_loss::Any end """ @@ -122,8 +122,8 @@ end function (l::BEARLearner)(env) s = send_to_device(device(l.policy), state(env)) s = Flux.unsqueeze(s, ndims(s) + 1) - s = repeat(s, outer=(1, 1, l.p)) - action = l.policy(l.rng, s; is_sampling=true) + s = repeat(s, outer = (1, 1, l.p)) + action = l.policy(l.rng, s; is_sampling = true) q_value = l.qnetwork1(vcat(s, action)) idx = argmax(q_value) action[idx] @@ -134,13 +134,17 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS}) γ, τ, λ = l.γ, l.τ, l.λ update_vae!(l, s, a) - - repeat_s′ = repeat(s′, outer=(1, 1, l.p)) - repeat_action′ = l.target_policy(l.rng, repeat_s′, is_sampling=true) + + repeat_s′ = repeat(s′, outer = (1, 1, l.p)) + repeat_action′ = l.target_policy(l.rng, repeat_s′, is_sampling = true) q′_input = vcat(repeat_s′, repeat_action′) - q′ = maximum(λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)), dims=3) + q′ = maximum( + λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)), + dims = 3, + ) y = r .+ γ .* (1 .- t) .* vec(q′) @@ -151,7 +155,7 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS}) q_grad_1 = gradient(Flux.params(l.qnetwork1)) do q1 = l.qnetwork1(q_input) |> vec loss = mse(q1, y) - ignore() do + ignore() do l.critic_loss = loss end loss @@ -161,30 +165,40 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS}) q_grad_2 = gradient(Flux.params(l.qnetwork2)) do q2 = l.qnetwork2(q_input) |> vec loss = mse(q2, y) - ignore() do + ignore() do l.critic_loss += loss end loss end update!(l.qnetwork2, q_grad_2) - repeat_s = repeat(s, outer=(1, 1, l.p)) - repeat_a = repeat(a, outer=(1, 1, l.p)) - repeat_q1 = mean(l.target_qnetwork1(vcat(repeat_s, repeat_a)), dims=(1, 3)) - repeat_q2 = mean(l.target_qnetwork2(vcat(repeat_s, repeat_a)), dims=(1, 3)) + repeat_s = repeat(s, outer = (1, 1, l.p)) + repeat_a = repeat(a, outer = (1, 1, l.p)) + repeat_q1 = mean(l.target_qnetwork1(vcat(repeat_s, repeat_a)), dims = (1, 3)) + repeat_q2 = mean(l.target_qnetwork2(vcat(repeat_s, repeat_a)), dims = (1, 3)) q = vec(min.(repeat_q1, repeat_q2)) alpha = exp(l.log_α.model[1]) # Train Policy p_grad = gradient(Flux.params(l.policy)) do - raw_sample_action = decode(l.vae.model, repeat(s, outer=(1, 1, l.sample_num)); is_normalize=false) # action_dim * batch_size * sample_num - raw_actor_action = l.policy(repeat(s, outer=(1, 1, l.sample_num)); is_sampling=true) # action_dim * batch_size * sample_num - - mmd_loss = maximum_mean_discrepancy_loss(raw_sample_action, raw_actor_action, l.kernel_type, l.mmd_σ) + raw_sample_action = decode( + l.vae.model, + repeat(s, outer = (1, 1, l.sample_num)); + is_normalize = false, + ) # action_dim * batch_size * sample_num + raw_actor_action = + l.policy(repeat(s, outer = (1, 1, l.sample_num)); is_sampling = true) # action_dim * batch_size * sample_num + + mmd_loss = maximum_mean_discrepancy_loss( + raw_sample_action, + raw_actor_action, + l.kernel_type, + l.mmd_σ, + ) actor_loss = mean(-q .+ alpha .* mmd_loss) - ignore() do + ignore() do l.actor_loss = actor_loss l.mmd_loss = mmd_loss end @@ -193,11 +207,11 @@ function RLBase.update!(l::BEARLearner, batch::NamedTuple{SARTS}) update!(l.policy, p_grad) # Update lagrange multiplier - l_grad = gradient(Flux.params(l.log_α)) do + l_grad = gradient(Flux.params(l.log_α)) do mean(-q .+ alpha .* (l.mmd_loss .- l.ε)) end update!(l.log_α, l_grad) - + clamp!(l.log_α.model, -5.0f0, l.max_log_α) # polyak averaging diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl index 102600f21..2a5ae8031 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/CRR.jl @@ -22,11 +22,7 @@ See paper: [Critic Regularized Regression](https://arxiv.org/abs/2006.15134). - `continuous::Bool`: type of action space. - `rng = Random.GLOBAL_RNG` """ -mutable struct CRRLearner{ - Aq<:ActorCritic, - At<:ActorCritic, - R<:AbstractRNG, -} <: AbstractLearner +mutable struct CRRLearner{Aq<:ActorCritic,At<:ActorCritic,R<:AbstractRNG} <: AbstractLearner approximator::Aq target_approximator::At γ::Float32 @@ -61,7 +57,7 @@ function CRRLearner(; target_update_freq::Int = 100, continuous::Bool, rng = Random.GLOBAL_RNG, -) where {Aq<:ActorCritic, At<:ActorCritic} +) where {Aq<:ActorCritic,At<:ActorCritic} copyto!(approximator, target_approximator) CRRLearner( approximator, @@ -95,7 +91,7 @@ function (learner::CRRLearner)(env) s = Flux.unsqueeze(s, ndims(s) + 1) s = send_to_device(device(learner), s) if learner.continuous - learner.approximator.actor(s; is_sampling=true) |> vec |> send_to_host + learner.approximator.actor(s; is_sampling = true) |> vec |> send_to_host else learner.approximator.actor(s) |> vec |> send_to_host end @@ -125,7 +121,7 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple) r = reshape(r, :, batch_size) t = reshape(t, :, batch_size) - target_a_t = target_AC.actor(s′; is_sampling=true) + target_a_t = target_AC.actor(s′; is_sampling = true) target_q_input = vcat(s′, target_a_t) expected_target_q = target_AC.critic(target_q_input) @@ -133,7 +129,7 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple) q_t = Matrix{Float32}(undef, learner.m, batch_size) for i in 1:learner.m - a_sample = AC.actor(s; is_sampling=true) + a_sample = AC.actor(s; is_sampling = true) q_t[i, :] = AC.critic(vcat(s, a_sample)) end @@ -142,14 +138,14 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple) # Critic loss qa_t = AC.critic(vcat(s, a)) critic_loss = Flux.Losses.mse(qa_t, target) - + # Actor loss log_π = AC.actor.model(s, a) if advantage_estimator == :max - advantage = qa_t .- maximum(q_t, dims=1) + advantage = qa_t .- maximum(q_t, dims = 1) elseif advantage_estimator == :mean - advantage = qa_t .- mean(q_t, dims=1) + advantage = qa_t .- mean(q_t, dims = 1) else error("Wrong parameter.") end @@ -168,7 +164,7 @@ function continuous_update!(learner::CRRLearner, batch::NamedTuple) learner.actor_loss = actor_loss learner.critic_loss = critic_loss end - + actor_loss + critic_loss end @@ -193,7 +189,7 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple) target_a_t = softmax(target_AC.actor(s′)) target_q_t = target_AC.critic(s′) - expected_target_q = sum(target_a_t .* target_q_t, dims=1) + expected_target_q = sum(target_a_t .* target_q_t, dims = 1) target = r .+ γ .* (1 .- t) .* expected_target_q @@ -203,14 +199,14 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple) q_t = AC.critic(s) qa_t = reshape(q_t[a], :, batch_size) critic_loss = Flux.Losses.mse(qa_t, target) - + # Actor loss a_t = softmax(AC.actor(s)) if advantage_estimator == :max - advantage = qa_t .- maximum(q_t, dims=1) + advantage = qa_t .- maximum(q_t, dims = 1) elseif advantage_estimator == :mean - advantage = qa_t .- mean(q_t, dims=1) + advantage = qa_t .- mean(q_t, dims = 1) else error("Wrong parameter.") end @@ -222,16 +218,16 @@ function discrete_update!(learner::CRRLearner, batch::NamedTuple) else error("Wrong parameter.") end - + actor_loss = mean(-log.(a_t[a]) .* actor_loss_coef) ignore() do learner.actor_loss = actor_loss learner.critic_loss = critic_loss end - + actor_loss + critic_loss end update!(AC, gs) -end \ No newline at end of file +end diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl index d1ef288cd..ee7f6be35 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/DiscreteBCQ.jl @@ -18,11 +18,8 @@ See paper: [Benchmarking Batch Deep Reinforcement Learning Algorithms](https://a - `update_step::Int = 0` - `rng = Random.GLOBAL_RNG` """ -mutable struct BCQDLearner{ - Aq<:ActorCritic, - At<:ActorCritic, - R<:AbstractRNG, -} <: AbstractLearner +mutable struct BCQDLearner{Aq<:ActorCritic,At<:ActorCritic,R<:AbstractRNG} <: + AbstractLearner approximator::Aq target_approximator::At γ::Float32 @@ -49,7 +46,7 @@ function BCQDLearner(; update_freq::Int = 10, update_step::Int = 0, rng = Random.GLOBAL_RNG, -) where {Aq<:ActorCritic, At<:ActorCritic} +) where {Aq<:ActorCritic,At<:ActorCritic} copyto!(approximator, target_approximator) BCQDLearner( approximator, @@ -79,8 +76,8 @@ function (learner::BCQDLearner)(env) s = Flux.unsqueeze(s, ndims(s) + 1) s = send_to_device(device(learner), s) q = learner.approximator.critic(s) - prob = softmax(learner.approximator.actor(s), dims=1) - mask = Float32.((prob ./ maximum(prob, dims=1)) .> learner.threshold) + prob = softmax(learner.approximator.actor(s), dims = 1) + mask = Float32.((prob ./ maximum(prob, dims = 1)) .> learner.threshold) new_q = q .* mask .+ (1.0f0 .- mask) .* -1f8 new_q |> vec |> send_to_host end @@ -98,9 +95,9 @@ function RLBase.update!(learner::BCQDLearner, batch::NamedTuple) t = reshape(t, :, batch_size) prob = softmax(AC.actor(s′)) - mask = Float32.((prob ./ maximum(prob, dims=1)) .> learner.threshold) + mask = Float32.((prob ./ maximum(prob, dims = 1)) .> learner.threshold) q′ = AC.critic(s′) - a′ = argmax(q′ .* mask .+ (1.0f0 .- mask) .* -1f8, dims=1) + a′ = argmax(q′ .* mask .+ (1.0f0 .- mask) .* -1f8, dims = 1) target_q = target_AC.critic(s′) target = r .+ γ .* (1 .- t) .* target_q[a′] @@ -111,27 +108,25 @@ function RLBase.update!(learner::BCQDLearner, batch::NamedTuple) q_t = AC.critic(s) qa_t = reshape(q_t[a], :, batch_size) critic_loss = Flux.Losses.huber_loss(qa_t, target) - + # Actor loss logit = AC.actor(s) - log_prob = -log.(softmax(logit, dims=1)) + log_prob = -log.(softmax(logit, dims = 1)) actor_loss = mean(log_prob[a]) ignore() do learner.actor_loss = actor_loss learner.critic_loss = critic_loss end - + actor_loss + critic_loss + θ * mean(logit .^ 2) end update!(AC, gs) # polyak averaging - for (dest, src) in zip( - Flux.params([learner.target_approximator]), - Flux.params([learner.approximator]), - ) + for (dest, src) in + zip(Flux.params([learner.target_approximator]), Flux.params([learner.approximator])) dest .= (1 - τ) .* dest .+ τ .* src end end diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl index 34c50913e..89602e3e1 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/FisherBRC.jl @@ -91,7 +91,8 @@ function FisherBRCLearner(; ) copyto!(qnetwork1, target_qnetwork1) # force sync copyto!(qnetwork2, target_qnetwork2) # force sync - entropy_behavior_policy = EntropyBC(behavior_policy, 0.0f0, behavior_lr_alpha, Float32(-action_dims), 0.0f0) + entropy_behavior_policy = + EntropyBC(behavior_policy, 0.0f0, behavior_lr_alpha, Float32(-action_dims), 0.0f0) FisherBRCLearner( policy, entropy_behavior_policy, @@ -111,8 +112,8 @@ function FisherBRCLearner(; lr_alpha, Float32(-action_dims), rng, - 0f0, - 0f0, + 0.0f0, + 0.0f0, ) end @@ -120,7 +121,7 @@ function (l::FisherBRCLearner)(env) D = device(l.policy) s = send_to_device(D, state(env)) s = Flux.unsqueeze(s, ndims(s) + 1) - action = dropdims(l.policy(l.rng, s; is_sampling=true), dims=2) + action = dropdims(l.policy(l.rng, s; is_sampling = true), dims = 2) end function RLBase.update!(l::FisherBRCLearner, batch::NamedTuple{SARTS}) @@ -137,7 +138,7 @@ function update_behavior_policy!(l::EntropyBC, batch::NamedTuple{SARTS}) ps = Flux.params(l.policy) gs = gradient(ps) do log_π = l.policy.model(s, a) - _, entropy = l.policy.model(s; is_sampling=true, is_return_log_prob=true) + _, entropy = l.policy.model(s; is_sampling = true, is_return_log_prob = true) loss = mean(l.α .* entropy .- log_π) # Update entropy ignore() do @@ -154,7 +155,7 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS}) r .+= l.reward_bonus γ, τ, α = l.γ, l.τ, l.α - a′ = l.policy(l.rng, s′; is_sampling=true) + a′ = l.policy(l.rng, s′; is_sampling = true) q′_input = vcat(s′, a′) target_q′ = min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) @@ -164,16 +165,16 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS}) a = reshape(a, :, l.batch_size) q_input = vcat(s, a) log_μ = l.behavior_policy.policy.model(s, a) |> vec - a_policy = l.policy(l.rng, s; is_sampling=true) + a_policy = l.policy(l.rng, s; is_sampling = true) q_grad_1 = gradient(Flux.params(l.qnetwork1)) do q1 = l.qnetwork1(q_input) |> vec - q1_grad_norm = gradient(Flux.params([a_policy])) do + q1_grad_norm = gradient(Flux.params([a_policy])) do q1_reg = mean(l.qnetwork1(vcat(s, a_policy))) end reg = mean(q1_grad_norm[a_policy] .^ 2) loss = mse(q1 .+ log_μ, y) + l.f_reg * reg - ignore() do + ignore() do l.qnetwork_loss = loss end loss @@ -182,12 +183,12 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS}) q_grad_2 = gradient(Flux.params(l.qnetwork2)) do q2 = l.qnetwork2(q_input) |> vec - q2_grad_norm = gradient(Flux.params([a_policy])) do + q2_grad_norm = gradient(Flux.params([a_policy])) do q2_reg = mean(l.qnetwork2(vcat(s, a_policy))) end reg = mean(q2_grad_norm[a_policy] .^ 2) loss = mse(q2 .+ log_μ, y) + l.f_reg * reg - ignore() do + ignore() do l.qnetwork_loss += loss end loss @@ -196,7 +197,7 @@ function update_learner!(l::FisherBRCLearner, batch::NamedTuple{SARTS}) # Train Policy p_grad = gradient(Flux.params(l.policy)) do - a, log_π = l.policy(l.rng, s; is_sampling=true, is_return_log_prob=true) + a, log_π = l.policy(l.rng, s; is_sampling = true, is_return_log_prob = true) q_input = vcat(s, a) q = min.(l.qnetwork1(q_input), l.qnetwork2(q_input)) .+ log_μ policy_loss = mean(α .* log_π .- q) diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl index 8ced04148..1bcfd6477 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/PLAS.jl @@ -96,7 +96,7 @@ function (l::PLASLearner)(env) s = send_to_device(device(l.policy), state(env)) s = Flux.unsqueeze(s, ndims(s) + 1) latent_action = tanh.(l.policy(s)) - action = dropdims(decode(l.vae.model, s, latent_action), dims=2) + action = dropdims(decode(l.vae.model, s, latent_action), dims = 2) end function RLBase.update!(l::PLASLearner, batch::NamedTuple{SARTS}) @@ -125,7 +125,9 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS}) latent_action′ = tanh.(l.target_policy(s′)) action′ = decode(l.vae.model, s′, latent_action′) q′_input = vcat(s′, action′) - q′ = λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + q′ = + λ .* min.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) + + (1 - λ) .* max.(l.target_qnetwork1(q′_input), l.target_qnetwork2(q′_input)) y = r .+ γ .* (1 .- t) .* vec(q′) @@ -136,7 +138,7 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS}) q_grad_1 = gradient(Flux.params(l.qnetwork1)) do q1 = l.qnetwork1(q_input) |> vec loss = mse(q1, y) - ignore() do + ignore() do l.critic_loss = loss end loss @@ -146,7 +148,7 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS}) q_grad_2 = gradient(Flux.params(l.qnetwork2)) do q2 = l.qnetwork2(q_input) |> vec loss = mse(q2, y) - ignore() do + ignore() do l.critic_loss += loss end loss @@ -158,7 +160,7 @@ function update_learner!(l::PLASLearner, batch::NamedTuple{SARTS}) latent_action = tanh.(l.policy(s)) action = decode(l.vae.model, s, latent_action) actor_loss = -mean(l.qnetwork1(vcat(s, action))) - ignore() do + ignore() do l.actor_loss = actor_loss end actor_loss diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl index c5b7bdd09..48c021035 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/behavior_cloning.jl @@ -19,19 +19,14 @@ end - `rng = Random.GLOBAL_RNG` """ function BehaviorCloningPolicy(; - approximator::A, - explorer::AbstractExplorer = GreedyExplorer(), - batch_size::Int = 32, - min_reservoir_history::Int = 100, - rng = Random.GLOBAL_RNG + approximator::A, + explorer::AbstractExplorer = GreedyExplorer(), + batch_size::Int = 32, + min_reservoir_history::Int = 100, + rng = Random.GLOBAL_RNG, ) where {A} sampler = BatchSampler{(:state, :action)}(batch_size; rng = rng) - BehaviorCloningPolicy( - approximator, - explorer, - sampler, - min_reservoir_history, - ) + BehaviorCloningPolicy(approximator, explorer, sampler, min_reservoir_history) end function (p::BehaviorCloningPolicy)(env::AbstractEnv) @@ -39,7 +34,8 @@ function (p::BehaviorCloningPolicy)(env::AbstractEnv) s_batch = Flux.unsqueeze(s, ndims(s) + 1) s_batch = send_to_device(device(p.approximator), s_batch) logits = p.approximator(s_batch) |> vec |> send_to_host # drop dimension - typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) : p.explorer(logits, legal_action_space_mask(env)) + typeof(ActionStyle(env)) == MinimalActionSet ? p.explorer(logits) : + p.explorer(logits, legal_action_space_mask(env)) end function RLBase.update!(p::BehaviorCloningPolicy, batch::NamedTuple{(:state, :action)}) @@ -65,7 +61,8 @@ function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv) m = p.approximator s_batch = send_to_device(device(m), Flux.unsqueeze(s, ndims(s) + 1)) values = m(s_batch) |> vec |> send_to_host - typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) : prob(p.explorer, values, legal_action_space_mask(env)) + typeof(ActionStyle(env)) == MinimalActionSet ? prob(p.explorer, values) : + prob(p.explorer, values, legal_action_space_mask(env)) end function RLBase.prob(p::BehaviorCloningPolicy, env::AbstractEnv, action) diff --git a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl index 3c8401dd6..5bf0e0b7f 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/offline_rl/common.jl @@ -3,11 +3,11 @@ export OfflinePolicy, AtariRLTransition export calculate_CQL_loss, maximum_mean_discrepancy_loss struct AtariRLTransition - state - action - reward - terminal - next_state + state::Any + action::Any + reward::Any + terminal::Any + next_state::Any end Base.@kwdef struct OfflinePolicy{L,T} <: AbstractPolicy @@ -26,7 +26,8 @@ function (π::OfflinePolicy)(env, ::MinimalActionSet, ::Base.OneTo) findmax(π.learner(env))[2] end end -(π::OfflinePolicy)(env, ::FullActionSet, ::Base.OneTo) = findmax(π.learner(env), legal_action_space_mask(env))[2] +(π::OfflinePolicy)(env, ::FullActionSet, ::Base.OneTo) = + findmax(π.learner(env), legal_action_space_mask(env))[2] function (π::OfflinePolicy)(env, ::MinimalActionSet, A) if π.continuous @@ -35,7 +36,8 @@ function (π::OfflinePolicy)(env, ::MinimalActionSet, A) A[findmax(π.learner(env))[2]] end end -(π::OfflinePolicy)(env, ::FullActionSet, A) = A[findmax(π.learner(env), legal_action_space_mask(env))[2]] +(π::OfflinePolicy)(env, ::FullActionSet, A) = + A[findmax(π.learner(env), legal_action_space_mask(env))[2]] function RLBase.update!( p::OfflinePolicy, @@ -62,7 +64,8 @@ function RLBase.update!( l = p.learner l.update_step += 1 - if in(:target_update_freq, fieldnames(typeof(l))) && l.update_step % l.target_update_freq == 0 + if in(:target_update_freq, fieldnames(typeof(l))) && + l.update_step % l.target_update_freq == 0 copyto!(l.target_approximator, l.approximator) end @@ -99,19 +102,30 @@ end calculate_CQL_loss(q_value, action; method) See paper: [Conservative Q-Learning for Offline Reinforcement Learning](https://arxiv.org/abs/2006.04779) """ -function calculate_CQL_loss(q_value::Matrix{T}, action::Vector{R}; method = "CQL(H)") where {T, R} +function calculate_CQL_loss( + q_value::Matrix{T}, + action::Vector{R}; + method = "CQL(H)", +) where {T,R} if method == "CQL(H)" - cql_loss = mean(log.(sum(exp.(q_value), dims=1)) .- q_value[action]) + cql_loss = mean(log.(sum(exp.(q_value), dims = 1)) .- q_value[action]) else @error Wrong method parameter end return cql_loss end -function maximum_mean_discrepancy_loss(raw_sample_action, raw_actor_action, type::Symbol, mmd_σ::Float32=10.0f0) +function maximum_mean_discrepancy_loss( + raw_sample_action, + raw_actor_action, + type::Symbol, + mmd_σ::Float32 = 10.0f0, +) A, B, N = size(raw_sample_action) - diff_xx = reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_sample_action, A, B, 1, N) - diff_xy = reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_actor_action, A, B, 1, N) + diff_xx = + reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_sample_action, A, B, 1, N) + diff_xy = + reshape(raw_sample_action, A, B, N, 1) .- reshape(raw_actor_action, A, B, 1, N) diff_yy = reshape(raw_actor_action, A, B, N, 1) .- reshape(raw_actor_action, A, B, 1, N) diff_xx = calculate_sample_distance(diff_xx, type, mmd_σ) diff_xy = calculate_sample_distance(diff_xy, type, mmd_σ) @@ -127,5 +141,5 @@ function calculate_sample_distance(diff, type::Symbol, mmd_σ::Float32) else error("Wrong parameter.") end - return vec(mean(exp.(-sum(diff, dims=1) ./ (2.0f0 * mmd_σ)), dims=(3, 4))) + return vec(mean(exp.(-sum(diff, dims = 1) ./ (2.0f0 * mmd_σ)), dims = (3, 4))) end diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl index d49f971b0..0daac6a2e 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ddpg.jl @@ -118,7 +118,12 @@ function (p::DDPGPolicy)(env, player::Any = nothing) s = DynamicStyle(env) == SEQUENTIAL ? state(env) : state(env, player) s = Flux.unsqueeze(s, ndims(s) + 1) actions = p.behavior_actor(send_to_device(D, s)) |> vec |> send_to_host - c = clamp.(actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na), -p.act_limit, p.act_limit) + c = + clamp.( + actions .+ randn(p.rng, p.na) .* repeat([p.act_noise], p.na), + -p.act_limit, + p.act_limit, + ) p.na == 1 && return c[1] c end @@ -154,7 +159,7 @@ function RLBase.update!(p::DDPGPolicy, batch::NamedTuple{SARTS}) a′ = Aₜ(s′) qₜ = Cₜ(vcat(s′, a′)) |> vec y = r .+ γ .* (1 .- t) .* qₜ - a = Flux.unsqueeze(a, ndims(a)+1) + a = Flux.unsqueeze(a, ndims(a) + 1) gs1 = gradient(Flux.params(C)) do q = C(vcat(s, a)) |> vec diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl index 8f2d67046..6357e74c0 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/maddpg.jl @@ -16,8 +16,8 @@ See the paper https://arxiv.org/abs/1706.02275 for more details. - `rng::AbstractRNG`. """ mutable struct MADDPGManager <: AbstractPolicy - agents::Dict{<:Any, <:Agent} - traces + agents::Dict{<:Any,<:Agent} + traces::Any batch_size::Int update_freq::Int update_step::Int @@ -29,12 +29,10 @@ function (π::MADDPGManager)(env::AbstractEnv) while current_player(env) == chance_player(env) env |> legal_action_space |> rand |> env end - Dict( - player => agent.policy(env) - for (player, agent) in π.agents) + Dict(player => agent.policy(env) for (player, agent) in π.agents) end -function (π::MADDPGManager)(stage::Union{PreEpisodeStage, PostActStage}, env::AbstractEnv) +function (π::MADDPGManager)(stage::Union{PreEpisodeStage,PostActStage}, env::AbstractEnv) # only need to update trajectory. for (_, agent) in π.agents update!(agent.trajectory, agent.policy, env, stage) @@ -46,7 +44,7 @@ function (π::MADDPGManager)(stage::PreActStage, env::AbstractEnv, actions) for (player, agent) in π.agents update!(agent.trajectory, agent.policy, env, stage, actions[player]) end - + # update policy update!(π, env) end @@ -70,14 +68,18 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv) length(agent.trajectory) > agent.policy.policy.update_after || return length(agent.trajectory) > π.batch_size || return end - + # get training data temp_player = collect(keys(π.agents))[1] t = π.agents[temp_player].trajectory inds = rand(π.rng, 1:length(t), π.batch_size) - batches = Dict((player, RLCore.fetch!(BatchSampler{π.traces}(π.batch_size), agent.trajectory, inds)) - for (player, agent) in π.agents) - + batches = Dict( + ( + player, + RLCore.fetch!(BatchSampler{π.traces}(π.batch_size), agent.trajectory, inds), + ) for (player, agent) in π.agents + ) + # get s, a, s′ for critic s = vcat((batches[player][:state] for (player, _) in π.agents)...) a = vcat((batches[player][:action] for (player, _) in π.agents)...) @@ -100,17 +102,17 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv) t = batches[player][:terminal] # for training behavior_actor. mu_actions = vcat( - (( - batches[p][:next_state] |> - a.policy.policy.behavior_actor - ) for (p, a) in π.agents)... + ( + (batches[p][:next_state] |> a.policy.policy.behavior_actor) for + (p, a) in π.agents + )..., ) # for training behavior_critic. new_actions = vcat( - (( - batches[p][:next_state] |> - a.policy.policy.target_actor - ) for (p, a) in π.agents)... + ( + (batches[p][:next_state] |> a.policy.policy.target_actor) for + (p, a) in π.agents + )..., ) if π.traces == SLARTSL @@ -120,18 +122,18 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv) @assert env isa ActionTransformedEnv mask = batches[player][:next_legal_actions_mask] - mu_l′ = Flux.batch( - (begin + mu_l′ = Flux.batch(( + begin actions = env.action_mapping(mu_actions[:, i]) mask[actions[player]] - end for i = 1:π.batch_size) - ) - new_l′ = Flux.batch( - (begin + end for i in 1:π.batch_size + )) + new_l′ = Flux.batch(( + begin actions = env.action_mapping(new_actions[:, i]) mask[actions[player]] - end for i = 1:π.batch_size) - ) + end for i in 1:π.batch_size + )) end qₜ = Cₜ(vcat(s′, new_actions)) |> vec @@ -157,7 +159,7 @@ function RLBase.update!(π::MADDPGManager, env::AbstractEnv) v .+= ifelse.(mu_l′, 0.0f0, typemin(Float32)) end reg = mean(A(batches[player][:state]) .^ 2) - loss = -mean(v) + reg * 1e-3 + loss = -mean(v) + reg * 1e-3 ignore() do p.actor_loss = loss end diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl index fa06a9132..bf9e8f27e 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/ppo.jl @@ -148,7 +148,9 @@ function RLBase.prob( if p.update_step < p.n_random_start @error "todo" else - μ, logσ = p.approximator.actor(send_to_device(device(p.approximator), state)) |> send_to_host + μ, logσ = + p.approximator.actor(send_to_device(device(p.approximator), state)) |> + send_to_host StructArray{Normal}((μ, exp.(logσ))) end end @@ -256,11 +258,11 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory) end s = send_to_device(D, select_last_dim(states_flatten, inds)) # !!! performance critical a = send_to_device(D, select_last_dim(actions_flatten, inds)) - + if eltype(a) === Int a = CartesianIndex.(a, 1:length(a)) end - + r = send_to_device(D, vec(returns)[inds]) log_p = send_to_device(D, vec(action_log_probs)[inds]) adv = send_to_device(D, vec(advantages)[inds]) @@ -275,7 +277,8 @@ function _update!(p::PPOPolicy, t::AbstractTrajectory) else log_p′ₐ = normlogpdf(μ, exp.(logσ), a) end - entropy_loss = mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2 + entropy_loss = + mean(size(logσ, 1) * (log(2.0f0π) + 1) .+ sum(logσ; dims = 1)) / 2 else # actor is assumed to return discrete logits logit′ = AC.actor(s) diff --git a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl index b41c0397a..fef3c53d1 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/policy_gradient/sac.jl @@ -104,8 +104,8 @@ function SACPolicy(; Float32(-action_dims), update_step, rng, - 0f0, - 0f0, + 0.0f0, + 0.0f0, ) end @@ -120,7 +120,7 @@ function (p::SACPolicy)(env) s = state(env) s = Flux.unsqueeze(s, ndims(s) + 1) # trainmode: - action = dropdims(p.policy(p.rng, s; is_sampling=true), dims=2) # Single action vec, drop second dim + action = dropdims(p.policy(p.rng, s; is_sampling = true), dims = 2) # Single action vec, drop second dim # testmode: # if testing dont sample an action, but act deterministically by @@ -146,7 +146,7 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS}) γ, τ, α = p.γ, p.τ, p.α - a′, log_π = p.policy(p.rng, s′; is_sampling=true, is_return_log_prob=true) + a′, log_π = p.policy(p.rng, s′; is_sampling = true, is_return_log_prob = true) q′_input = vcat(s′, a′) q′ = min.(p.target_qnetwork1(q′_input), p.target_qnetwork2(q′_input)) @@ -168,12 +168,12 @@ function RLBase.update!(p::SACPolicy, batch::NamedTuple{SARTS}) # Train Policy p_grad = gradient(Flux.params(p.policy)) do - a, log_π = p.policy(p.rng, s; is_sampling=true, is_return_log_prob=true) + a, log_π = p.policy(p.rng, s; is_sampling = true, is_return_log_prob = true) q_input = vcat(s, a) q = min.(p.qnetwork1(q_input), p.qnetwork2(q_input)) reward = mean(q) entropy = mean(log_π) - ignore() do + ignore() do p.reward_term = reward p.entropy_term = entropy end diff --git a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl index a91c22b62..d0bbedfa3 100644 --- a/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl +++ b/src/ReinforcementLearningZoo/src/algorithms/tabular/tabular_policy.jl @@ -11,7 +11,7 @@ A `Dict` is used internally to store the mapping from state to action. """ Base.@kwdef struct TabularPolicy{S,A} <: AbstractPolicy table::Dict{S,A} = Dict{Int,Int}() - n_action::Union{Int, Nothing} = nothing + n_action::Union{Int,Nothing} = nothing end (p::TabularPolicy)(env::AbstractEnv) = p(state(env)) diff --git a/test/runtests.jl b/test/runtests.jl index 04be42cbd..d6dd95c9e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,4 @@ using Test using ReinforcementLearning -@testset "ReinforcementLearning" begin -end +@testset "ReinforcementLearning" begin end