|
5 | 5 | from functools import reduce
|
6 | 6 | from ding.utils import list_split, MODEL_REGISTRY
|
7 | 7 | from ding.torch_utils import fc_block, MLP
|
| 8 | +from ..common import ConvEncoder |
8 | 9 | from .q_learning import DRQN
|
9 | 10 |
|
10 | 11 |
|
@@ -111,7 +112,7 @@ def __init__(
|
111 | 112 | self,
|
112 | 113 | agent_num: int,
|
113 | 114 | obs_shape: int,
|
114 |
| - global_obs_shape: int, |
| 115 | + global_obs_shape: Union[int, List[int]], |
115 | 116 | action_shape: int,
|
116 | 117 | hidden_size_list: list,
|
117 | 118 | mixer: bool = True,
|
@@ -146,8 +147,14 @@ def __init__(
|
146 | 147 | embedding_size = hidden_size_list[-1]
|
147 | 148 | self.mixer = mixer
|
148 | 149 | if self.mixer:
|
149 |
| - self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) |
150 |
| - self._global_state_encoder = nn.Identity() |
| 150 | + if len(global_obs_shape) == 1: |
| 151 | + self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation) |
| 152 | + self._global_state_encoder = nn.Identity() |
| 153 | + elif len(global_obs_shape) == 3: |
| 154 | + self._mixer = Mixer(agent_num, embedding_size, embedding_size, activation=activation) |
| 155 | + self._global_state_encoder = ConvEncoder(global_obs_shape, hidden_size_list=hidden_size_list, activation=activation, norm_type='BN') |
| 156 | + else: |
| 157 | + raise ValueError("Not support global_obs_shape: {}".format(global_obs_shape)) |
151 | 158 |
|
152 | 159 | def forward(self, data: dict, single_step: bool = True) -> dict:
|
153 | 160 | """
|
@@ -183,7 +190,9 @@ def forward(self, data: dict, single_step: bool = True) -> dict:
|
183 | 190 | 'prev_state']
|
184 | 191 | action = data.get('action', None)
|
185 | 192 | if single_step:
|
186 |
| - agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0) |
| 193 | + agent_state = agent_state.unsqueeze(0) |
| 194 | + if single_step and len(global_state.shape) == 2: |
| 195 | + global_state = global_state.unsqueeze(0) |
187 | 196 | T, B, A = agent_state.shape[:3]
|
188 | 197 | assert len(prev_state) == B and all(
|
189 | 198 | [len(p) == A for p in prev_state]
|
|
0 commit comments