-
Notifications
You must be signed in to change notification settings - Fork 26
/
rl_utils.py
30 lines (23 loc) · 1.13 KB
/
rl_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import random
def huber(x, k=1.0):
return torch.where(x.abs() < k, 0.5 * x.pow(2), k * (x.abs() - 0.5 * k))
class ReplayMemory:
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
def push(self, state, action, next_state, reward, done):
transition = torch.Tensor([state]), torch.Tensor([action]), torch.Tensor([next_state]), torch.Tensor([reward]), torch.Tensor([done])
self.memory.append(transition)
if len(self.memory) > self.capacity: del self.memory[0]
def sample(self, batch_size):
sample = random.sample(self.memory, batch_size)
batch_state, batch_action, batch_next_state, batch_reward, batch_done = zip(*sample)
batch_state = torch.cat(batch_state)
batch_action = torch.LongTensor(batch_action)
batch_reward = torch.cat(batch_reward)
batch_done = torch.cat(batch_done)
batch_next_state = torch.cat(batch_next_state)
return batch_state, batch_action, batch_reward.unsqueeze(1), batch_next_state, batch_done.unsqueeze(1)
def __len__(self):
return len(self.memory)