Skip to content

Commit d182256

Browse files
committed
polish(pu): polish comments in qmix
1 parent efd472e commit d182256

File tree

4 files changed

+42
-5
lines changed

4 files changed

+42
-5
lines changed

ding/model/template/qmix.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,9 @@ def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) ->
167167
Overview:
168168
Determine the type of global observation shape.
169169
Arguments:
170-
- global_obs_shape (:obj:`int` or :obj:`List[int]`): The global observation state.
170+
- global_obs_shape (Union[:obj:`int`, :obj:`List[int]`]): The global observation state.
171171
Returns:
172-
- str: 'flat' for 1D observation or 'image' for 3D observation.
172+
- (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation.
173173
"""
174174
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
175175
return "flat"
@@ -211,8 +211,14 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
211211
agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
212212
'prev_state']
213213
action = data.get('action', None)
214+
# If single_step is True, add a new dimension at the front of agent_state
215+
# This is necessary to maintain the expected input shape for the model,
216+
# which requires a time step dimension even when processing a single step.
214217
if single_step:
215218
agent_state = agent_state.unsqueeze(0)
219+
# If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state
220+
# This ensures that global_state has the same number of dimensions as agent_state,
221+
# allowing for consistent processing in the forward computation.
216222
if single_step and len(global_state.shape) == 2:
217223
global_state = global_state.unsqueeze(0)
218224
T, B, A = agent_state.shape[:3]

ding/model/template/tests/test_qmix.py

+31
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,34 @@ def test_qmix():
4343
is_differentiable(loss, qmix_model)
4444
data.pop('action')
4545
output = qmix_model(data, single_step=False)
46+
47+
48+
@pytest.mark.unittest
49+
def test_qmix_process_global_state():
50+
# Test the behavior of the _process_global_state method with different global_obs_shape types
51+
agent_num, obs_dim, global_obs_dim, action_dim = 4, 32, 32 * 4, 9
52+
embedding_dim = 64
53+
54+
# Case 1: Test "flat" type global_obs_shape
55+
global_obs_shape = global_obs_dim # Flat global_obs_shape
56+
qmix_model_flat = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)
57+
58+
# Simulate input for the "flat" type global_state
59+
batch_size, time_steps = 3, 8
60+
global_state_flat = torch.randn(batch_size, time_steps, global_obs_dim)
61+
processed_flat = qmix_model_flat._process_global_state(global_state_flat)
62+
63+
# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
64+
assert processed_flat.shape == (batch_size, time_steps, global_obs_dim)
65+
66+
# Case 2: Test "image" type global_obs_shape
67+
global_obs_shape = [3, 64, 64] # Image-shaped global_obs_shape (C, H, W)
68+
qmix_model_image = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)
69+
70+
# Simulate input for the "image" type global_state
71+
C, H, W = global_obs_shape
72+
global_state_image = torch.randn(batch_size, time_steps, C, H, W)
73+
processed_image = qmix_model_image._process_global_state(global_state_image)
74+
75+
# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
76+
assert processed_image.shape == (batch_size, time_steps, embedding_dim)

dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
n_pistons = 20
44
collector_env_num = 8
55
evaluator_env_num = 8
6+
max_env_step = 3e6
67

78
main_config = dict(
89
exp_name=f'data_pistonball/ptz_pistonball_n{n_pistons}_qmix_seed0',
@@ -18,7 +19,6 @@
1819
n_evaluator_episode=evaluator_env_num,
1920
stop_value=1e6,
2021
manager=dict(shared_memory=False,),
21-
max_env_step=3e6,
2222
),
2323
policy=dict(
2424
cuda=True,
@@ -76,4 +76,4 @@
7676
if __name__ == '__main__':
7777
# or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0`
7878
from ding.entry import serial_pipeline
79-
serial_pipeline((main_config, create_config), seed=0, max_env_step=main_config.env.max_env_step)
79+
serial_pipeline((main_config, create_config), seed=0, max_env_step=max_env_step)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
'responses', # interaction
7676
'URLObject', # interaction
7777
'pynng', # parallel
78-
'sniffio', # parallel
78+
'sniffio', # parallel
7979
'redis', # parallel
8080
'mpire>=2.3.5', # parallel
8181
],

0 commit comments

Comments
 (0)