@@ -30,15 +30,20 @@ def __init__(self, cfg: dict) -> None:
30
30
31
31
def calculate_action_prob (self , actions ):
32
32
num_action = self ._action_space .n
33
- N = actions .shape [0 ]
33
+ N = actions .shape [0 ] # agent_num
34
34
# Convert actions to one_hot encoding
35
35
one_hot_actions = np .eye (num_action )[actions .flatten ()]
36
36
action_prob = np .zeros ((N , num_action ))
37
37
38
38
for i in range (N ):
39
- # Exclude agent i's actions and calculate the one_hot average of all other agent actions
40
- exclude_current = np .delete (one_hot_actions , i , axis = 0 )
41
- action_prob [i ] = exclude_current .mean (axis = 0 )
39
+ # Select only the one_hot actions of agents visible to agent i
40
+ visible_actions = one_hot_actions [self ._env .agents [i ].spin_mask == 1 ]
41
+ if visible_actions .size > 0 :
42
+ # Calculate the average of the one_hot encoding for visible agents only
43
+ action_prob [i ] = visible_actions .mean (axis = 0 )
44
+ else :
45
+ # If no visible agents, action_prob remains zero for agent i
46
+ action_prob [i ] = np .zeros (num_action )
42
47
43
48
return action_prob
44
49
@@ -62,10 +67,11 @@ def reset(self) -> np.ndarray:
62
67
obs = self ._env ._reset ()
63
68
obs = np .stack (obs )
64
69
self .pre_action = np .zeros (self ._cfg .num_agents , dtype = np .int32 )
65
- pre_action_prob = np .zeros ((self ._cfg .num_agents , self ._action_space .n ))
70
+ # consider the last global state as pre action prob
71
+ pre_action_prob = self .calculate_action_prob (self ._env .world .global_state .flatten ().astype (int ))
66
72
obs = np .concatenate ([obs , pre_action_prob ], axis = 1 )
67
73
obs = to_ndarray (obs ).astype (np .float32 )
68
- self ._eval_episode_return = np . zeros (( self . _cfg . num_agents , 1 ), dtype = np . float32 )
74
+ self ._eval_episode_return = 0
69
75
return obs
70
76
71
77
def close (self ) -> None :
@@ -90,8 +96,9 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
90
96
obs = np .concatenate ([obs , pre_action_prob ], axis = 1 )
91
97
obs = to_ndarray (obs ).astype (np .float32 )
92
98
rew = np .stack (rew )
93
- rew = to_ndarray (rew ).astype (np .float32 )
94
- self ._eval_episode_return += rew
99
+ rew = np .squeeze (to_ndarray (rew ).astype (np .float32 ), axis = 1 )
100
+ # rew = to_ndarray(rew).astype(np.float32)
101
+ self ._eval_episode_return += np .sum (rew )
95
102
96
103
done = done [0 ] # dones are the same for all agents
97
104
if done :
0 commit comments