@@ -65,7 +65,7 @@ def evaluate(self, state, epsilon=1e-6):
6565 std = tf .math .exp (log_std ) # no clip in evaluation, clip affects gradients flow
6666
6767 normal = Normal (0 , 1 )
68- z = normal .sample ()
68+ z = normal .sample (mean . shape )
6969 action_0 = tf .math .tanh (mean + std * z ) # TanhNormal distribution as actions; reparameterization trick
7070 # according to original paper, with an extra last term for normalizing different action range
7171 log_prob = Normal (mean , std ).log_prob (mean + std * z ) - tf .math .log (1. - action_0 ** 2 + epsilon )
@@ -80,11 +80,14 @@ def evaluate(self, state, epsilon=1e-6):
8080
8181 def get_action (self , state ):
8282 """ generate action with state for interaction with envronment """
83- return self .policy_net (np .array ([state ])).numpy ()[0 ]
83+ action , _ , _ , _ , _ = self .evaluate (np .array ([state ]))
84+ return action .numpy ()[0 ]
8485
8586 def get_action_greedy (self , state ):
8687 """ generate action with state for interaction with envronment """
87- return self .policy_net (np .array ([state ]), greedy = True ).numpy ()[0 ]
88+ mean = self .policy_net (np .array ([state ]), greedy = True ).numpy ()[0 ]
89+ action = tf .math .tanh (mean ) * self .policy_net .policy_dist .action_scale + self .policy_net .policy_dist .action_mean
90+ return action
8891
8992 def sample_action (self , ):
9093 """ generate random actions for exploration """
0 commit comments