This repository has been archived by the owner on Aug 17, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkuhn_nfsp.jl
80 lines (64 loc) · 2.32 KB
/
kuhn_nfsp.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
NFSP agents trained on Kuhn Poker game.
"""
using ReinforcementLearning
using Flux
using StableRNGs
using ProgressMeter: @showprogress
include("nfsp.jl")
# Encode the KuhnPokerEnv's states for training.
env = KuhnPokerEnv()
states = [
(), (:J,), (:Q,), (:K,),
(:J, :bet), (:J, :pass), (:Q, :bet), (:Q, :pass), (:K, :bet), (:K, :pass),
(:J, :pass, :bet), (:J, :bet, :bet), (:J, :bet, :pass), (:J, :pass, :pass),
(:Q, :pass, :bet), (:Q, :bet, :bet), (:Q, :bet, :pass), (:Q, :pass, :pass),
(:K, :pass, :bet), (:K, :bet, :bet), (:K, :bet, :pass), (:K, :pass, :pass),
(:J, :pass, :bet, :pass), (:J, :pass, :bet, :bet), (:Q, :pass, :bet, :pass),
(:Q, :pass, :bet, :bet), (:K, :pass, :bet, :pass), (:K, :pass, :bet, :bet),
] # all states for players 1 & 2
states_indexes_Dict = Dict((i, j) for (j, i) in enumerate(states))
RLBase.state(env::StateTransformedEnv, args...; kwargs...) =
env.state_mapping(state(env.env, args...; kwargs...), args...)
RLBase.state_space(env::StateTransformedEnv, args...; kwargs...) =
env.state_space_mapping(state_space(env.env, args...; kwargs...), args...)
wrapped_env = StateTransformedEnv(
env;
state_mapping = (s, player=current_player(env)) ->
player == chance_player(env) ? s : [states_indexes_Dict[s]],
state_space_mapping = (ss, player=current_player(env)) ->
player == chance_player(env) ? ss : [[i] for i in 1:length(states)]
)
# set parameters
seed = 123
anticipatory_param = 0.1f0
used_device = Flux.cpu # Flux.gpu
rng = StableRNG(seed)
hidden_layers = (64, 64)
eval_every = 10_000
ϵ_decay = 2_000_000
train_episodes = 10_000_000
# initial NFSPAgents
nfsp = NFSPAgents(wrapped_env;
η = anticipatory_param,
_device = used_device,
ϵ_decay = ϵ_decay,
hidden_layers = hidden_layers,
rng = rng
)
episodes = []
results = [] # where can use `hook` to record the results
@showprogress for episode in range(1, length=train_episodes)
reset!(wrapped_env)
while !is_terminated(wrapped_env)
RLBase.update!(nfsp, wrapped_env)
end
if episode % eval_every == 0
push!(episodes, episode)
push!(results, nash_conv(nfsp, wrapped_env) / 2)
end
end
# save results
ENV["GKSwstype"]="nul"
using Plots
savefig(plot(episodes, results, xaxis=:log), "result")