Skip to content

Commit

Permalink
env(rjy): add ising model env (#782)
Browse files Browse the repository at this point in the history
* env(rjy): add ising model env

* fix(rjy): modify ising for mean field RL

* fix(rjy): try to fix reward problem

* fix(rjy): fix the multi-agent reward

* polish(rjy): fix dqn net init

* polish(rjy): fix format

* polish(rjy): norm eval_episode_return

* polish(rjy): polish ising model env

* fix(rjy): fix subprocess manager

* fix(rjy): fixed reward compatibility

* polish(rjy): polish according to comments

* polish(rjy): add replay for ising

* polish(rjy): and ising replay
  • Loading branch information
nighood authored Apr 23, 2024
1 parent 1ac9ad5 commit 8392206
Show file tree
Hide file tree
Showing 17 changed files with 707 additions and 19 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
<details open>
<summary>(Click to Collapse)</summary>


| No | Environment | Label | Visualization | Code and Doc Links |
| :-: | :--------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| 1 | [Atari](https://github.com/openai/gym/tree/master/gym/envs/atari) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/atari/atari.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/atari/envs) <br>[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/atari.html)<br>[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/atari_zh.html) |
Expand Down Expand Up @@ -316,8 +315,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 35 | [metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)<br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) |
| 36 | [cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)<br> env tutorial <br> 环境指南 |
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南 |
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |

| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env) <br> env tutorial <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |

