Skip to content

Commit 3d18513

Browse files
author
nighood
committed
fix(rjy): try to fix reward problem
1 parent ad8c53c commit 3d18513

File tree

5 files changed

+29
-15
lines changed

5 files changed

+29
-15
lines changed

ding/model/common/head.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
norm_type=norm_type
6262
), block(hidden_size, output_size)
6363
)
64+
nn.init.normal_(self.Q[1].weight, 0, 0.2)
6465

6566
def forward(self, x: torch.Tensor) -> Dict:
6667
"""

ding/policy/common_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def default_preprocess_learn(
6262
if len(reward.shape) == 1:
6363
reward = reward.unsqueeze(1)
6464
# reward: (batch_size, nstep) -> (nstep, batch_size)
65-
data['reward'] = reward.permute(1, 0).contiguous()
65+
# reversed_shape = [i for i in range(len(reward.shape))][::-1]
66+
# data['reward'] = reward.permute(reversed_shape).contiguous()
67+
data['reward'] = reward.transpose(0, -1).contiguous()
6668
else:
6769
if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
6870
data['reward'] = data['reward'].squeeze(-1)

ding/rl_utils/adder.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_nstep_return_data(
121121
"""
122122
if nstep == 1:
123123
return data
124-
fake_reward = torch.zeros(1)
124+
fake_reward = torch.zeros_like(data[0]['reward'])
125125
next_obs_flag = 'next_obs' in data[0]
126126
for i in range(len(data) - nstep):
127127
# update keys ['next_obs', 'reward', 'done'] with their n-step value
@@ -130,7 +130,7 @@ def get_nstep_return_data(
130130
if cum_reward:
131131
data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(nstep)])
132132
else:
133-
data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)])
133+
data[i]['reward'] = torch.stack([data[i + j]['reward'] for j in range(nstep)], dim = -1)
134134
data[i]['done'] = data[i + nstep - 1]['done']
135135
if correct_terminate_gamma:
136136
data[i]['value_gamma'] = gamma ** nstep
@@ -140,10 +140,15 @@ def get_nstep_return_data(
140140
if cum_reward:
141141
data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(len(data) - i)])
142142
else:
143-
data[i]['reward'] = torch.cat(
143+
data[i]['reward'] = torch.stack(
144144
[data[i + j]['reward']
145-
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))]
145+
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))],
146+
dim = -1
146147
)
148+
try:
149+
assert len(data[i]['reward']) == 300
150+
except:
151+
print(len(data[i]['reward']))
147152
data[i]['done'] = data[-1]['done']
148153
if correct_terminate_gamma:
149154
data[i]['value_gamma'] = gamma ** (len(data) - i - 1)

dizoo/ising_env/config/ising_mfq_config.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
collector_env_num=8,
1313
evaluator_env_num=8,
1414
n_evaluator_episode=8,
15-
stop_value=20,
1615
num_agents=num_agents,
1716
dim_spin=dim_spin,
1817
agent_view_sight=agent_view_sight,
@@ -62,4 +61,4 @@
6261
if __name__ == '__main__':
6362
# or you can enter `ding -m serial -c ising_mfq_config.py -s 0`
6463
from ding.entry import serial_pipeline
65-
serial_pipeline((main_config, create_config), seed=0)
64+
serial_pipeline((main_config, create_config), seed=0, max_env_step=1e5)

dizoo/ising_env/envs/ising_model_env.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,20 @@ def __init__(self, cfg: dict) -> None:
3030

3131
def calculate_action_prob(self, actions):
3232
num_action = self._action_space.n
33-
N = actions.shape[0]
33+
N = actions.shape[0] # agent_num
3434
# Convert actions to one_hot encoding
3535
one_hot_actions = np.eye(num_action)[actions.flatten()]
3636
action_prob = np.zeros((N, num_action))
3737

3838
for i in range(N):
39-
# Exclude agent i's actions and calculate the one_hot average of all other agent actions
40-
exclude_current = np.delete(one_hot_actions, i, axis=0)
41-
action_prob[i] = exclude_current.mean(axis=0)
39+
# Select only the one_hot actions of agents visible to agent i
40+
visible_actions = one_hot_actions[self._env.agents[i].spin_mask == 1]
41+
if visible_actions.size > 0:
42+
# Calculate the average of the one_hot encoding for visible agents only
43+
action_prob[i] = visible_actions.mean(axis=0)
44+
else:
45+
# If no visible agents, action_prob remains zero for agent i
46+
action_prob[i] = np.zeros(num_action)
4247

4348
return action_prob
4449

@@ -62,10 +67,11 @@ def reset(self) -> np.ndarray:
6267
obs = self._env._reset()
6368
obs = np.stack(obs)
6469
self.pre_action = np.zeros(self._cfg.num_agents, dtype=np.int32)
65-
pre_action_prob = np.zeros((self._cfg.num_agents, self._action_space.n))
70+
# consider the last global state as pre action prob
71+
pre_action_prob = self.calculate_action_prob(self._env.world.global_state.flatten().astype(int))
6672
obs = np.concatenate([obs, pre_action_prob], axis=1)
6773
obs = to_ndarray(obs).astype(np.float32)
68-
self._eval_episode_return = np.zeros((self._cfg.num_agents, 1), dtype=np.float32)
74+
self._eval_episode_return = 0
6975
return obs
7076

7177
def close(self) -> None:
@@ -90,8 +96,9 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
9096
obs = np.concatenate([obs, pre_action_prob], axis=1)
9197
obs = to_ndarray(obs).astype(np.float32)
9298
rew = np.stack(rew)
93-
rew = to_ndarray(rew).astype(np.float32)
94-
self._eval_episode_return += rew
99+
rew = np.squeeze(to_ndarray(rew).astype(np.float32), axis=1)
100+
# rew = to_ndarray(rew).astype(np.float32)
101+
self._eval_episode_return += np.sum(rew)
95102

96103
done = done[0] # dones are the same for all agents
97104
if done:

0 commit comments

Comments
 (0)