Skip to content

Commit

Permalink
update the experiment's parameters (JuliaReinforcementLearning#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterchen96 authored Aug 7, 2021
1 parent cad468e commit b04e3f1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
20 changes: 10 additions & 10 deletions docs/experiments/experiments/NFSP/JuliaRL_NFSP_KuhnPoker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@ function RL.Experiment(
learner = DQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_normal(rng)),
Dense(128, na; init = glorot_normal(rng))
Dense(ns, 64, relu; init = glorot_normal(rng)),
Dense(64, na; init = glorot_normal(rng))
) |> cpu,
optimizer = Descent(0.01),
),
target_approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_normal(rng)),
Dense(128, na; init = glorot_normal(rng))
Dense(ns, 64, relu; init = glorot_normal(rng)),
Dense(64, na; init = glorot_normal(rng))
) |> cpu,
),
γ = 1.0f0,
loss_func = huber_loss,
batch_size = 128,
update_freq = 64,
update_freq = 128,
min_replay_history = 1000,
target_update_freq = 1000,
rng = rng,
Expand All @@ -75,7 +75,7 @@ function RL.Experiment(
kind = :linear,
ϵ_init = 0.06,
ϵ_stable = 0.001,
decay_steps = 3_000_000,
decay_steps = 1_000_000,
rng = rng,
),
),
Expand All @@ -89,8 +89,8 @@ function RL.Experiment(
policy = BehaviorCloningPolicy(;
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_normal(rng)),
Dense(128, na; init = glorot_normal(rng))
Dense(ns, 64, relu; init = glorot_normal(rng)),
Dense(64, na; init = glorot_normal(rng))
) |> cpu,
optimizer = Descent(0.01),
),
Expand All @@ -116,14 +116,14 @@ function RL.Experiment(
deepcopy(sl_agent),
η,
rng,
64, # update_freq
128, # update_freq
0, # initial update_step
true, # initial NFSPAgent's learn mode
)) for player in players(wrapped_env) if player != chance_player(wrapped_env)
)
)

stop_condition = StopAfterEpisode(4_000_000, is_show_progress=!haskey(ENV, "CI"))
stop_condition = StopAfterEpisode(1_200_000, is_show_progress=!haskey(ENV, "CI"))
hook = ResultNEpisode(10_000, 0, [], [])

Experiment(nfsp, wrapped_env, stop_condition, hook, "# run NFSP on KuhnPokerEnv")
Expand Down
20 changes: 12 additions & 8 deletions src/ReinforcementLearningZoo/src/algorithms/nfsp/nfsp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ function (π::NFSPAgent)(stage::PreActStage, env::AbstractEnv, action)
# update policy
π.update_step += 1
if π.update_step % π.update_freq == 0
update!(sl.policy, sl.trajectory)
if !π.mode
rl_learn(rl.policy, rl.trajectory) # only update rl_policy's learner.
if π.mode
update!(sl.policy, sl.trajectory)
else
rl_learn!(rl.policy, rl.trajectory) # only update rl_policy's learner.
update!(sl.policy, sl.trajectory)
end
end
end
Expand Down Expand Up @@ -87,18 +89,20 @@ function (π::NFSPAgent)(::PostEpisodeStage, env::AbstractEnv, player::Any)
push!(rl.trajectory[:legal_actions_mask], legal_action_space_mask(env, player))
end

# update the policy
# update the policy
π.update_step += 1
if π.update_step % π.update_freq == 0
update!(sl.policy, sl.trajectory)
if !π.mode
rl_learn(rl.policy, rl.trajectory)
if π.mode
update!(sl.policy, sl.trajectory)
else
rl_learn!(rl.policy, rl.trajectory) # only update rl_policy's learner.
update!(sl.policy, sl.trajectory)
end
end
end

# the supplement function
function rl_learn(policy::QBasedPolicy, t::AbstractTrajectory)
function rl_learn!(policy::QBasedPolicy, t::AbstractTrajectory)
# just learn the approximator, not update target_approximator
learner = policy.learner
length(t[:terminal]) - learner.sampler.n <= learner.min_replay_history && return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function RLBase.update!(π::NFSPAgentManager, env::AbstractEnv)
end

function::NFSPAgentManager)(stage::Union{PreEpisodeStage, PostEpisodeStage}, env::AbstractEnv)
for (player, agent) in π.agents
agent(stage, env, player)
@sync for (player, agent) in π.agents
@async agent(stage, env, player)
end
end

0 comments on commit b04e3f1

Please sign in to comment.