-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay_buffer.py
32 lines (27 loc) · 1.34 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
class ReplayBuffer:
def __init__(self, size, input_shape):
self.size = size
self.counter = 0
self.state_buffer = np.zeros((self.size, input_shape), dtype=np.float32)
self.action_buffer = np.zeros(self.size, dtype=np.int32)
self.reward_buffer = np.zeros(self.size, dtype=np.float32)
self.new_state_buffer = np.zeros((self.size, input_shape), dtype=np.float32)
self.terminal_buffer = np.zeros(self.size, dtype=np.bool_)
def store_tuples(self, state, action, reward, new_state, done):
idx = self.counter % self.size
self.state_buffer[idx] = state
self.action_buffer[idx] = action
self.reward_buffer[idx] = reward
self.new_state_buffer[idx] = new_state
self.terminal_buffer[idx] = done
self.counter += 1
def sample_buffer(self, batch_size):
max_buffer = min(self.counter, self.size)
batch = np.random.choice(max_buffer, batch_size, replace=False)
state_batch = self.state_buffer[batch]
action_batch = self.action_buffer[batch]
reward_batch = self.reward_buffer[batch]
new_state_batch = self.new_state_buffer[batch]
done_batch = self.terminal_buffer[batch]
return state_batch, action_batch, reward_batch, new_state_batch, done_batch