Skip to content

Commit 3fb5210

Browse files
fix bug in SAC
1 parent 0840a82 commit 3fb5210

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

rlzoo/algorithms/sac/default.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def classic_control(env, default_seed=True):
5555
with tf.name_scope('Policy'):
5656
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
5757
hidden_dim_list=num_hidden_layer * [hidden_dim],
58+
output_activation=None,
5859
state_conditioned=True)
5960
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
6061
alg_params['net_list'] = net_list
@@ -110,6 +111,7 @@ def box2d(env, default_seed=True):
110111
with tf.name_scope('Policy'):
111112
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
112113
hidden_dim_list=num_hidden_layer * [hidden_dim],
114+
output_activation=None,
113115
state_conditioned=True)
114116
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
115117
alg_params['net_list'] = net_list
@@ -165,6 +167,7 @@ def mujoco(env, default_seed=True):
165167
with tf.name_scope('Policy'):
166168
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
167169
hidden_dim_list=num_hidden_layer * [hidden_dim],
170+
output_activation=None,
168171
state_conditioned=True)
169172
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
170173
alg_params['net_list'] = net_list
@@ -220,6 +223,7 @@ def robotics(env, default_seed=True):
220223
with tf.name_scope('Policy'):
221224
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
222225
hidden_dim_list=num_hidden_layer * [hidden_dim],
226+
output_activation=None,
223227
state_conditioned=True)
224228
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
225229
alg_params['net_list'] = net_list
@@ -275,6 +279,7 @@ def dm_control(env, default_seed=True):
275279
with tf.name_scope('Policy'):
276280
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
277281
hidden_dim_list=num_hidden_layer * [hidden_dim],
282+
output_activation=None,
278283
state_conditioned=True)
279284
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
280285
alg_params['net_list'] = net_list
@@ -330,6 +335,7 @@ def rlbench(env, default_seed=True):
330335
with tf.name_scope('Policy'):
331336
policy_net = StochasticPolicyNetwork(env.observation_space, env.action_space,
332337
hidden_dim_list=num_hidden_layer * [hidden_dim],
338+
output_activation=None,
333339
state_conditioned=True)
334340
net_list = [soft_q_net1, soft_q_net2, target_soft_q_net1, target_soft_q_net2, policy_net]
335341
alg_params['net_list'] = net_list

0 commit comments

Comments
 (0)