-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay_memory.py
executable file
·56 lines (50 loc) · 1.95 KB
/
replay_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from collections import namedtuple
import random
import numpy as np
class replay_memory:
def __init__(self,capacity):
self.capacity=capacity
self.memory=[]
self.position=0
self.Transition=namedtuple('Transition',
['obs4','act','next_obs4','reward','done'])
def __len__(self):
return len(self.memory)
def add(self,*args):
'''Add a transition to replay memory
Parameters
----------
e.g. repay_memory.add(obs4,action,next_obs4,reward,done)
obs4: {Tensor} of shape torch.Size([4,84,84])
act: {Tensor} of shape torch.Size([6])
next_obs4: {Tensor} of shape torch.Size([4,84,84])
reward: {int}
done: {bool} the next station is the terminal station or not
Function
--------
the replay_memory will save the latest samples
'''
if len(self.memory)<self.capacity:
self.memory.append(None)
self.memory[self.position]=self.Transition(*args)
self.position=(self.position+1)%self.capacity
def sample(self,batch_size):
'''Sample a batch from replay memory
Parameters
----------
batch_size: int
How many trasitions you want
Returns
-------
obs_batch: {Tensor} of shape torch.Size([BATCH_SIZE,4,84,84])
batch of observations
act_batch: {Tensor} of shape torch.Size([BATCH_SIZE,6])
batch of actions executed w.r.t observations in obs_batch
nob_batch: {Tensor} of shape torch.Size([BATCH_SIZE,4,84,84])
batch of next observations w.r.t obs_batch and act_batch
rew_batch: {ndarray} of shape
batch of reward received w.r.t obs_batch and act_batch
'''
batch = random.sample(self.memory, batch_size)
batch_zip=self.Transition(*zip(*batch))
return batch_zip