Skip to content

Commit c7c3bac

Browse files
committed
fix(nyz): fix marl nstep td compatibility bug
1 parent 8392206 commit c7c3bac

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

ding/policy/madqn.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
198198
td_error_per_sample = []
199199
for t in range(self._cfg.collect.unroll_len):
200200
v_data = v_nstep_td_data(
201-
total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], self._gamma
201+
total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], None
202202
)
203203
# calculate v_nstep_td critic_loss
204204
loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep)
@@ -231,8 +231,12 @@ def _forward_learn(self, data: dict) -> Dict[str, Any]:
231231
cooperation_loss_all = []
232232
for t in range(self._cfg.collect.unroll_len):
233233
v_data = v_nstep_td_data(
234-
cooperation_total_q[t], cooperation_target_total_q[t], data['reward'][t], data['done'][t],
235-
data['weight'], self._gamma
234+
cooperation_total_q[t],
235+
cooperation_target_total_q[t],
236+
data['reward'][t],
237+
data['done'][t],
238+
data['weight'],
239+
None,
236240
)
237241
cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep)
238242
cooperation_loss_all.append(cooperation_loss)

0 commit comments

Comments
 (0)