Skip to content

Commit 1158cd5

Browse files
authored
feature(pu): add pistonball_env, its unittest and qmix config (#833)
* feature(pu): add pistonball_env, its unittest and qmix config * polish(pu): pistonball reuse PTZRecordVideo * polish(pu): adapt qmix's mixer to support image obs * fix(pu): fix qmix's mixer to support image obs * sync code * polish(pu): polish ptz_pistonball_qmix_config.py * polish(pu): polish qmix.py * polish(pu): add normalize_reward in pistonball_env * polish(pu): polish hyper-parameters in ptz_pistonball_qmix_config.py * polish(pu): polish ptz_pistonball_qmix_config.py * style(pu): yapf format * polish(pu): polish comments in qmix * polish(pu): polish qmix comments
1 parent 1f198e9 commit 1158cd5

File tree

7 files changed

+539
-23
lines changed

7 files changed

+539
-23
lines changed

ding/model/template/qmix.py

+70-10
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import Union, List
1+
from functools import reduce
2+
from typing import List, Union
3+
24
import torch
35
import torch.nn as nn
46
import torch.nn.functional as F
5-
from functools import reduce
6-
from ding.utils import list_split, MODEL_REGISTRY
7-
from ding.torch_utils import fc_block, MLP
7+
from ding.torch_utils import MLP, fc_block
8+
from ding.utils import MODEL_REGISTRY, list_split
9+
10+
from ..common import ConvEncoder
811
from .q_learning import DRQN
912

1013

@@ -111,7 +114,7 @@ def __init__(
111114
self,
112115
agent_num: int,
113116
obs_shape: int,
114-
global_obs_shape: int,
117+
global_obs_shape: Union[int, List[int]],
115118
action_shape: int,
116119
hidden_size_list: list,
117120
mixer: bool = True,
@@ -146,8 +149,34 @@ def __init__(
146149
embedding_size = hidden_size_list[-1]
147150
self.mixer = mixer
148151
if self.mixer:
149-
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
150-
self._global_state_encoder = nn.Identity()
152+
global_obs_shape_type = self._get_global_obs_shape_type(global_obs_shape)
153+
154+
if global_obs_shape_type == "flat":
155+
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
156+
self._global_state_encoder = nn.Identity()
157+
elif global_obs_shape_type == "image":
158+
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
159+
self._global_state_encoder = ConvEncoder(
160+
global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN'
161+
)
162+
else:
163+
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")
164+
165+
def _get_global_obs_shape_type(self, global_obs_shape: Union[int, List[int]]) -> str:
166+
"""
167+
Overview:
168+
Determine the type of global observation shape.
169+
Arguments:
170+
- global_obs_shape (:obj:`Union[int, List[int]]`): The global observation state.
171+
Returns:
172+
- obs_shape_type (:obj:`str`): 'flat' for 1D observation or 'image' for 3D observation.
173+
"""
174+
if isinstance(global_obs_shape, int) or (isinstance(global_obs_shape, list) and len(global_obs_shape) == 1):
175+
return "flat"
176+
elif isinstance(global_obs_shape, list) and len(global_obs_shape) == 3:
177+
return "image"
178+
else:
179+
raise ValueError(f"Unsupported global_obs_shape: {global_obs_shape}")
151180

152181
def forward(self, data: dict, single_step: bool = True) -> dict:
153182
"""
@@ -182,8 +211,16 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
182211
agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
183212
'prev_state']
184213
action = data.get('action', None)
214+
# If single_step is True, add a new dimension at the front of agent_state
215+
# This is necessary to maintain the expected input shape for the model,
216+
# which requires a time step dimension even when processing a single step.
185217
if single_step:
186-
agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
218+
agent_state = agent_state.unsqueeze(0)
219+
# If single_step is True and global_state has 2 dimensions, add a new dimension at the front of global_state
220+
# This ensures that global_state has the same number of dimensions as agent_state,
221+
# allowing for consistent processing in the forward computation.
222+
if single_step and len(global_state.shape) == 2:
223+
global_state = global_state.unsqueeze(0)
187224
T, B, A = agent_state.shape[:3]
188225
assert len(prev_state) == B and all(
189226
[len(p) == A for p in prev_state]
@@ -205,15 +242,38 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
205242
agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
206243
agent_q_act = agent_q_act.squeeze(-1) # T, B, A
207244
if self.mixer:
208-
global_state_embedding = self._global_state_encoder(global_state)
245+
global_state_embedding = self._process_global_state(global_state)
209246
total_q = self._mixer(agent_q_act, global_state_embedding)
210247
else:
211-
total_q = agent_q_act.sum(-1)
248+
total_q = agent_q_act.sum(dim=-1)
249+
212250
if single_step:
213251
total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)
252+
214253
return {
215254
'total_q': total_q,
216255
'logit': agent_q,
217256
'next_state': next_state,
218257
'action_mask': data['obs']['action_mask']
219258
}
259+
260+
def _process_global_state(self, global_state: torch.Tensor) -> torch.Tensor:
261+
"""
262+
Overview:
263+
Process the global state to obtain an embedding.
264+
Arguments:
265+
- global_state (:obj:`torch.Tensor`): The global state tensor.
266+
267+
Returns:
268+
- global_state_embedding (:obj:`torch.Tensor`): The processed global state embedding.
269+
"""
270+
# If global_state has 5 dimensions, it's likely in the form [batch_size, time_steps, C, H, W]
271+
if global_state.dim() == 5:
272+
# Reshape and apply the global state encoder
273+
batch_time_shape = global_state.shape[:2] # [batch_size, time_steps]
274+
reshaped_state = global_state.view(-1, *global_state.shape[-3:]) # Collapse batch and time dims
275+
encoded_state = self._global_state_encoder(reshaped_state)
276+
return encoded_state.view(*batch_time_shape, -1) # Reshape back to [batch_size, time_steps, embedding_dim]
277+
else:
278+
# For lower-dimensional states, apply the encoder directly
279+
return self._global_state_encoder(global_state)

ding/model/template/tests/test_qmix.py

+31
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,34 @@ def test_qmix():
4343
is_differentiable(loss, qmix_model)
4444
data.pop('action')
4545
output = qmix_model(data, single_step=False)
46+
47+
48+
@pytest.mark.unittest
49+
def test_qmix_process_global_state():
50+
# Test the behavior of the _process_global_state method with different global_obs_shape types
51+
agent_num, obs_dim, global_obs_dim, action_dim = 4, 32, 32 * 4, 9
52+
embedding_dim = 64
53+
54+
# Case 1: Test "flat" type global_obs_shape
55+
global_obs_shape = global_obs_dim # Flat global_obs_shape
56+
qmix_model_flat = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)
57+
58+
# Simulate input for the "flat" type global_state
59+
batch_size, time_steps = 3, 8
60+
global_state_flat = torch.randn(batch_size, time_steps, global_obs_dim)
61+
processed_flat = qmix_model_flat._process_global_state(global_state_flat)
62+
63+
# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
64+
assert processed_flat.shape == (batch_size, time_steps, global_obs_dim)
65+
66+
# Case 2: Test "image" type global_obs_shape
67+
global_obs_shape = [3, 64, 64] # Image-shaped global_obs_shape (C, H, W)
68+
qmix_model_image = QMix(agent_num, obs_dim, global_obs_shape, action_dim, [64, 128, embedding_dim], mixer=True)
69+
70+
# Simulate input for the "image" type global_state
71+
C, H, W = global_obs_shape
72+
global_state_image = torch.randn(batch_size, time_steps, C, H, W)
73+
processed_image = qmix_model_image._process_global_state(global_state_image)
74+
75+
# Ensure the output shape is correct [batch_size, time_steps, embedding_dim]
76+
assert processed_image.shape == (batch_size, time_steps, embedding_dim)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from easydict import EasyDict
2+
3+
n_pistons = 20
4+
collector_env_num = 8
5+
evaluator_env_num = 8
6+
max_env_step = 3e6
7+
8+
main_config = dict(
9+
exp_name=f'data_pistonball/ptz_pistonball_n{n_pistons}_qmix_seed0',
10+
env=dict(
11+
env_family='butterfly',
12+
env_id='pistonball_v6',
13+
n_pistons=n_pistons,
14+
max_cycles=125,
15+
agent_obs_only=False,
16+
continuous_actions=False,
17+
collector_env_num=collector_env_num,
18+
evaluator_env_num=evaluator_env_num,
19+
n_evaluator_episode=evaluator_env_num,
20+
stop_value=1e6,
21+
manager=dict(shared_memory=False,),
22+
),
23+
policy=dict(
24+
cuda=True,
25+
model=dict(
26+
agent_num=n_pistons,
27+
obs_shape=(3, 457, 120), # RGB image observation shape for each piston agent
28+
global_obs_shape=(3, 560, 880), # Global state shape
29+
action_shape=3, # Discrete actions (0, 1, 2)
30+
hidden_size_list=[32, 64, 128, 256],
31+
mixer=True,
32+
),
33+
learn=dict(
34+
update_per_collect=20,
35+
batch_size=32,
36+
learning_rate=0.0001,
37+
clip_value=5,
38+
target_update_theta=0.001,
39+
discount_factor=0.99,
40+
double_q=True,
41+
),
42+
collect=dict(
43+
n_sample=16,
44+
unroll_len=5,
45+
env_num=collector_env_num,
46+
),
47+
eval=dict(env_num=evaluator_env_num),
48+
other=dict(
49+
eps=dict(
50+
type='exp',
51+
start=1.0,
52+
end=0.05,
53+
decay=100000,
54+
),
55+
replay_buffer=dict(
56+
replay_buffer_size=5000,
57+
),
58+
),
59+
),
60+
)
61+
main_config = EasyDict(main_config)
62+
63+
create_config = dict(
64+
env=dict(
65+
import_names=['dizoo.petting_zoo.envs.petting_zoo_pistonball_env'],
66+
type='petting_zoo_pistonball',
67+
),
68+
env_manager=dict(type='subprocess'),
69+
policy=dict(type='qmix'),
70+
)
71+
create_config = EasyDict(create_config)
72+
73+
ptz_pistonball_qmix_config = main_config
74+
ptz_pistonball_qmix_create_config = create_config
75+
76+
if __name__ == '__main__':
77+
# or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0`
78+
from ding.entry import serial_pipeline
79+
serial_pipeline((main_config, create_config), seed=0, max_env_step=max_env_step)

0 commit comments

Comments
 (0)