-
Notifications
You must be signed in to change notification settings - Fork 4
/
agent_noisy.py
executable file
·110 lines (90 loc) · 3.78 KB
/
agent_noisy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import random
import numpy as np
from collections import deque
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from memory import ReplayMemory
from model_noisy import *
from utils import *
from config import *
import pdb
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Agent():
def __init__(self, action_size, epsilon=1.0, load_model=False, model_path=None):
self.load_model = load_model
self.action_size = action_size
# These are hyper parameters for the DQN
self.discount_factor = 0.99
self.epsilon = epsilon
self.epsilon_min = 0.01
self.explore_step = 1000000
self.epsilon_decay = (self.epsilon - self.epsilon_min) / self.explore_step
self.train_start = 100000
self.update_target = 1000
# Generate the memory
self.memory = ReplayMemory()
# Create the policy net and the target net
self.policy_net = NoisyDQN(action_size)
self.policy_net.to(device)
self.target_net = NoisyDQN(action_size)
self.target_net.to(device)
self.optimizer = optim.Adam(params=self.policy_net.parameters(), lr=learning_rate)
# initialize target net
self.update_target_net()
if self.load_model:
self.policy_net = torch.load(model_path, map_location=device)
self.target_net = torch.load(model_path, map_location=device)
self.target_net.eval()
# after some time interval update the target net to be same with policy net
def update_target_net(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def get_action(self, state):
state = torch.from_numpy(state).unsqueeze(0).to(device)
#pdb.set_trace()
with torch.no_grad():
a = self.policy_net(state).argmax(dim=1).detach().cpu().numpy()[0]
return a
# pick samples randomly from replay memory (with batch_size)
def train_policy_net(self, frame):
mini_batch = self.memory.sample_mini_batch(frame)
mini_batch = np.array(mini_batch).transpose()
history = np.stack(mini_batch[0], axis=0) #shape: (batch_size,5,84,84)
states = np.float32(history[:, :4, :, :]) / 255. #current state consists of frame(0 to 3)
actions = list(mini_batch[1])
rewards = list(mini_batch[2])
next_states = np.float32(history[:, 1:, :, :]) / 255. #next state consists of frame(1 to 4)
dones = mini_batch[3] # checks if the game is over
#pdb.set_trace()
current_q_values = QValues.get_current(self.policy_net, states, actions)
next_q_values = QValues.get_next(self.target_net, next_states, dones)
rewards = torch.from_numpy(np.float32(np.array(rewards))).to(device)
target_q_values = (next_q_values * self.discount_factor) + rewards
loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
#loss = F.smooth_l1_loss(current_q_values, target_q_values.unsqueeze(1))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.policy_net.reset_noise()
self.target_net.reset_noise()
return loss
class QValues():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@staticmethod
def get_current(policy_net, states, actions):
states = torch.from_numpy(states).to(device)
actions = torch.from_numpy(np.array(actions)).to(device)
return policy_net(states).gather(dim=1, index=actions.unsqueeze(-1))
@staticmethod
# find q_values of states that are NOT terminal states
# q_values of terminal states are kept at 0
def get_next(target_net, next_states, dones):
next_states = torch.from_numpy(next_states).to(device)
dones = torch.from_numpy(dones.astype(bool)).to(device)
non_final_state_locations = (dones == False)
non_final_states = next_states[non_final_state_locations]
batch_size = next_states.shape[0]
values = torch.zeros(batch_size).to(QValues.device)
values[non_final_state_locations] = target_net(non_final_states).max(dim=1)[0].detach()
return values