From 3956ae32636acbbf0c30e6a636e9563178397f65 Mon Sep 17 00:00:00 2001 From: nighood Date: Mon, 22 Apr 2024 17:18:14 +0800 Subject: [PATCH] fix(rjy): fixed reward compatibility --- ding/rl_utils/adder.py | 7 +++++-- dizoo/ising_env/envs/ising_model_env.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ding/rl_utils/adder.py b/ding/rl_utils/adder.py index f4faa572a8..b050c097fd 100644 --- a/ding/rl_utils/adder.py +++ b/ding/rl_utils/adder.py @@ -130,7 +130,10 @@ def get_nstep_return_data( if cum_reward: data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(nstep)]) else: - data[i]['reward'] = torch.stack([data[i + j]['reward'] for j in range(nstep)], dim=-1) + # data[i]['reward'].shape = (1) or (agent_num, 1) + # single agent env: shape (1) -> (n_step) + # multi-agent env: shape (agent_num, 1) -> (agent_num, n_step) + data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)], dim=-1) data[i]['done'] = data[i + nstep - 1]['done'] if correct_terminate_gamma: data[i]['value_gamma'] = gamma ** nstep @@ -140,7 +143,7 @@ def get_nstep_return_data( if cum_reward: data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(len(data) - i)]) else: - data[i]['reward'] = torch.stack( + data[i]['reward'] = torch.cat( [data[i + j]['reward'] for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))], dim=-1 diff --git a/dizoo/ising_env/envs/ising_model_env.py b/dizoo/ising_env/envs/ising_model_env.py index 731d2a8c06..6de7c6ea69 100644 --- a/dizoo/ising_env/envs/ising_model_env.py +++ b/dizoo/ising_env/envs/ising_model_env.py @@ -103,7 +103,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: obs = np.stack(obs) obs = np.concatenate([obs, pre_action_prob], axis=1) obs = to_ndarray(obs).astype(np.float32) - rew = np.concatenate(rew) + rew = np.stack(rew) self._eval_episode_return += np.sum(rew) self.cur_step += 1