1
- from typing import Union , List
1
+ from functools import reduce
2
+ from typing import List , Union
3
+
2
4
import torch
3
5
import torch .nn as nn
4
6
import torch .nn .functional as F
5
- from functools import reduce
6
- from ding .utils import list_split , MODEL_REGISTRY
7
- from ding .torch_utils import fc_block , MLP
7
+ from ding .torch_utils import MLP , fc_block
8
+ from ding .utils import MODEL_REGISTRY , list_split
9
+
10
+ from ..common import ConvEncoder
8
11
from .q_learning import DRQN
9
12
10
13
@@ -111,7 +114,7 @@ def __init__(
111
114
self ,
112
115
agent_num : int ,
113
116
obs_shape : int ,
114
- global_obs_shape : int ,
117
+ global_obs_shape : Union [ int , List [ int ]] ,
115
118
action_shape : int ,
116
119
hidden_size_list : list ,
117
120
mixer : bool = True ,
@@ -146,8 +149,34 @@ def __init__(
146
149
embedding_size = hidden_size_list [- 1 ]
147
150
self .mixer = mixer
148
151
if self .mixer :
149
- self ._mixer = Mixer (agent_num , global_obs_shape , embedding_size , activation = activation )
150
- self ._global_state_encoder = nn .Identity ()
152
+ global_obs_shape_type = self ._get_global_obs_shape_type (global_obs_shape )
153
+
154
+ if global_obs_shape_type == "flat" :
155
+ self ._mixer = Mixer (agent_num , global_obs_shape , embedding_size , activation = activation )
156
+ self ._global_state_encoder = nn .Identity ()
157
+ elif global_obs_shape_type == "image" :
158
+ self ._mixer = Mixer (agent_num , embedding_size , embedding_size , activation = activation )
159
+ self ._global_state_encoder = ConvEncoder (
160
+ global_obs_shape , hidden_size_list = hidden_size_list , activation = activation , norm_type = 'BN'
161
+ )
162
+ else :
163
+ raise ValueError (f"Unsupported global_obs_shape: { global_obs_shape } " )
164
+
165
+ def _get_global_obs_shape_type (self , global_obs_shape : Union [int , List [int ]]) -> str :
166
+ """
167
+ Overview:
168
+ Determine the type of global observation shape.
169
+ Arguments:
170
+ - global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state.
171
+ Returns:
172
+ - obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation.
173
+ """
174
+ if isinstance (global_obs_shape , int ) or (isinstance (global_obs_shape , list ) and len (global_obs_shape ) == 1 ):
175
+ return "flat"
176
+ elif isinstance (global_obs_shape , list ) and len (global_obs_shape ) == 3 :
177
+ return "image"
178
+ else :
179
+ raise ValueError (f"Unsupported global_obs_shape: { global_obs_shape } " )
151
180
152
181
def forward (self , data : dict , single_step : bool = True ) -> dict :
153
182
"""
@@ -182,8 +211,16 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
182
211
agent_state , global_state , prev_state = data ['obs' ]['agent_state' ], data ['obs' ]['global_state' ], data [
183
212
'prev_state' ]
184
213
action = data .get ('action' , None )
214
+ # If single_step is True, add a new dimension at the front of agent_state
215
+ # This is necessary to maintain the expected input shape for the model,
216
+ # which requires a time step dimension even when processing a single step.
185
217
if single_step :
186
- agent_state , global_state = agent_state .unsqueeze (0 ), global_state .unsqueeze (0 )
218
+ agent_state = agent_state .unsqueeze (0 )
219
+ # If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state
220
+ # This ensures that global_state has the same number of dimensions as agent_state,
221
+ # allowing for consistent processing in the forward computation.
222
+ if single_step and len (global_state .shape ) == 2 :
223
+ global_state = global_state .unsqueeze (0 )
187
224
T , B , A = agent_state .shape [:3 ]
188
225
assert len (prev_state ) == B and all (
189
226
[len (p ) == A for p in prev_state ]
@@ -205,15 +242,38 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
205
242
agent_q_act = torch .gather (agent_q , dim = - 1 , index = action .unsqueeze (- 1 ))
206
243
agent_q_act = agent_q_act .squeeze (- 1 ) # T, B, A
207
244
if self .mixer :
208
- global_state_embedding = self ._global_state_encoder (global_state )
245
+ global_state_embedding = self ._process_global_state (global_state )
209
246
total_q = self ._mixer (agent_q_act , global_state_embedding )
210
247
else :
211
- total_q = agent_q_act .sum (- 1 )
248
+ total_q = agent_q_act .sum (dim = - 1 )
249
+
212
250
if single_step :
213
251
total_q , agent_q = total_q .squeeze (0 ), agent_q .squeeze (0 )
252
+
214
253
return {
215
254
'total_q' : total_q ,
216
255
'logit' : agent_q ,
217
256
'next_state' : next_state ,
218
257
'action_mask' : data ['obs' ]['action_mask' ]
219
258
}
259
+
260
+ def _process_global_state (self , global_state : torch .Tensor ) -> torch .Tensor :
261
+ """
262
+ Overview:
263
+ Process the global state to obtain an embedding.
264
+ Arguments:
265
+ - global_state (:obj:`torch.Tensor`): The global state tensor.
266
+
267
+ Returns:
268
+ - global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding.
269
+ """
270
+ # If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
271
+ if global_state .dim () == 5 :
272
+ # Reshape and apply the global state encoder
273
+ batch_time_shape = global_state .shape [:2 ] # [batch_size, time_steps]
274
+ reshaped_state = global_state .view (- 1 , * global_state .shape [- 3 :]) # Collapse batch and time dims
275
+ encoded_state = self ._global_state_encoder (reshaped_state )
276
+ return encoded_state .view (* batch_time_shape , - 1 ) # Reshape back to [batch_size, time_steps, embedding_dim]
277
+ else :
278
+ # For lower-dimensional states, apply the encoder directly
279
+ return self ._global_state_encoder (global_state )
0 commit comments