Skip to content

Commit ad8c53c

Browse files
author
nighood
committed
fix(rjy): modify ising for mean field RL
1 parent 98b81a0 commit ad8c53c

File tree

3 files changed

+93
-4
lines changed

3 files changed

+93
-4
lines changed
+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from easydict import EasyDict
2+
3+
obs_shape = 4
4+
action_shape = 2
5+
num_agents = 100
6+
dim_spin = 2
7+
agent_view_sight = 1
8+
9+
ising_mfq_config = dict(
10+
exp_name='ising_mfq_seed0',
11+
env=dict(
12+
collector_env_num=8,
13+
evaluator_env_num=8,
14+
n_evaluator_episode=8,
15+
stop_value=20,
16+
num_agents=num_agents,
17+
dim_spin=dim_spin,
18+
agent_view_sight=agent_view_sight,
19+
),
20+
policy=dict(
21+
cuda=True,
22+
priority=False,
23+
model=dict(
24+
obs_shape=obs_shape + action_shape, # for we will concat the pre_action_prob into obs
25+
action_shape=action_shape,
26+
encoder_hidden_size_list=[128, 128, 512],
27+
),
28+
nstep=3,
29+
discount_factor=0.99,
30+
learn=dict(
31+
update_per_collect=10,
32+
batch_size=32,
33+
learning_rate=0.0001,
34+
target_update_freq=500,
35+
),
36+
collect=dict(n_sample=96, ),
37+
eval=dict(evaluator=dict(eval_freq=4000, )),
38+
other=dict(
39+
eps=dict(
40+
type='exp',
41+
start=1.,
42+
end=0.05,
43+
decay=250000,
44+
),
45+
replay_buffer=dict(replay_buffer_size=100000, ),
46+
),
47+
),
48+
)
49+
ising_mfq_config = EasyDict(ising_mfq_config)
50+
main_config = ising_mfq_config
51+
ising_mfq_create_config = dict(
52+
env=dict(
53+
type='ising_model',
54+
import_names=['dizoo.ising_env.envs.ising_model_env'],
55+
),
56+
env_manager=dict(type='base'),
57+
policy=dict(type='dqn'),
58+
)
59+
ising_mfq_create_config = EasyDict(ising_mfq_create_config)
60+
create_config = ising_mfq_create_config
61+
62+
if __name__ == '__main__':
63+
# or you can enter `ding -m serial -c ising_mfq_config.py -s 0`
64+
from ding.entry import serial_pipeline
65+
serial_pipeline((main_config, create_config), seed=0)

dizoo/ising_env/envs/ising_model_env.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ def __init__(self, cfg: dict) -> None:
2828
self._observation_space = gym.spaces.MultiBinary(4 * cfg.agent_view_sight)
2929
self._reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
3030

31+
def calculate_action_prob(self, actions):
32+
num_action = self._action_space.n
33+
N = actions.shape[0]
34+
# Convert actions to one_hot encoding
35+
one_hot_actions = np.eye(num_action)[actions.flatten()]
36+
action_prob = np.zeros((N, num_action))
37+
38+
for i in range(N):
39+
# Exclude agent i's actions and calculate the one_hot average of all other agent actions
40+
exclude_current = np.delete(one_hot_actions, i, axis=0)
41+
action_prob[i] = exclude_current.mean(axis=0)
42+
43+
return action_prob
44+
3145
def reset(self) -> np.ndarray:
3246
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
3347
np_seed = 100 * np.random.randint(1, 1000)
@@ -47,6 +61,9 @@ def reset(self) -> np.ndarray:
4761
self._init_flag = True
4862
obs = self._env._reset()
4963
obs = np.stack(obs)
64+
self.pre_action = np.zeros(self._cfg.num_agents, dtype=np.int32)
65+
pre_action_prob = np.zeros((self._cfg.num_agents, self._action_space.n))
66+
obs = np.concatenate([obs, pre_action_prob], axis=1)
5067
obs = to_ndarray(obs).astype(np.float32)
5168
self._eval_episode_return = np.zeros((self._cfg.num_agents, 1), dtype=np.float32)
5269
return obs
@@ -63,13 +80,20 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
6380

6481
def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
6582
action = to_ndarray(action)
83+
if (len(action.shape) == 1):
84+
action = np.expand_dims(action, axis=1)
6685
obs, rew, done, order_param, ups, downs = self._env._step(action)
67-
info = {"order_param": order_param, "ups": ups, "downs": downs}
86+
info = {"order_param": order_param, "ups": ups, "downs": downs, 'pre_action': self.pre_action}
87+
pre_action_prob = self.calculate_action_prob(self.pre_action)
88+
self.pre_action = action
6889
obs = np.stack(obs)
90+
obs = np.concatenate([obs, pre_action_prob], axis=1)
6991
obs = to_ndarray(obs).astype(np.float32)
7092
rew = np.stack(rew)
7193
rew = to_ndarray(rew).astype(np.float32)
7294
self._eval_episode_return += rew
95+
96+
done = done[0] # dones are the same for all agents
7397
if done:
7498
info['eval_episode_return'] = self._eval_episode_return
7599
return BaseEnvTimestep(obs, rew, done, info)

dizoo/ising_env/envs/test_ising_model.py renamed to dizoo/ising_env/envs/test_ising_model_env.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
@pytest.mark.envtest
1010
class TestIsingModelEnv:
1111

12-
def test_ising(self):
12+
def test_ising():
1313
env = IsingModelEnv(EasyDict({'num_agents': num_agents, 'dim_spin': 2, 'agent_view_sight': 1}))
1414
env.seed(314, dynamic_seed=False)
1515
assert env._seed == 314
1616
obs = env.reset()
17-
assert obs.shape == (100, 4)
17+
assert obs.shape == (100, 4 + 2)
1818
for _ in range(5):
1919
env.reset()
2020
np.random.seed(314)
@@ -31,7 +31,7 @@ def test_ising(self):
3131
print('timestep', timestep, '\n')
3232
assert isinstance(timestep.obs, np.ndarray)
3333
assert isinstance(timestep.done[0], bool)
34-
assert timestep.obs.shape == (100, 4)
34+
assert timestep.obs.shape == (100, 4 + 2)
3535
assert timestep.reward.shape == (100, 1)
3636
assert timestep.reward[0] >= env.reward_space.low
3737
assert timestep.reward[0] <= env.reward_space.high

0 commit comments

Comments
 (0)