Skip to content

Commit 7ac3e7d

Browse files
author
nighood
committed
polish(rjy): fix dqn net init
1 parent f10f952 commit 7ac3e7d

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

ding/model/common/head.py

-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(
6161
norm_type=norm_type
6262
), block(hidden_size, output_size)
6363
)
64-
nn.init.normal_(self.Q[1].weight, 0, 0.2)
6564

6665
def forward(self, x: torch.Tensor) -> Dict:
6766
"""

dizoo/ising_env/config/ising_mfq_config.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from easydict import EasyDict
2+
from ding.utils import set_pkg_seed
23

34
obs_shape = 4
45
action_shape = 2
@@ -7,7 +8,7 @@
78
agent_view_sight = 1
89

910
ising_mfq_config = dict(
10-
exp_name='ising_mfq_seed0',
11+
exp_name='ising_mfq_seed0_debug',
1112
env=dict(
1213
collector_env_num=8,
1314
evaluator_env_num=8,
@@ -61,4 +62,10 @@
6162
if __name__ == '__main__':
6263
# or you can enter `ding -m serial -c ising_mfq_config.py -s 0`
6364
from ding.entry import serial_pipeline
64-
serial_pipeline((main_config, create_config), seed=0, max_env_step=5e4)
65+
from ding.model import DQN
66+
seed = 1
67+
set_pkg_seed(seed)
68+
model = DQN(**ising_mfq_config.policy.model)
69+
model.head.A[-1][0].bias.data.fill_(0) # zero last layer bias
70+
# print("init model successful")
71+
serial_pipeline((main_config, create_config), seed=seed, model=model, max_env_step=5e4)

dizoo/ising_env/envs/ising_model_env.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, cfg: dict) -> None:
3030

3131
def calculate_action_prob(self, actions):
3232
num_action = self._action_space.n
33-
N = actions.shape[0] # agent_num
33+
N = actions.shape[0] # agent_num
3434
# Convert actions to one_hot encoding
3535
one_hot_actions = np.eye(num_action)[actions.flatten()]
3636
action_prob = np.zeros((N, num_action))
@@ -84,9 +84,9 @@ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
8484
self._dynamic_seed = dynamic_seed
8585
np.random.seed(self._seed)
8686

87-
def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
87+
def step(self, action: np.ndarray) -> BaseEnvTimestep:
8888
action = to_ndarray(action)
89-
if (len(action.shape) == 1):
89+
if len(action.shape) == 1:
9090
action = np.expand_dims(action, axis=1)
9191
obs, rew, done, order_param, ups, downs = self._env._step(action)
9292
info = {"order_param": order_param, "ups": ups, "downs": downs, 'pre_action': self.pre_action}
@@ -95,9 +95,7 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
9595
obs = np.stack(obs)
9696
obs = np.concatenate([obs, pre_action_prob], axis=1)
9797
obs = to_ndarray(obs).astype(np.float32)
98-
rew = np.stack(rew)
99-
rew = np.squeeze(to_ndarray(rew).astype(np.float32), axis=1)
100-
# rew = to_ndarray(rew).astype(np.float32)
98+
rew = np.concatenate(rew)
10199
self._eval_episode_return += np.sum(rew)
102100

103101
done = done[0] # dones are the same for all agents

dizoo/ising_env/envs/test_ising_model_env.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
@pytest.mark.envtest
1010
class TestIsingModelEnv:
1111

12-
def test_ising():
12+
def test_ising(self):
1313
env = IsingModelEnv(EasyDict({'num_agents': num_agents, 'dim_spin': 2, 'agent_view_sight': 1}))
1414
env.seed(314, dynamic_seed=False)
1515
assert env._seed == 314
1616
obs = env.reset()
17-
assert obs.shape == (100, 4 + 2)
17+
assert obs.shape == (num_agents, 4 + 2)
1818
for _ in range(5):
1919
env.reset()
2020
np.random.seed(314)
@@ -30,9 +30,9 @@ def test_ising():
3030
timestep = env.step(random_action)
3131
print('timestep', timestep, '\n')
3232
assert isinstance(timestep.obs, np.ndarray)
33-
assert isinstance(timestep.done[0], bool)
34-
assert timestep.obs.shape == (100, 4 + 2)
35-
assert timestep.reward.shape == (100, 1)
33+
assert isinstance(timestep.done, bool)
34+
assert timestep.obs.shape == (num_agents, 4 + 2)
35+
assert timestep.reward.shape == (num_agents, )
3636
assert timestep.reward[0] >= env.reward_space.low
3737
assert timestep.reward[0] <= env.reward_space.high
3838
print(env.observation_space, env.action_space, env.reward_space)

0 commit comments

Comments
 (0)