diff --git a/ding/model/template/q_learning.py b/ding/model/template/q_learning.py index ece076bd81..1114fc3af0 100644 --- a/ding/model/template/q_learning.py +++ b/ding/model/template/q_learning.py @@ -26,16 +26,17 @@ class DQN(nn.Module): """ def __init__( - self, - obs_shape: Union[int, SequenceType], - action_shape: Union[int, SequenceType], - encoder_hidden_size_list: SequenceType = [128, 128, 64], - dueling: bool = True, - head_hidden_size: Optional[int] = None, - head_layer_num: int = 1, - activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None, - dropout: Optional[float] = None + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + encoder_hidden_size_list: SequenceType = [128, 128, 64], + dueling: bool = True, + head_hidden_size: Optional[int] = None, + head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + dropout: Optional[float] = None, + init_bias: Optional[float] = None, ) -> None: """ Overview: @@ -99,6 +100,9 @@ def __init__( norm_type=norm_type, dropout=dropout ) + if init_bias is not None and head_cls == DuelingHead: + # Zero the last layer bias of advantage head + self.head.A[-1][0].bias.data.fill_(init_bias) def forward(self, x: torch.Tensor) -> Dict: """ diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index 77cc9a2a17..295def03f1 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -61,12 +61,17 @@ def default_preprocess_learn( reward = data['reward'] if len(reward.shape) == 1: reward = reward.unsqueeze(1) - # reward: (batch_size, nstep) -> (nstep, batch_size) - # reversed_shape = [i for i in range(len(reward.shape))][::-1] - temp_shape = [i for i in range(len(reward.shape))] - temp_shape = [temp_shape[-1]] + temp_shape[:-1] - data['reward'] = reward.permute(temp_shape).contiguous() - # data['reward'] = reward.transpose(0, -1).contiguous() + # single agent reward: (batch_size, nstep) -> (nstep, batch_size) + # multi-agent reward: (batch_size, agent_dim, nstep) -> (nstep, batch_size, agent_dim) + # Assuming 'reward' is a PyTorch tensor with shape (batch_size, nstep) or (batch_size, agent_dim, nstep) + if reward.ndim == 2: + # For a 2D tensor, simply transpose it to get (nstep, batch_size) + data['reward'] = reward.transpose(0, 1).contiguous() + elif reward.ndim == 3: + # For a 3D tensor, move the last dimension to the front to get (nstep, batch_size, agent_dim) + data['reward'] = reward.transpose(0, 2).transpose(1, 2).contiguous() + else: + raise ValueError("The 'reward' tensor must be either 2D or 3D.") else: if data['reward'].dim() == 2 and data['reward'].shape[1] == 1: data['reward'] = data['reward'].squeeze(-1) diff --git a/ding/rl_utils/adder.py b/ding/rl_utils/adder.py index 413f137827..f4faa572a8 100644 --- a/ding/rl_utils/adder.py +++ b/ding/rl_utils/adder.py @@ -145,10 +145,6 @@ def get_nstep_return_data( for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))], dim=-1 ) - # try: - # assert len(data[i]['reward']) == 300 - # except: - # print(len(data[i]['reward'])) data[i]['done'] = data[-1]['done'] if correct_terminate_gamma: data[i]['value_gamma'] = gamma ** (len(data) - i - 1) diff --git a/ding/rl_utils/td.py b/ding/rl_utils/td.py index fc85e2ce54..cb2443cf88 100644 --- a/ding/rl_utils/td.py +++ b/ding/rl_utils/td.py @@ -690,7 +690,10 @@ def q_nstep_td_error( if weight is None: weight = torch.ones_like(reward) - if len(action.shape) == 1 or len(action.shape) < len(q.shape): # single agent case + if len(action.shape) == 1 or len(action.shape) < len(q.shape): + # we need to unsqueeze action and q to make them have the same shape + # e.g. single agent case: action is [B, ] and q is [B, ] + # e.g. multi agent case: action is [B, agent_num] and q is [B, agent_num, action_shape] action = action.unsqueeze(-1) elif len(action.shape) > 1: # MARL case reward = reward.unsqueeze(-1) diff --git a/dizoo/ising_env/config/ising_mfq_config.py b/dizoo/ising_env/config/ising_mfq_config.py index 9cdf1a0dd1..1480eb379e 100644 --- a/dizoo/ising_env/config/ising_mfq_config.py +++ b/dizoo/ising_env/config/ising_mfq_config.py @@ -8,7 +8,7 @@ agent_view_sight = 1 ising_mfq_config = dict( - exp_name='ising_mfq_seed0_debug', + exp_name='ising_mfq_seed0', env=dict( collector_env_num=8, evaluator_env_num=8, @@ -24,6 +24,7 @@ obs_shape=obs_shape + action_shape, # for we will concat the pre_action_prob into obs action_shape=action_shape, encoder_hidden_size_list=[128, 128, 512], + init_bias=0, ), nstep=3, discount_factor=0.99, @@ -53,7 +54,7 @@ type='ising_model', import_names=['dizoo.ising_env.envs.ising_model_env'], ), - env_manager=dict(type='base'), + env_manager=dict(type='subprocess'), policy=dict(type='dqn'), ) ising_mfq_create_config = EasyDict(ising_mfq_create_config) @@ -65,7 +66,4 @@ from ding.model import DQN seed = 1 set_pkg_seed(seed) - model = DQN(**ising_mfq_config.policy.model) - model.head.A[-1][0].bias.data.fill_(0) # zero last layer bias - # print("init model successful") - serial_pipeline((main_config, create_config), seed=seed, model=model, max_env_step=5e4) + serial_pipeline((main_config, create_config), seed=seed, max_env_step=5e4) diff --git a/dizoo/ising_env/envs/ising_model_env.py b/dizoo/ising_env/envs/ising_model_env.py index e6659d1bb1..731d2a8c06 100644 --- a/dizoo/ising_env/envs/ising_model_env.py +++ b/dizoo/ising_env/envs/ising_model_env.py @@ -8,18 +8,25 @@ from ding.envs import BaseEnv, BaseEnvTimestep from ding.torch_utils import to_ndarray, to_list from ding.utils import ENV_REGISTRY - -# ising_model_dir = os.path.join(os.path.dirname(__file__), 'ising_model') -# if ising_model_dir not in sys.path: -# sys.path.append(ising_model_dir) -# from ising_model.multiagent.environment import IsingMultiAgentEnv -# import ising_model from dizoo.ising_env.envs.ising_model.multiagent.environment import IsingMultiAgentEnv import dizoo.ising_env.envs.ising_model as ising_model_ @ENV_REGISTRY.register('ising_model') class IsingModelEnv(BaseEnv): + """ + Overview: + Ising Model Environment for Multi-Agent Reinforcement Learning according to the paper: \ + [Mean Field Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1802.05438). \ + The environment is a grid of agents, each of which can be in one of two states: \ + spin up or spin down. The agents interact with their neighbors according to the Ising model, \ + and the goal is to maximize the global order parameter, which is the average spin of all agents. \ + Details of the environment can be found in the \ + [DI-engine-Doc](https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/index.html). + Interface: + `__init__`, `reset`, `close`, `seed`, `step`, `random_action`, `num_agents`, \ + `observation_space`, `action_space`, `reward_space`. + """ def __init__(self, cfg: dict) -> None: self._cfg = cfg