@@ -28,6 +28,20 @@ def __init__(self, cfg: dict) -> None:
28
28
self ._observation_space = gym .spaces .MultiBinary (4 * cfg .agent_view_sight )
29
29
self ._reward_space = gym .spaces .Box (low = float ("-inf" ), high = float ("inf" ), shape = (1 , ), dtype = np .float32 )
30
30
31
+ def calculate_action_prob (self , actions ):
32
+ num_action = self ._action_space .n
33
+ N = actions .shape [0 ]
34
+ # Convert actions to one_hot encoding
35
+ one_hot_actions = np .eye (num_action )[actions .flatten ()]
36
+ action_prob = np .zeros ((N , num_action ))
37
+
38
+ for i in range (N ):
39
+ # Exclude agent i's actions and calculate the one_hot average of all other agent actions
40
+ exclude_current = np .delete (one_hot_actions , i , axis = 0 )
41
+ action_prob [i ] = exclude_current .mean (axis = 0 )
42
+
43
+ return action_prob
44
+
31
45
def reset (self ) -> np .ndarray :
32
46
if hasattr (self , '_seed' ) and hasattr (self , '_dynamic_seed' ) and self ._dynamic_seed :
33
47
np_seed = 100 * np .random .randint (1 , 1000 )
@@ -47,6 +61,9 @@ def reset(self) -> np.ndarray:
47
61
self ._init_flag = True
48
62
obs = self ._env ._reset ()
49
63
obs = np .stack (obs )
64
+ self .pre_action = np .zeros (self ._cfg .num_agents , dtype = np .int32 )
65
+ pre_action_prob = np .zeros ((self ._cfg .num_agents , self ._action_space .n ))
66
+ obs = np .concatenate ([obs , pre_action_prob ], axis = 1 )
50
67
obs = to_ndarray (obs ).astype (np .float32 )
51
68
self ._eval_episode_return = np .zeros ((self ._cfg .num_agents , 1 ), dtype = np .float32 )
52
69
return obs
@@ -63,13 +80,20 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
63
80
64
81
def step (self , action : Union [np .ndarray , list ]) -> BaseEnvTimestep :
65
82
action = to_ndarray (action )
83
+ if (len (action .shape ) == 1 ):
84
+ action = np .expand_dims (action , axis = 1 )
66
85
obs , rew , done , order_param , ups , downs = self ._env ._step (action )
67
- info = {"order_param" : order_param , "ups" : ups , "downs" : downs }
86
+ info = {"order_param" : order_param , "ups" : ups , "downs" : downs , 'pre_action' : self .pre_action }
87
+ pre_action_prob = self .calculate_action_prob (self .pre_action )
88
+ self .pre_action = action
68
89
obs = np .stack (obs )
90
+ obs = np .concatenate ([obs , pre_action_prob ], axis = 1 )
69
91
obs = to_ndarray (obs ).astype (np .float32 )
70
92
rew = np .stack (rew )
71
93
rew = to_ndarray (rew ).astype (np .float32 )
72
94
self ._eval_episode_return += rew
95
+
96
+ done = done [0 ] # dones are the same for all agents
73
97
if done :
74
98
info ['eval_episode_return' ] = self ._eval_episode_return
75
99
return BaseEnvTimestep (obs , rew , done , info )
0 commit comments