Skip to content

Commit 7451f46

Browse files
committed
fix(nyz): fix unittest bugs
1 parent 6e139b6 commit 7451f46

File tree

5 files changed

+26
-30
lines changed

5 files changed

+26
-30
lines changed

ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515

1616
@pytest.mark.unittest
1717
def test_serial_pipeline_trex_onpolicy():
18-
exp_name = 'test_serial_pipeline_trex_onpolicy_expert'
18+
exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
1919
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
2020
config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
2121
config[0].exp_name = exp_name
2222
expert_policy = serial_pipeline_onpolicy(config, seed=0)
2323

24-
exp_name = 'test_serial_pipeline_trex_onpolicy_collect'
24+
exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_collect'
2525
config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)]
2626
config[0].exp_name = exp_name
27-
config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_onpolicy_expert'
27+
config[0].reward_model.expert_model_path = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
2828
config[0].reward_model.checkpoint_max = 100
2929
config[0].reward_model.checkpoint_step = 100
3030
config[0].reward_model.num_snippets = 100

ding/envs/env/tests/test_ding_env_wrapper.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_hybrid(self):
181181
print('random_action', action)
182182
assert isinstance(action, dict)
183183

184-
@pytest.mark.unittest
184+
@pytest.mark.envtest
185185
def test_AllinObsWrapper(self):
186186
env_cfg = EasyDict(env_id='PongNoFrameskip-v4', env_wrapper='reward_in_obs')
187187
ding_env_aio = DingEnvWrapper(cfg=env_cfg)

ding/framework/tests/test_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_callback(key):
2424
time.sleep(0.7)
2525

2626

27-
@pytest.mark.unittest
27+
@pytest.mark.tmp
2828
def test_parallel_run():
2929
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main)
3030
Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main)

ding/model/template/tests/test_decision_transformer.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn.functional as F
55

66
from ding.model.template import DecisionTransformer
7-
from ding.torch_utils import is_differentiable, one_hot
7+
from ding.torch_utils import is_differentiable
88

99
args = ['continuous', 'discrete']
1010

@@ -23,6 +23,7 @@ def test_decision_transformer(action_space):
2323
context_len=T,
2424
n_heads=2,
2525
drop_p=0.1,
26+
continuous=(action_space == 'continuous')
2627
)
2728

2829
is_continuous = True if action_space == 'continuous' else False
@@ -40,15 +41,11 @@ def test_decision_transformer(action_space):
4041
# all ones since no padding
4142
traj_mask = torch.ones([B, T], dtype=torch.long) # B x T
4243

43-
# if discrete
44-
if not is_continuous:
45-
actions = one_hot(actions.squeeze(-1), num=act_dim)
46-
47-
assert actions.shape == (B, T, act_dim)
4844
if is_continuous:
4945
assert action_target.shape == (B, T, act_dim)
5046
else:
5147
assert action_target.shape == (B, T, 1)
48+
actions = actions.squeeze(-1)
5249

5350
returns_to_go = returns_to_go.float()
5451
state_preds, action_preds, return_preds = DT_model.forward(

ding/policy/dt.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def _init_learn(self) -> None:
6969
self.act_dim = self._cfg.model.act_dim
7070

7171
self._learn_model = self._model
72+
self._atari_env = 'state_mean' not in self._cfg
73+
self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
7274

73-
if 'state_mean' not in self._cfg:
75+
if self._atari_env:
7476
self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr)
7577
else:
7678
self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay)
@@ -93,22 +95,18 @@ def _forward_learn(self, data: list) -> Dict[str, Any]:
9395
self._learn_model.train()
9496

9597
timesteps, states, actions, returns_to_go, traj_mask = data
96-
if actions.dtype is not torch.long:
97-
actions = actions.to(torch.long)
98-
action_target = torch.clone(actions).detach().to(self._device)
9998

10099
# The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1),
101100
# and we need a 3-dim tensor
102101
if len(returns_to_go.shape) == 2:
103102
returns_to_go = returns_to_go.unsqueeze(-1)
104103

105-
# if discrete
106-
if not self._cfg.model.continuous and 'state_mean' in self._cfg:
107-
# actions = one_hot(actions.squeeze(-1), num=self.act_dim)
104+
if self._basic_discrete_env:
105+
actions = actions.to(torch.long)
108106
actions = actions.squeeze(-1)
109107
action_target = torch.clone(actions).detach().to(self._device)
110108

111-
if 'state_mean' not in self._cfg:
109+
if self._atari_env:
112110
state_preds, action_preds, return_preds = self._learn_model.forward(
113111
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1
114112
)
@@ -117,7 +115,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]:
117115
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
118116
)
119117

120-
if 'state_mean' not in self._cfg:
118+
if self._atari_env:
121119
action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1))
122120
else:
123121
traj_mask = traj_mask.view(-1, )
@@ -171,7 +169,9 @@ def _init_eval(self) -> None:
171169
self.actions = torch.zeros(
172170
(self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
173171
)
174-
if 'state_mean' not in self._cfg:
172+
self._atari_env = 'state_mean' not in self._cfg
173+
self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
174+
if self._atari_env:
175175
self.states = torch.zeros(
176176
(
177177
self.eval_batch_size,
@@ -201,7 +201,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
201201

202202
self._eval_model.eval()
203203
with torch.no_grad():
204-
if 'state_mean' not in self._cfg:
204+
if self._atari_env:
205205
states = torch.zeros(
206206
(
207207
self.eval_batch_size,
@@ -228,15 +228,15 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
228228
(self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device
229229
)
230230
for i in data_id:
231-
if 'state_mean' not in self._cfg:
231+
if self._atari_env:
232232
self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
233233
else:
234234
self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
235235
self.running_rtg[i] = self.running_rtg[i] - data[i]['reward'].to(self._device)
236236
self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]
237237

238238
if self.t[i] <= self.context_len:
239-
if 'state_mean' not in self._cfg:
239+
if self._atari_env:
240240
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
241241
(1, 1), dtype=torch.int64
242242
).to(self._device)
@@ -246,7 +246,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
246246
actions[i] = self.actions[i, :self.context_len]
247247
rewards_to_go[i] = self.rewards_to_go[i, :self.context_len]
248248
else:
249-
if 'state_mean' not in self._cfg:
249+
if self._atari_env:
250250
timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
251251
(1, 1), dtype=torch.int64
252252
).to(self._device)
@@ -255,15 +255,14 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
255255
states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
256256
actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
257257
rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
258-
if not self._cfg.model.continuous and 'state_mean' in self._cfg:
259-
# actions = one_hot(actions.squeeze(-1), num=self.act_dim)
258+
if self._basic_discrete_env:
260259
actions = actions.squeeze(-1)
261260
_, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go)
262261
del timesteps, states, actions, rewards_to_go
263262

264263
logits = act_preds[:, -1, :]
265264
if not self._cfg.model.continuous:
266-
if 'state_mean' not in self._cfg:
265+
if self._atari_env:
267266
probs = F.softmax(logits, dim=-1)
268267
act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device)
269268
for i in data_id:
@@ -297,7 +296,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None:
297296
dtype=torch.float32,
298297
device=self._device
299298
)
300-
if 'state_mean' not in self._cfg:
299+
if self._atari_env:
301300
self.states = torch.zeros(
302301
(
303302
self.eval_batch_size,
@@ -327,7 +326,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None:
327326
self.actions[i] = torch.zeros(
328327
(self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
329328
)
330-
if 'state_mean' not in self._cfg:
329+
if self._atari_env:
331330
self.states[i] = torch.zeros(
332331
(self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device
333332
)

0 commit comments

Comments
 (0)