Skip to content

Commit

Permalink
polish(rjy): polish ising model env
Browse files Browse the repository at this point in the history
  • Loading branch information
nighood committed Apr 17, 2024
1 parent 017bfba commit d43ddc6
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 33 deletions.
24 changes: 14 additions & 10 deletions ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ class DQN(nn.Module):
"""

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None,
init_bias: Optional[float] = None,
) -> None:
"""
Overview:
Expand Down Expand Up @@ -99,6 +100,9 @@ def __init__(
norm_type=norm_type,
dropout=dropout
)
if init_bias is not None and head_cls == DuelingHead:
# Zero the last layer bias of advantage head
self.head.A[-1][0].bias.data.fill_(init_bias)

def forward(self, x: torch.Tensor) -> Dict:
"""
Expand Down
17 changes: 11 additions & 6 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,17 @@ def default_preprocess_learn(
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
# reward: (batch_size, nstep) -> (nstep, batch_size)
# reversed_shape = [i for i in range(len(reward.shape))][::-1]
temp_shape = [i for i in range(len(reward.shape))]
temp_shape = [temp_shape[-1]] + temp_shape[:-1]
data['reward'] = reward.permute(temp_shape).contiguous()
# data['reward'] = reward.transpose(0, -1).contiguous()
# single agent reward: (batch_size, nstep) -> (nstep, batch_size)
# multi-agent reward: (batch_size, agent_dim, nstep) -> (nstep, batch_size, agent_dim)
# Assuming 'reward' is a PyTorch tensor with shape (batch_size, nstep) or (batch_size, agent_dim, nstep)
if reward.ndim == 2:
# For a 2D tensor, simply transpose it to get (nstep, batch_size)
data['reward'] = reward.transpose(0, 1).contiguous()
elif reward.ndim == 3:
# For a 3D tensor, move the last dimension to the front to get (nstep, batch_size, agent_dim)
data['reward'] = reward.transpose(0, 2).transpose(1, 2).contiguous()
else:
raise ValueError("The 'reward' tensor must be either 2D or 3D.")
else:
if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
data['reward'] = data['reward'].squeeze(-1)
Expand Down
4 changes: 0 additions & 4 deletions ding/rl_utils/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,6 @@ def get_nstep_return_data(
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))],
dim=-1
)
# try:
# assert len(data[i]['reward']) == 300
# except:
# print(len(data[i]['reward']))
data[i]['done'] = data[-1]['done']
if correct_terminate_gamma:
data[i]['value_gamma'] = gamma ** (len(data) - i - 1)
Expand Down
5 changes: 4 additions & 1 deletion ding/rl_utils/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,10 @@ def q_nstep_td_error(
if weight is None:
weight = torch.ones_like(reward)

if len(action.shape) == 1 or len(action.shape) < len(q.shape): # single agent case
if len(action.shape) == 1 or len(action.shape) < len(q.shape):
# we need to unsqueeze action and q to make them have the same shape
# e.g. single agent case: action is [B, ] and q is [B, ]
# e.g. multi agent case: action is [B, agent_num] and q is [B, agent_num, action_shape]
action = action.unsqueeze(-1)
elif len(action.shape) > 1: # MARL case
reward = reward.unsqueeze(-1)
Expand Down
10 changes: 4 additions & 6 deletions dizoo/ising_env/config/ising_mfq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
agent_view_sight = 1

ising_mfq_config = dict(
exp_name='ising_mfq_seed0_debug',
exp_name='ising_mfq_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
Expand All @@ -24,6 +24,7 @@
obs_shape=obs_shape + action_shape, # for we will concat the pre_action_prob into obs
action_shape=action_shape,
encoder_hidden_size_list=[128, 128, 512],
init_bias=0,
),
nstep=3,
discount_factor=0.99,
Expand Down Expand Up @@ -53,7 +54,7 @@
type='ising_model',
import_names=['dizoo.ising_env.envs.ising_model_env'],
),
env_manager=dict(type='base'),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
ising_mfq_create_config = EasyDict(ising_mfq_create_config)
Expand All @@ -65,7 +66,4 @@
from ding.model import DQN
seed = 1
set_pkg_seed(seed)
model = DQN(**ising_mfq_config.policy.model)
model.head.A[-1][0].bias.data.fill_(0) # zero last layer bias
# print("init model successful")
serial_pipeline((main_config, create_config), seed=seed, model=model, max_env_step=5e4)
serial_pipeline((main_config, create_config), seed=seed, max_env_step=5e4)
19 changes: 13 additions & 6 deletions dizoo/ising_env/envs/ising_model_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,25 @@
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray, to_list
from ding.utils import ENV_REGISTRY

# ising_model_dir = os.path.join(os.path.dirname(__file__), 'ising_model')
# if ising_model_dir not in sys.path:
# sys.path.append(ising_model_dir)
# from ising_model.multiagent.environment import IsingMultiAgentEnv
# import ising_model
from dizoo.ising_env.envs.ising_model.multiagent.environment import IsingMultiAgentEnv
import dizoo.ising_env.envs.ising_model as ising_model_


@ENV_REGISTRY.register('ising_model')
class IsingModelEnv(BaseEnv):
"""
Overview:
Ising Model Environment for Multi-Agent Reinforcement Learning according to the paper: \
[Mean Field Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1802.05438). \
The environment is a grid of agents, each of which can be in one of two states: \
spin up or spin down. The agents interact with their neighbors according to the Ising model, \
and the goal is to maximize the global order parameter, which is the average spin of all agents. \
Details of the environment can be found in the \
[DI-engine-Doc](https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/index.html).
Interface:
`__init__`, `reset`, `close`, `seed`, `step`, `random_action`, `num_agents`, \
`observation_space`, `action_space`, `reward_space`.
"""

def __init__(self, cfg: dict) -> None:
self._cfg = cfg
Expand Down

0 comments on commit d43ddc6

Please sign in to comment.