Skip to content

Commit

Permalink
fix(rjy): fixed reward compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
nighood committed Apr 22, 2024
1 parent 13d04e6 commit 3956ae3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions ding/rl_utils/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ def get_nstep_return_data(
if cum_reward:
data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(nstep)])
else:
data[i]['reward'] = torch.stack([data[i + j]['reward'] for j in range(nstep)], dim=-1)
# data[i]['reward'].shape = (1) or (agent_num, 1)
# single agent env: shape (1) -> (n_step)
# multi-agent env: shape (agent_num, 1) -> (agent_num, n_step)
data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)], dim=-1)
data[i]['done'] = data[i + nstep - 1]['done']
if correct_terminate_gamma:
data[i]['value_gamma'] = gamma ** nstep
Expand All @@ -140,7 +143,7 @@ def get_nstep_return_data(
if cum_reward:
data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(len(data) - i)])
else:
data[i]['reward'] = torch.stack(
data[i]['reward'] = torch.cat(
[data[i + j]['reward']
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))],
dim=-1
Expand Down
2 changes: 1 addition & 1 deletion dizoo/ising_env/envs/ising_model_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
obs = np.stack(obs)
obs = np.concatenate([obs, pre_action_prob], axis=1)
obs = to_ndarray(obs).astype(np.float32)
rew = np.concatenate(rew)
rew = np.stack(rew)
self._eval_episode_return += np.sum(rew)
self.cur_step += 1

Expand Down

0 comments on commit 3956ae3

Please sign in to comment.