-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner.py
94 lines (73 loc) · 2.76 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
class AtariRunner:
def __init__(self, env, agent, buffer, gamma=0.99, minibatch_size=1, replay_period=1):
self.env = env
self.agent = agent
self.buffer = buffer
self.gamma = gamma
self.replay_period = replay_period
self.minibatch_size = minibatch_size
self.counter = 0
def run_episode(self):
steps = 0
rewards = 0
done = False
s = self.env.reset()
while not done:
a = self.agent.act(s)
s_, r, done, info = self.env.step(a)
sample = (s, a, r, s_, done)
x, y, error = self.get_targets([(0, sample)])
self.buffer.add(error, sample)
self.counter += 1
if self.counter % self.replay_period == 0:
batch = self.buffer.sample(self.minibatch_size)
x, y, errors = self.get_targets(batch)
# Update errors
for i in range(self.minibatch_size):
idx = batch[i][0]
self.buffer.update(idx, errors[i])
# Train agent
self.agent.train(x, y)
self.counter = 0
s = s_
steps += 1
rewards += r
return steps, rewards
def get_targets(self, batch):
"""
Computes the targets for the given batch of transitions and their TD errors
"""
states = np.array([obs[1][0] for obs in batch], dtype=np.float32)
states_ = np.array([obs[1][3] for obs in batch], dtype=np.float32)
values = self.agent.predict(states)
values_next = self.agent.predict(states_)
values_next_target = self.agent.predict(states_, target=True)
length = len(batch)
errors = np.zeros(length, dtype=np.float32)
for i in range(length):
s, a, r, s_, done = batch[i][1]
old_value = values[i, a]
if done:
values[i, a] = r
else:
values[i, a] = r + self.gamma * values_next_target[i][np.argmax(values_next[i])]
# Compute the TD error
errors[i] = abs(old_value - values[i, a])
return states, values, errors
def initialize_buffer(self, transitions):
"""
Initializes the replay buffer using experiences generated by taking random actions
"""
s = self.env.reset()
count = 0
while count < transitions:
a = self.env.random_action()
s_, r, done, info = self.env.step(a)
sample = (s, a, r, s_, done)
self.buffer.add(abs(r), sample) # using reward as error
if done:
s = self.env.reset()
else:
s = s_
count += 1