Skip to content

Commit a7addbb

Browse files
fix bug in _td_error()
1 parent 8b88b1f commit a7addbb

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

rlzoo/algorithms/dqn/dqn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _td_error(self, transitions, reward_gamma):
8989
b_o, b_a, b_r, b_o_, b_d = transitions
9090
b_d = tf.cast(b_d, tf.float32)
9191
b_a = tf.cast(b_a, tf.int64)
92-
b_r = tf.cast(b_a, tf.float32)
92+
b_r = tf.cast(b_r, tf.float32)
9393
if self.double_q:
9494
b_a_ = tf.one_hot(tf.argmax(self.network(b_o_), 1), self.network.action_shape[0])
9595
b_q_ = (1 - b_d) * tf.reduce_sum(self.target_network(b_o_) * b_a_, 1)

0 commit comments

Comments
 (0)