![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
25 changes: 15 additions & 10 deletions ding/model/template/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ class DQN(nn.Module):
"""

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
dropout: Optional[float] = None,
init_bias: Optional[float] = None,
) -> None:
"""
Overview:
Expand All @@ -55,6 +56,7 @@ def __init__(
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
- dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
if ``None`` then default disable dropout layer.
- init_bias (:obj:`Optional[float]`): The initial value of the last layer bias in the head network. \
"""
super(DQN, self).__init__()
# Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
Expand Down Expand Up @@ -99,6 +101,9 @@ def __init__(
norm_type=norm_type,
dropout=dropout
)
if init_bias is not None and head_cls == DuelingHead:
# Zero the last layer bias of advantage head
self.head.A[-1][0].bias.data.fill_(init_bias)

def forward(self, x: torch.Tensor) -> Dict:
"""
Expand Down
13 changes: 11 additions & 2 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,17 @@ def default_preprocess_learn(
reward = data['reward']
if len(reward.shape) == 1:
reward = reward.unsqueeze(1)
# reward: (batch_size, nstep) -> (nstep, batch_size)
data['reward'] = reward.permute(1, 0).contiguous()
# single agent reward: (batch_size, nstep) -> (nstep, batch_size)
# multi-agent reward: (batch_size, agent_dim, nstep) -> (nstep, batch_size, agent_dim)
# Assuming 'reward' is a PyTorch tensor with shape (batch_size, nstep) or (batch_size, agent_dim, nstep)
if reward.ndim == 2:
# For a 2D tensor, simply transpose it to get (nstep, batch_size)
data['reward'] = reward.transpose(0, 1).contiguous()
elif reward.ndim == 3:
# For a 3D tensor, move the last dimension to the front to get (nstep, batch_size, agent_dim)
data['reward'] = reward.permute(2, 0, 1).contiguous()
else:
raise ValueError("The 'reward' tensor must be either 2D or 3D. Got shape: {}".format(reward.shape))
else:
if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
data['reward'] = data['reward'].squeeze(-1)
Expand Down
10 changes: 7 additions & 3 deletions ding/rl_utils/adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_nstep_return_data(
"""
if nstep == 1:
return data
fake_reward = torch.zeros(1)
fake_reward = torch.zeros_like(data[0]['reward'])
next_obs_flag = 'next_obs' in data[0]
for i in range(len(data) - nstep):
# update keys ['next_obs', 'reward', 'done'] with their n-step value
Expand All @@ -131,7 +131,10 @@ def get_nstep_return_data(
if cum_reward:
data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(nstep)])
else:
data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)])
# data[i]['reward'].shape = (1) or (agent_num, 1)
# single agent env: shape (1) -> (n_step)
# multi-agent env: shape (agent_num, 1) -> (agent_num, n_step)
data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)], dim=-1)
data[i]['done'] = data[i + nstep - 1]['done']
if correct_terminate_gamma:
data[i]['value_gamma'] = gamma ** nstep
Expand All @@ -143,7 +146,8 @@ def get_nstep_return_data(
else:
data[i]['reward'] = torch.cat(
[data[i + j]['reward']
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))]
for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))],
dim=-1
)
data[i]['done'] = data[-1]['done']
if correct_terminate_gamma:
Expand Down
7 changes: 6 additions & 1 deletion ding/rl_utils/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_
if value_gamma is None:
return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done)
else:
value_gamma = view_similar(value_gamma, next_value)
done = view_similar(done, next_value)
return_ = return_tmp + value_gamma * next_value * (1 - done)

elif isinstance(gamma, list):
Expand Down Expand Up @@ -688,7 +690,10 @@ def q_nstep_td_error(
if weight is None:
weight = torch.ones_like(reward)

if len(action.shape) == 1: # single agent case
if len(action.shape) == 1 or len(action.shape) < len(q.shape):
# we need to unsqueeze action and q to make them have the same shape
# e.g. single agent case: action is [B, ] and q is [B, ]
# e.g. multi agent case: action is [B, agent_num] and q is [B, agent_num, action_shape]
action = action.unsqueeze(-1)
elif len(action.shape) > 1: # MARL case
reward = reward.unsqueeze(-1)
Expand Down
Empty file added dizoo/ising_env/__init__.py
Empty file.
68 changes: 68 additions & 0 deletions dizoo/ising_env/config/ising_mfq_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from easydict import EasyDict
from ding.utils import set_pkg_seed

obs_shape = 4
action_shape = 2
num_agents = 100
dim_spin = 2
agent_view_sight = 1

ising_mfq_config = dict(
exp_name='ising_mfq_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
num_agents=num_agents,
dim_spin=dim_spin,
agent_view_sight=agent_view_sight,
manager=dict(shared_memory=False, ),
),
policy=dict(
cuda=True,
priority=False,
model=dict(
obs_shape=obs_shape + action_shape, # for we will concat the pre_action_prob into obs
action_shape=action_shape,
encoder_hidden_size_list=[128, 128, 512],
init_bias=0,
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=32,
learning_rate=0.0001,
target_update_freq=500,
),
collect=dict(n_sample=96, ),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
ising_mfq_config = EasyDict(ising_mfq_config)
main_config = ising_mfq_config
ising_mfq_create_config = dict(
env=dict(
type='ising_model',
import_names=['dizoo.ising_env.envs.ising_model_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
)
ising_mfq_create_config = EasyDict(ising_mfq_create_config)
create_config = ising_mfq_create_config

if __name__ == '__main__':
# or you can enter `ding -m serial -c ising_mfq_config.py -s 0`
from ding.entry import serial_pipeline
seed = 1
serial_pipeline((main_config, create_config), seed=seed, max_env_step=5e4)
15 changes: 15 additions & 0 deletions dizoo/ising_env/entry/ising_mfq_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dizoo.ising_env.config.ising_mfq_config import main_config, create_config
from ding.entry import eval


def main():
main_config.env.collector_env_num = 1
main_config.env.evaluator_env_num = 1
main_config.env.n_evaluator_episode = 1
ckpt_path = './ckpt_best.pth.tar'
replay_path = './replay_videos'
eval((main_config, create_config), seed=1, load_path=ckpt_path, replay_path=replay_path)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions dizoo/ising_env/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .ising_model_env import IsingModelEnv
114 changes: 114 additions & 0 deletions dizoo/ising_env/envs/ising_model/Ising.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import numpy as np

from dizoo.ising_env.envs.ising_model.multiagent.core import IsingWorld, IsingAgent


class Scenario():

def _calc_mask(self, agent, shape_size):
# compute the neighbour mask for each agent
if agent.view_sight == -1:
# fully observed
agent.spin_mask += 1
elif agent.view_sight == 0:
# observe itself
agent.spin_mask[agent.state.id] = 1
elif agent.view_sight > 0:
# observe neighbours
delta = list(range(-int(agent.view_sight), int(agent.view_sight) + 1, 1))
delta.remove(0) # agent itself is not counted as neighbour of itself
for dt in delta:
row = agent.state.p_pos[0]
col = agent.state.p_pos[1]
row_dt = row + dt
col_dt = col + dt
if row_dt in range(0, shape_size):
agent.spin_mask[agent.state.id + shape_size * dt] = 1
if col_dt in range(0, shape_size):
agent.spin_mask[agent.state.id + dt] = 1

# the graph is cyclic, most left and most right are neighbours
if agent.state.p_pos[0] < agent.view_sight:
tar = shape_size - (np.array(range(0, int(agent.view_sight - agent.state.p_pos[0]), 1)) + 1)
tar = tar * shape_size + agent.state.p_pos[1]
agent.spin_mask[tar] = [1] * len(tar)

if agent.state.p_pos[1] < agent.view_sight:
tar = shape_size - (np.array(range(0, int(agent.view_sight - agent.state.p_pos[1]), 1)) + 1)
tar = agent.state.p_pos[0] * shape_size + tar
agent.spin_mask[tar] = [1] * len(tar)

if agent.state.p_pos[0] >= shape_size - agent.view_sight:
tar = np.array(range(0, int(agent.view_sight - (shape_size - 1 - agent.state.p_pos[0])), 1))
tar = tar * shape_size + agent.state.p_pos[1]
agent.spin_mask[tar] = [1] * len(tar)

if agent.state.p_pos[1] >= shape_size - agent.view_sight:
tar = np.array(range(0, int(agent.view_sight - (shape_size - 1 - agent.state.p_pos[1])), 1))
tar = agent.state.p_pos[0] * shape_size + tar
agent.spin_mask[tar] = [1] * len(tar)

def make_world(self, num_agents=100, agent_view=1):
world = IsingWorld()
world.agent_view_sight = agent_view
world.dim_spin = 2
world.dim_pos = 2
world.n_agents = num_agents
world.shape_size = int(np.ceil(np.power(num_agents, 1.0 / world.dim_pos)))
world.global_state = np.zeros((world.shape_size, ) * world.dim_pos)
# assume 0 external magnetic field
world.field = np.zeros((world.shape_size, ) * world.dim_pos)

world.agents = [IsingAgent(view_sight=world.agent_view_sight) for i in range(num_agents)]

# make initial conditions
self.reset_world(world)

return world

def reset_world(self, world):

world_mat = np.array(
range(np.power(world.shape_size, world.dim_pos))). \
reshape((world.shape_size,) * world.dim_pos)
# init agent state and global state
for i, agent in enumerate(world.agents):
agent.name = 'agent %d' % i
agent.color = np.array([0.35, 0.35, 0.85])
agent.state.id = i
agent.state.p_pos = np.where(world_mat == i)
agent.state.spin = np.random.choice(world.dim_spin)
agent.spin_mask = np.zeros(world.n_agents)

assert world.dim_pos == 2, "cyclic neighbour only support 2D now"
self._calc_mask(agent, world.shape_size)
world.global_state[agent.state.p_pos] = agent.state.spin

n_ups = np.count_nonzero(world.global_state.flatten())
n_downs = world.n_agents - n_ups
world.order_param = abs(n_ups - n_downs) / (world.n_agents + 0.0)

def reward(self, agent, world):
# turn the state into -1/1 for easy computing
world.global_state[np.where(world.global_state == 0)] = -1

mask_display = agent.spin_mask.reshape((int(np.sqrt(world.n_agents)), -1))

local_reward = - 0.5 * world.global_state[agent.state.p_pos] \
* np.sum(world.global_state.flatten() * agent.spin_mask)

world.global_state[np.where(world.global_state == -1)] = 0
return -local_reward

def observation(self, agent, world):
# get positions of all entities in this agent's reference frame
# agent state is updated in the world.step() function already
# update the changes of the world

# return the neighbour state
return world.global_state.flatten()[np.where(agent.spin_mask == 1)]

def done(self, agent, world):
if world.order_param == 1.0:
return True
return False
Loading

0 comments on commit 8392206

Please sign in to comment.