@@ -198,7 +198,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
198
198
td_error_per_sample = []
199
199
for t in range (self ._cfg .collect .unroll_len ):
200
200
v_data = v_nstep_td_data (
201
- total_q [t ], target_total_q [t ], data ['reward' ][t ], data ['done' ][t ], data ['weight' ], self . _gamma
201
+ total_q [t ], target_total_q [t ], data ['reward' ][t ], data ['done' ][t ], data ['weight' ], None
202
202
)
203
203
# calculate v_nstep_td critic_loss
204
204
loss_i , td_error_per_sample_i = v_nstep_td_error (v_data , self ._gamma , self ._nstep )
@@ -231,8 +231,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
231
231
cooperation_loss_all = []
232
232
for t in range (self ._cfg .collect .unroll_len ):
233
233
v_data = v_nstep_td_data (
234
- cooperation_total_q [t ], cooperation_target_total_q [t ], data ['reward' ][t ], data ['done' ][t ],
235
- data ['weight' ], self ._gamma
234
+ cooperation_total_q [t ],
235
+ cooperation_target_total_q [t ],
236
+ data ['reward' ][t ],
237
+ data ['done' ][t ],
238
+ data ['weight' ],
239
+ None ,
236
240
)
237
241
cooperation_loss , _ = v_nstep_td_error (v_data , self ._gamma , self ._nstep )
238
242
cooperation_loss_all .append (cooperation_loss )
0 commit comments