Skip to content

Commit 0394b6b

Browse files
committed
fix bug in SAC
1 parent a7addbb commit 0394b6b

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

rlzoo/algorithms/sac/sac.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def evaluate(self, state, epsilon=1e-6):
6565
std = tf.math.exp(log_std) # no clip in evaluation, clip affects gradients flow
6666

6767
normal = Normal(0, 1)
68-
z = normal.sample()
68+
z = normal.sample(mean.shape)
6969
action_0 = tf.math.tanh(mean + std * z) # TanhNormal distribution as actions; reparameterization trick
7070
# according to original paper, with an extra last term for normalizing different action range
7171
log_prob = Normal(mean, std).log_prob(mean + std * z) - tf.math.log(1. - action_0 ** 2 + epsilon)
@@ -80,11 +80,14 @@ def evaluate(self, state, epsilon=1e-6):
8080

8181
def get_action(self, state):
8282
""" generate action with state for interaction with envronment """
83-
return self.policy_net(np.array([state])).numpy()[0]
83+
action, _, _, _, _ = self.evaluate(np.array([state]))
84+
return action.numpy()[0]
8485

8586
def get_action_greedy(self, state):
8687
""" generate action with state for interaction with envronment """
87-
return self.policy_net(np.array([state]), greedy=True).numpy()[0]
88+
mean = self.policy_net(np.array([state]), greedy=True).numpy()[0]
89+
action = tf.math.tanh(mean) * self.policy_net.policy_dist.action_scale + self.policy_net.policy_dist.action_mean
90+
return action
8891

8992
def sample_action(self, ):
9093
""" generate random actions for exploration """

0 commit comments

Comments
 (0)