@@ -147,14 +147,36 @@ def __init__(
147
147
embedding_size = hidden_size_list [- 1 ]
148
148
self .mixer = mixer
149
149
if self .mixer :
150
- if len (global_obs_shape ) == 1 :
150
+ global_obs_shape_type = self ._get_global_obs_shape_type (global_obs_shape )
151
+
152
+ if global_obs_shape_type == "flat" :
151
153
self ._mixer = Mixer (agent_num , global_obs_shape , embedding_size , activation = activation )
152
154
self ._global_state_encoder = nn .Identity ()
153
- elif len ( global_obs_shape ) == 3 :
155
+ elif global_obs_shape_type == "image" :
154
156
self ._mixer = Mixer (agent_num , embedding_size , embedding_size , activation = activation )
155
- self ._global_state_encoder = ConvEncoder (global_obs_shape , hidden_size_list = hidden_size_list , activation = activation , norm_type = 'BN' )
157
+ self ._global_state_encoder = ConvEncoder (
158
+ global_obs_shape , hidden_size_list = hidden_size_list , activation = activation , norm_type = 'BN'
159
+ )
156
160
else :
157
- raise ValueError ("Not support global_obs_shape: {}" .format (global_obs_shape ))
161
+ raise ValueError (f"Unsupported global_obs_shape: { global_obs_shape } " )
162
+
163
+ def _get_global_obs_shape_type (self , global_obs_shape : Union [int , List [int ]]) -> str :
164
+ """
165
+ Overview:
166
+ Determine the type of global observation shape.
167
+
168
+ Arguments:
169
+ - global_obs_shape (:obj:`int` or :obj:`List[int]`): The global observation state.
170
+
171
+ Returns:
172
+ - str: 'flat' for 1D observation or 'image' for 3D observation.
173
+ """
174
+ if isinstance (global_obs_shape , int ) or (isinstance (global_obs_shape , list ) and len (global_obs_shape ) == 1 ):
175
+ return "flat"
176
+ elif isinstance (global_obs_shape , list ) and len (global_obs_shape ) == 3 :
177
+ return "image"
178
+ else :
179
+ raise ValueError (f"Unsupported global_obs_shape: { global_obs_shape } " )
158
180
159
181
def forward (self , data : dict , single_step : bool = True ) -> dict :
160
182
"""
@@ -214,18 +236,37 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
214
236
agent_q_act = torch .gather (agent_q , dim = - 1 , index = action .unsqueeze (- 1 ))
215
237
agent_q_act = agent_q_act .squeeze (- 1 ) # T, B, A
216
238
if self .mixer :
217
- if len (global_state .shape ) == 5 :
218
- global_state_embedding = self ._global_state_encoder (global_state .reshape (- 1 , * global_state .shape [- 3 :])).reshape (global_state .shape [0 ], global_state .shape [1 ], - 1 )
219
- else :
220
- global_state_embedding = self ._global_state_encoder (global_state )
239
+ global_state_embedding = self ._process_global_state (global_state )
221
240
total_q = self ._mixer (agent_q_act , global_state_embedding )
222
241
else :
223
- total_q = agent_q_act .sum (- 1 )
242
+ total_q = agent_q_act .sum (dim = - 1 )
243
+
224
244
if single_step :
225
245
total_q , agent_q = total_q .squeeze (0 ), agent_q .squeeze (0 )
246
+
226
247
return {
227
248
'total_q' : total_q ,
228
249
'logit' : agent_q ,
229
250
'next_state' : next_state ,
230
251
'action_mask' : data ['obs' ]['action_mask' ]
231
252
}
253
+ def _process_global_state (self , global_state : torch .Tensor ) -> torch .Tensor :
254
+ """
255
+ Process the global state to obtain an embedding.
256
+
257
+ Arguments:
258
+ - global_state (:obj:`torch.Tensor`): The global state tensor.
259
+
260
+ Returns:
261
+ - (:obj:`torch.Tensor`): The processed global state embedding.
262
+ """
263
+ # If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
264
+ if global_state .dim () == 5 :
265
+ # Reshape and apply the global state encoder
266
+ batch_time_shape = global_state .shape [:2 ] # [batch_size, time_steps]
267
+ reshaped_state = global_state .view (- 1 , * global_state .shape [- 3 :]) # Collapse batch and time dims
268
+ encoded_state = self ._global_state_encoder (reshaped_state )
269
+ return encoded_state .view (* batch_time_shape , - 1 ) # Reshape back to [batch_size, time_steps, embedding_dim]
270
+ else :
271
+ # For lower-dimensional states, apply the encoder directly
272
+ return self ._global_state_encoder (global_state )
0 commit comments