-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
192 lines (157 loc) · 7.5 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
agent.py
--------
Description:
This file implements an agent class, based on the Categorical-DQN algorithm.
"""
import random
import torch
from memory import ReplayMemory
from network import Network
import utils
class Agent:
"""
Class for the Categorical-DQN (C51) agent.
In essence, for each action, a value distribution is returned by the network,
from which a statistic such as the mean is computed to get the action-value.
"""
def __init__(self,
n_actions: int,
capacity: int,
batch_size: int,
learning_rate: float,
obs_shape: tuple,
epsilon_start: float,
epsilon_end: float,
exploration_fraction: float,
training_steps: int,
gamma: float,
v_min: int,
v_max: int,
n_atoms: int,
image_obs: bool,
n_hidden_units: int,
device: torch.device):
"""
Initialize the agent class.
Args:
n_actions: Number of agent actions.
capacity: Replay memory capacity.
batch_size: Replay memory sample size.
learning_rate: Optimizer learning rate.
obs_shape: Env observation shape.
epsilon_start: Start epsilon value.
epsilon_end: Final epsilon value.
exploration_fraction: Fraction of training steps used to explore.
training_steps: Number of training steps.
gamma: discount factor
v_min: Minimum return value
v_max: Maximum return value
n_atoms: Number of return values between v_min and v_max
image_obs: Whether env observation is image
n_hidden_units: Number of agent network's fully-connected layer units.
device: torch.device
"""
self._n_actions = n_actions
self._batch_size = batch_size
self._epsilon_end = epsilon_end
self._epsilon = epsilon_start
self._epsilon_decay = (epsilon_start - self._epsilon_end) / (exploration_fraction * training_steps)
self._gamma = gamma
self._v_min = v_min
self._v_max = v_max
self._n_atoms = n_atoms
self._image_obs = image_obs
self._device = device
self._delta = (self._v_max - self._v_min) / (self._n_atoms - 1)
self._z = torch.linspace(self._v_min, self._v_max, self._n_atoms).to(self._device)
self.replay_memory = ReplayMemory(capacity=capacity, batch_size=batch_size)
self._main_network = Network(obs_shape, n_actions, n_atoms, n_hidden_units, image_obs).to(self._device)
self._target_network = Network(obs_shape, n_actions, n_atoms, n_hidden_units, image_obs).to(self._device)
self.update_target_network()
self._optimizer = torch.optim.Adam(self._main_network.parameters(), lr=learning_rate, eps=0.01 / batch_size)
def act(self, state: torch.Tensor) -> utils.ActionInfo:
"""
Sampling action for a given state. Actions are sampled randomly during exploration.
The action-value is the max expected value of the action value-distribution.
Args:
state: Current state of agent.
Returns:
action_info: Information namedtuple about the sampled action.
"""
with torch.no_grad():
value_dists = self._main_network(state)
expected_returns = (self._z * value_dists).sum(2)
if random.random() > self._epsilon:
action = expected_returns.argmax(1)
action_probs = expected_returns.softmax(0)
else:
action = torch.randint(high=self._n_actions, size=(1,))
action_probs = torch.ones(self._n_actions) / self._n_actions
action_value = expected_returns[0, action].item()
policy_entropy = -(action_probs * torch.log(action_probs + 1e-8)).sum().item()
action_info = utils.ActionInfo(
action=action.item(),
action_value=round(action_value, 2),
entropy=round(policy_entropy, 2),
)
return action_info
def decrease_epsilon(self):
if self._epsilon > self._epsilon_end:
self._epsilon -= self._epsilon_decay
def update_target_network(self):
"""Updating the parameters of the target network to equal the main network's parameters."""
self._target_network.load_state_dict(self._main_network.state_dict())
def learn(self) -> float:
"""Learning step, updates the main network through backpropagation. Returns loss."""
obs, actions, rewards, next_obs, terminals = self.replay_memory.sample()
states = utils.to_tensor(array=obs, device=self._device, normalize=self._image_obs)
actions = utils.to_tensor(array=actions, device=self._device).view(-1, 1).long()
rewards = utils.to_tensor(array=rewards, device=self._device).view(-1, 1)
next_states = utils.to_tensor(array=next_obs, device=self._device, normalize=self._image_obs)
terminals = utils.to_tensor(array=terminals, device=self._device).view(-1, 1)
# agent predictions
value_dists = self._main_network(states)
# gather probs for selected actions
probs = value_dists[torch.arange(self._batch_size), actions.view(-1), :]
# ------------------------------ Categorical algorithm ------------------------------
#
# Since we are dealing with value distributions and not value functions,
# we can't minimize the loss using MSE(reward+gamma*Q_i-1 - Q_i). Instead,
# we project the support of the target predictions T_hat*Z_i-1 onto the support
# of the agent predictions Z_i, and minimize the cross-entropy term of
# KL-divergence `KL(projected_T_hat*Z_i-1 || Z_i)`.
#
with torch.no_grad():
# target agent predictions
target_value_dists = self._target_network(next_states)
target_expected_returns = (self._z * target_value_dists).sum(2)
target_actions = target_expected_returns.argmax(1)
target_probs = target_value_dists[torch.arange(self._batch_size), target_actions, :]
m = torch.zeros(self._batch_size * self._n_atoms).to(self._device)
Tz = (rewards + (1 - terminals) * self._gamma * self._z).clip(self._v_min, self._v_max)
bj = (Tz - self._v_min) / self._delta
l, u = torch.floor(bj).long(), torch.ceil(bj).long()
offset = (
torch.linspace(0, (self._batch_size - 1) * self._n_atoms, self._batch_size)
.long()
.unsqueeze(1)
.expand(self._batch_size, self._n_atoms)
.to(self._device)
)
m.index_add_(
0,
(l + offset).view(-1),
(target_probs * (u + (l == u).long() - bj)).view(-1).float(),
)
m.index_add_(
0, (u + offset).view(-1), (target_probs * (bj - l)).view(-1).float()
)
m = m.view(self._batch_size, self._n_atoms)
# -----------------------------------------------------------------------------------
loss = (-((m * torch.log(probs + 1e-8)).sum(dim=1))).mean()
self._optimizer.zero_grad() # set all gradients to zero
loss.backward() # backpropagate loss through the network
self._optimizer.step() # update weights
return round(loss.item(), 2)
# ============== END OF FILE ==============