-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
92 lines (66 loc) · 3.42 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import numpy as np
def td_loss(num_frames, batch_size, gamma, model, buffer, optimizer):
state, action, reward, next_state, done = buffer.sample(batch_size)
state = torch.FloatTensor(np.float32(state))
next_state = torch.FloatTensor(np.float32(next_state))
action = torch.LongTensor(action)
reward = torch.FloatTensor(reward)
done = torch.FloatTensor(done)
if torch.cuda.is_available():
state, next_state, action, reward, done = state.cuda(), next_state.cuda(), action.cuda(), reward.cuda(), done.cuda()
q_values = model(state)
next_q_values = model(next_state)
q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1) # get q_value indexed by action
next_q_value = next_q_values.max(1)[0]
expected_q_value = reward + gamma * next_q_value * (1 - done) # if done = 1, no next state
loss = (q_value - expected_q_value.data).pow(2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.detach()
def double_td_loss(num_frames, batch_size, gamma, model, tmodel, buffer, optimizer):
state, action, reward, next_state, done = buffer.sample(batch_size)
state = torch.FloatTensor(np.float32(state))
next_state = torch.FloatTensor(np.float32(next_state))
action = torch.LongTensor(action)
reward = torch.FloatTensor(reward)
done = torch.FloatTensor(done)
if torch.cuda.is_available():
state, next_state, action, reward, done = state.cuda(), next_state.cuda(), action.cuda(), reward.cuda(), done.cuda()
q_values = model(state)
next_q_values = model(next_state)
next_q_state_values = tmodel(next_state)
q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
next_q_value = next_q_state_values.gather(1, torch.max(next_q_values, 1)[1].unsqueeze(1)).squeeze(1)
expected_q_value = reward + gamma * next_q_value * (1 - done)
loss = (q_value - expected_q_value.data).pow(2).mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.detach()
def prioritized_td_loss(num_frames, batch_size, gamma, model, tmodel, buffer, optimizer):
beta = buffer.beta_func(num_frames)
state, action, reward, next_state, done, indices, weights = buffer.sample(batch_size, beta)
state = torch.FloatTensor(np.float32(state))
next_state = torch.FloatTensor(np.float32(next_state))
action = torch.LongTensor(action)
reward = torch.FloatTensor(reward)
done = torch.FloatTensor(done)
weights = torch.FloatTensor(weights)
if torch.cuda.is_available():
state, next_state, action, reward, done, weights = state.cuda(), next_state.cuda(), action.cuda(), reward.cuda(), done.cuda(), weights.cuda()
q_values = model(state)
next_q_values = model(next_state)
next_q_state_values = tmodel(next_state)
q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
next_q_value = next_q_state_values.gather(1, torch.max(next_q_values, 1)[1].unsqueeze(1)).squeeze(1)
expected_q_value = reward + gamma * next_q_value * (1 - done)
loss = (q_value - expected_q_value.detach()).pow(2) * weights
prios = loss + 1e-5
loss = loss.mean()
buffer.update_priorities(indices, prios.data.cpu().numpy())
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.detach()