@@ -55,6 +55,7 @@ def classic_control(env, default_seed=True):
5555 with tf .name_scope ('Policy' ):
5656 policy_net = StochasticPolicyNetwork (env .observation_space , env .action_space ,
5757 hidden_dim_list = num_hidden_layer * [hidden_dim ],
58+ output_activation = None ,
5859 state_conditioned = True )
5960 net_list = [soft_q_net1 , soft_q_net2 , target_soft_q_net1 , target_soft_q_net2 , policy_net ]
6061 alg_params ['net_list' ] = net_list
@@ -110,6 +111,7 @@ def box2d(env, default_seed=True):
110111 with tf .name_scope ('Policy' ):
111112 policy_net = StochasticPolicyNetwork (env .observation_space , env .action_space ,
112113 hidden_dim_list = num_hidden_layer * [hidden_dim ],
114+ output_activation = None ,
113115 state_conditioned = True )
114116 net_list = [soft_q_net1 , soft_q_net2 , target_soft_q_net1 , target_soft_q_net2 , policy_net ]
115117 alg_params ['net_list' ] = net_list
@@ -165,6 +167,7 @@ def mujoco(env, default_seed=True):
165167 with tf .name_scope ('Policy' ):
166168 policy_net = StochasticPolicyNetwork (env .observation_space , env .action_space ,
167169 hidden_dim_list = num_hidden_layer * [hidden_dim ],
170+ output_activation = None ,
168171 state_conditioned = True )
169172 net_list = [soft_q_net1 , soft_q_net2 , target_soft_q_net1 , target_soft_q_net2 , policy_net ]
170173 alg_params ['net_list' ] = net_list
@@ -220,6 +223,7 @@ def robotics(env, default_seed=True):
220223 with tf .name_scope ('Policy' ):
221224 policy_net = StochasticPolicyNetwork (env .observation_space , env .action_space ,
222225 hidden_dim_list = num_hidden_layer * [hidden_dim ],
226+ output_activation = None ,
223227 state_conditioned = True )
224228 net_list = [soft_q_net1 , soft_q_net2 , target_soft_q_net1 , target_soft_q_net2 , policy_net ]
225229 alg_params ['net_list' ] = net_list
@@ -275,6 +279,7 @@ def dm_control(env, default_seed=True):
275279 with tf .name_scope ('Policy' ):
276280 policy_net = StochasticPolicyNetwork (env .observation_space , env .action_space ,
277281 hidden_dim_list = num_hidden_layer * [hidden_dim ],
282+ output_activation = None ,
278283 state_conditioned = True )
279284 net_list = [soft_q_net1 , soft_q_net2 , target_soft_q_net1 , target_soft_q_net2 , policy_net ]
280285 alg_params ['net_list' ] = net_list
@@ -330,6 +335,7 @@ def rlbench(env, default_seed=True):
330335 with tf .name_scope ('Policy' ):
331336 policy_net = StochasticPolicyNetwork (env .observation_space , env .action_space ,
332337 hidden_dim_list = num_hidden_layer * [hidden_dim ],
338+ output_activation = None ,
333339 state_conditioned = True )
334340 net_list = [soft_q_net1 , soft_q_net2 , target_soft_q_net1 , target_soft_q_net2 , policy_net ]
335341 alg_params ['net_list' ] = net_list
0 commit comments