-
Notifications
You must be signed in to change notification settings - Fork 1
/
rainbow_comparison.py
77 lines (65 loc) · 2.63 KB
/
rainbow_comparison.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from rainbow import *
import gym
import torch
import matplotlib.pyplot as plt
import argparse
def set_seed(seed, env):
torch.manual_seed(seed)
if torch.backends.cudnn.enabled:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
env.seed(seed)
if __name__ == '__main__':
# python rainbow_comparison.py --num_frames 2000 --plotting_interval 100
ap = argparse.ArgumentParser()
ap.add_argument("-nf", "--num_frames", type=int, default=2000,
help="number of training frames")
ap.add_argument("-plt", "--plot", default=False, action='store_true',
help="Plot training stats during training for each network")
ap.add_argument("-pi", "--plotting_interval", type=int, default=100,
help="Number of steps per plots update")
args = ap.parse_args()
# hyper parameters
num_frames = args.num_frames
memory_size = args.num_frames / 10
batch_size = 32
target_update = args.num_frames / 10
plotting_interval = args.plotting_interval
plot = args.plot
# seed
seed = 777
# make environment
env_id = "CartPole-v0"
env = gym.make(env_id)
set_seed(seed, env)
# train
agent_rainbow = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=False, no_categorical=False, no_double=False,
no_n_step=False, no_noise=False, no_priority=False,
plot=plot, frame_interval=plotting_interval)
agent_rainbow_4 = DQNAgent(env, memory_size, batch_size, target_update,
no_dueling=False, no_categorical=False, no_double=False,
no_n_step=False, no_noise=False, no_priority=False,
plot=plot, frame_interval=plotting_interval, n_frames_stack=4)
agents = [agent_rainbow, agent_rainbow_4]
labels = ["Rainbow", "Rainbow-4-frames"]
scores = []
losses = []
for i, agent in enumerate(agents):
print("Training agent", labels[i])
score, loss = agent.train(num_frames)
scores.append(score)
losses.append(loss)
# create a color palette
palette = plt.get_cmap('Set1')
plt.figure(figsize=(20, 5))
plt.subplot(131)
plt.title('Training frames: %s' % num_frames)
for i in range(len(scores)):
linewidth = 3.
plt.plot(scores[i], marker='', color=palette(i), linewidth=linewidth, alpha=1., label=labels[i])
plt.legend(loc=2, ncol=1)
plt.xlabel("Frames x " + str(plotting_interval))
plt.ylabel("Score")