Skip to content

Commit 9285a84

Browse files
committed
polish(pu): polish qmix.py
1 parent 5a0bdd8 commit 9285a84

File tree

1 file changed

+50
-9
lines changed

1 file changed

+50
-9
lines changed

ding/model/template/qmix.py

+50-9
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,36 @@ def __init__(
147147
embedding_size = hidden_size_list[-1]
148148
self.mixer = mixer
149149
if self.mixer:
150-
if len(global_obs_shape) == 1:
150+
global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape)
151+
152+
if global_obs_shape_type == "flat":
151153
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
152154
self._global_state_encoder = nn.Identity()
153-
elif len(global_obs_shape) == 3:
155+
elif global_obs_shape_type == "image":
154156
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
155-
self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN')
157+
self._global_state_encoder = ConvEncoder(
158+
global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN'
159+
)
156160
else:
157-
raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape))
161+
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")
162+
163+
def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str:
164+
"""
165+
Overview:
166+
Determine the type of global observation shape.
167+
168+
Arguments:
169+
- global_obs_shape (:obj:`int` or :obj:`List[int]`): The global observation state.
170+
171+
Returns:
172+
- str: 'flat' for 1D observation or 'image' for 3D observation.
173+
"""
174+
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
175+
return "flat"
176+
elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3:
177+
return "image"
178+
else:
179+
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")
158180

159181
def forward(self, data: dict, single_step: bool = True) -> dict:
160182
"""
@@ -214,18 +236,37 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
214236
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
215237
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
216238
if self.mixer:
217-
if len(global_state.shape) == 5:
218-
global_state_embedding = self._global_state_encoder(global_state.reshape(-1, *global_state.shape[-3:])).reshape(global_state.shape[0], global_state.shape[1], -1)
219-
else:
220-
global_state_embedding = self._global_state_encoder(global_state)
239+
global_state_embedding = self._process_global_state(global_state)
221240
total_q = self._mixer(agent_q_act, global_state_embedding)
222241
else:
223-
total_q = agent_q_act.sum(-1)
242+
total_q = agent_q_act.sum(dim=-1)
243+
224244
if single_step:
225245
total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)
246+
226247
return {
227248
'total_q': total_q,
228249
'logit': agent_q,
229250
'next_state': next_state,
230251
'action_mask': data['obs']['action_mask']
231252
}
253+
def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor:
254+
"""
255+
Process the global state to obtain an embedding.
256+
257+
Arguments:
258+
- global_state (:obj:`torch.Tensor`): The global state tensor.
259+
260+
Returns:
261+
- (:obj:`torch.Tensor`): The processed global state embedding.
262+
"""
263+
# If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
264+
if global_state.dim() == 5:
265+
# Reshape and apply the global state encoder
266+
batch_time_shape = global_state.shape[:2] # [batch_size, time_steps]
267+
reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims
268+
encoded_state = self._global_state_encoder(reshaped_state)
269+
return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim]
270+
else:
271+
# For lower-dimensional states, apply the encoder directly
272+
return self._global_state_encoder(global_state)

0 commit comments

Comments
 (0)