Skip to content

Commit 55dc254

Browse files
committed
polish(pu): adapt qmix's mixer to support image obs
1 parent e916841 commit 55dc254

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

ding/model/template/qmix.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from functools import reduce
66
from ding.utils import list_split, MODEL_REGISTRY
77
from ding.torch_utils import fc_block, MLP
8+
from ..common import ConvEncoder
89
from .q_learning import DRQN
910

1011

@@ -111,7 +112,7 @@ def __init__(
111112
self,
112113
agent_num: int,
113114
obs_shape: int,
114-
global_obs_shape: int,
115+
global_obs_shape: Union[int, List[int]],
115116
action_shape: int,
116117
hidden_size_list: list,
117118
mixer: bool = True,
@@ -146,8 +147,14 @@ def __init__(
146147
embedding_size = hidden_size_list[-1]
147148
self.mixer = mixer
148149
if self.mixer:
149-
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
150-
self._global_state_encoder = nn.Identity()
150+
if len(global_obs_shape) == 1:
151+
self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
152+
self._global_state_encoder = nn.Identity()
153+
elif len(global_obs_shape) == 3:
154+
self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation)
155+
self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN')
156+
else:
157+
raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape))
151158

152159
def forward(self, data: dict, single_step: bool = True) -> dict:
153160
"""
@@ -183,7 +190,9 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
183190
'prev_state']
184191
action = data.get('action', None)
185192
if single_step:
186-
agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
193+
agent_state = agent_state.unsqueeze(0)
194+
if single_step and len(global_state.shape) == 2:
195+
global_state = global_state.unsqueeze(0)
187196
T, B, A = agent_state.shape[:3]
188197
assert len(prev_state) == B and all(
189198
[len(p) == A for p in prev_state]

dizoo/petting_zoo/config/ptz_pistonball_qmix_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
shared_memory=False,
2222
reset_timeout=6000,
2323
),
24+
max_env_step=3e6,
2425
),
2526
policy=dict(
2627
cuda=True,
@@ -30,8 +31,7 @@
3031
global_obs_shape=(3, 560, 880), # Global state shape
3132
action_shape=3, # Discrete actions (0, 1, 2)
3233
hidden_size_list=[128, 128, 64],
33-
# mixer=True, # TODO: mixer is not supported image observation now
34-
mixer=False,
34+
mixer=True,
3535
),
3636
learn=dict(
3737
update_per_collect=100,
@@ -73,4 +73,4 @@
7373
if __name__ == '__main__':
7474
# or you can enter `ding -m serial -c ptz_pistonball_qmix_config.py -s 0`
7575
from ding.entry import serial_pipeline
76-
serial_pipeline((main_config, create_config), seed=0)
76+
serial_pipeline((main_config, create_config), seed=0, max_env_step=main_config.env.max_env_step)

0 commit comments

Comments
 (0)