-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn.py
80 lines (61 loc) · 3.07 KB
/
dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from keras import Model
import tensorflow as tf
class DQN(Model):
"""
This class implements a dueling DQN using a CNN architecture.
"""
def __init__(self, num_actions = 9.0, input_shape = (84,84,4)):
super(DQN, self).__init__()
self._mse = tf.keras.losses.MeanSquaredError()
self._l1 = tf.keras.layers.Conv2D(64, (8, 8), strides = (2,2),activation='relu', input_shape=input_shape)
self._l2 = tf.keras.layers.MaxPooling2D((3, 3))
self._l3 = tf.keras.layers.Conv2D(64, (4, 4), activation='relu')
self._l4 = tf.keras.layers.MaxPooling2D((3, 3))
self._l5 = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')
self._l6 = tf.keras.layers.GlobalMaxPooling2D()
self._l7 = tf.keras.layers.Dense(512,activation="relu")
# state value
self._l10 = tf.keras.layers.Dense(1,activation="linear")
# advantages
self._num_actions = num_actions
self._l11 = tf.keras.layers.Dense(self._num_actions,activation="linear")
@tf.function
def call(self,x,training):
x = tf.cast(x,tf.float32)
x = (x-127.5)/127.5
x = self._l1(x,training=training)
x = self._l2(x,training=training)
x = self._l3(x,training=training)
x = self._l4(x,training=training)
x = self._l5(x,training=training)
x = self._l6(x,training=training)
x = self._l7(x,training=training)
state_value = self._l10(x,training=training)
advantages = self._l11(x,training=training)
return state_value + advantages - tf.expand_dims(tf.reduce_sum(advantages,axis = -1)/tf.cast(self._num_actions,tf.float32),axis = -1)
@tf.function
def step(self,s,a,r,s_new,done,optimizer,dqn_target):
"""
Perform a training step with the dqn using double q-learning
Args:
- s <tf.Tensor> : state sampled from the environment
- a <tf.Tensor> : action that was perfomed in s
- r <tf.Tensor> : reward from performing a in s
- s_new <tf.Tensor> : new state from pefroming a in s
- done <tf.Tensor> : if s_new is a terminal state
- optimizer <tf.keras.optimizers.Optimizer> : optimizer used for performing the training step
- dqn_target <DQN> : target network used for double q-learning
Returns:
- loss <tf.Tensor> : mse loss
"""
with tf.GradientTape() as tape:
with tf.device("/GPU:0"):
# calculate the corresponding q values
Q_max = tf.math.reduce_max(dqn_target(s_new),axis=1)
Q_s_a = tf.gather(params = self(s),indices = tf.cast(a,tf.int32),axis = 1,batch_dims = 1)
# apply mean squared error loss
loss = self._mse(Q_s_a, r + (tf.constant(0.99)*Q_max)*(1-done))
# perform gradient descent step
grads = tape.gradient(loss, self.trainable_weights)
optimizer.apply_gradients(zip(grads, self.trainable_weights))
return loss