Skip to content

Commit 0be9555

Browse files
committed
fix(nyz): fix dt dataset compatibility bug
1 parent 08c42fa commit 0be9555

File tree

12 files changed

+23
-397
lines changed

12 files changed

+23
-397
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
7474
- [ACE](https://github.com/opendilab/ACE): [AAAI 2023] ACE: Cooperative Multi-agent Q-learning with Bidirectional Action-Dependency
7575
- [GoBigger](https://github.com/opendilab/GoBigger): [ICLR 2023] Multi-Agent Decision Intelligence Environment
7676
- [DOS](https://github.com/opendilab/DOS): [CVPR 2023] ReasonNet: End-to-End Driving with Temporal and Global Reasoning
77-
- [LightZero](https://github.com/opendilab/LightZero): LightZero: A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
77+
- [LightZero](https://github.com/opendilab/LightZero): A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
7878
- Docs and Tutorials
7979
- [DI-engine-docs](https://github.com/opendilab/DI-engine-docs): Tutorials, best practice and the API reference.
8080
- [awesome-model-based-RL](https://github.com/opendilab/awesome-model-based-RL): A curated list of awesome Model-Based RL resources

assets/wechat.jpeg

297 KB
Loading

ding/entry/serial_entry_decision_transformer.py

-88
This file was deleted.

ding/example/dt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import gym
22
from ditk import logging
3-
from ding.model.template.decision_transformer import DecisionTransformer
3+
from ding.model.template.dt import DecisionTransformer
44
from ding.policy import DTPolicy
55
from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2
66
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper

ding/utils/data/dataset.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -150,16 +150,18 @@ def __len__(self) -> int:
150150
return len(self._data['obs']) - self.context_len
151151

152152
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
153-
# return {k: self._data[k][idx] for k in self._data.keys()}
154-
block_size = self.context_len
155-
done_idx = idx + block_size
156-
idx = done_idx - block_size
157-
states = torch.as_tensor(np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32).view(block_size, -1)
158-
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long)
159-
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32)
160-
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64)
161-
traj_mask = torch.ones(self.context_len, dtype=torch.long)
162-
return timesteps, states, actions, rtgs, traj_mask
153+
if self.context_len == 0: # for other offline RL algorithms
154+
return {k: self._data[k][idx] for k in self._data.keys()}
155+
else: # for decision transformer
156+
block_size = self.context_len
157+
done_idx = idx + block_size
158+
idx = done_idx - block_size
159+
states = torch.as_tensor(np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32).view(block_size, -1)
160+
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long)
161+
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32)
162+
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64)
163+
traj_mask = torch.ones(self.context_len, dtype=torch.long)
164+
return timesteps, states, actions, rtgs, traj_mask
163165

164166
def _load_data(self, dataset: Dict[str, np.ndarray]) -> None:
165167
self._data = {}

ding/utils/pytorch_ddp_dist_helper.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:
181181

182182
def to_ddp_config(cfg: EasyDict) -> EasyDict:
183183
w = get_world_size()
184-
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
185-
# cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample) / w)
184+
if 'batch_size' in cfg.policy:
185+
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
186+
if 'batch_size' in cfg.policy.learn:
187+
cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w))
188+
if 'n_sample' in cfg.policy.collect:
189+
cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample / w))
190+
if 'n_episode' in cfg.policy.collect:
191+
cfg.policy.collect.n_episode = int(np.ceil(cfg.policy.collect.n_episode / w))
186192
return cfg

dizoo/atari/config/serial/pong/pong_dt_config.py

-5
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,3 @@
6666
)
6767
Pong_dt_create_config = EasyDict(Pong_dt_create_config)
6868
create_config = Pong_dt_create_config
69-
70-
if __name__ == "__main__":
71-
from ding.entry import serial_pipeline_dt
72-
config = deepcopy([main_config, create_config])
73-
serial_pipeline_dt(config, seed=0, max_train_iter=1000)

dizoo/box2d/lunarlander/config/lunarlander_dt_config.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
learn=dict(
4040
dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO
4141
learning_rate=3e-4,
42-
batch_size=64, # training batch size
42+
batch_size=64, # training batch size
4343
target_update_freq=100,
4444
),
4545
collect=dict(
@@ -62,8 +62,3 @@
6262
)
6363
lunarlander_dt_create_config = EasyDict(lunarlander_dt_create_config)
6464
create_config = lunarlander_dt_create_config
65-
66-
if __name__ == "__main__":
67-
from ding.entry import serial_pipeline_dt, collect_demo_data, eval, serial_pipeline
68-
config = deepcopy([main_config, create_config])
69-
serial_pipeline_dt(config, seed=0, max_train_iter=1000)

dizoo/classic_control/cartpole/config/cartpole_dt_config.py

-65
This file was deleted.

dizoo/classic_control/cartpole/offline_data/cartpole_collect_data.py

-33
This file was deleted.

dizoo/classic_control/cartpole/offline_data/cartpole_show_data.py

-50
This file was deleted.

0 commit comments

Comments
 (0)