Skip to content

Commit 3a73dd4

Browse files
authored
refactor(lyd): refactor dt_policy in new pipeline and add img input support (#693)
* Revise old version dt pipline * Add new dt pipline * Add DT in new pipeline * Add img input to support atari * Fix according to comment * Fix dt config files * Fix abs path * Accelerate DT train iter by replacing dataloader * Simplify dt model and policy and config * reformat * Reformat * Change data fatcher func to class * Add threading shift data to gpu * Change action sample func * Add configure optimizers * Add multi gpu support * Add dt policy test serial * Fix multi gpu support and data fetcher * Reformat
1 parent bc3ecd9 commit 3a73dd4

File tree

109 files changed

+2217
-1445
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+2217
-1445
lines changed

ding/entry/tests/test_serial_entry.py

+65
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa
4646
from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa
4747
from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa
48+
from dizoo.classic_control.cartpole.config.cartpole_dt_config import cartpole_discrete_dt_config, cartpole_discrete_dt_create_config # noqa
4849
from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa
4950
from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa
5051
from dizoo.classic_control.pendulum.config.pendulum_ibc_config import pendulum_ibc_config, pendulum_ibc_create_config
@@ -621,6 +622,70 @@ def test_discrete_cql():
621622
os.popen('rm -rf cartpole cartpole_cql')
622623

623624

625+
@pytest.mark.platformtest
626+
@pytest.mark.unittest
627+
def test_discrete_dt():
628+
# train expert
629+
config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
630+
config[0].policy.learn.update_per_collect = 1
631+
config[0].exp_name = 'dt_cartpole'
632+
try:
633+
serial_pipeline(config, seed=0, max_train_iter=1)
634+
except Exception:
635+
assert False, "pipeline fail"
636+
# collect expert data
637+
import torch
638+
config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
639+
state_dict = torch.load('./dt_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu')
640+
try:
641+
collect_demo_data(config, seed=0, collect_count=1000, state_dict=state_dict)
642+
except Exception as e:
643+
assert False, "pipeline fail"
644+
print(repr(e))
645+
646+
# train dt
647+
config = [deepcopy(cartpole_discrete_dt_config), deepcopy(cartpole_discrete_dt_create_config)]
648+
config[0].policy.eval.evaluator.eval_freq = 5
649+
try:
650+
from ding.framework import task
651+
from ding.framework.context import OfflineRLContext
652+
from ding.envs import SubprocessEnvManagerV2, BaseEnvManagerV2
653+
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
654+
from dizoo.classic_control.cartpole.envs import CartPoleEnv
655+
from ding.utils import set_pkg_seed
656+
from ding.data import create_dataset
657+
from ding.config import compile_config
658+
from ding.model.template.dt import DecisionTransformer
659+
from ding.policy import DTPolicy
660+
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
661+
offline_data_fetcher_from_mem_c, offline_logger, termination_checker
662+
config = compile_config(config[0], create_cfg=config[1], auto=True)
663+
with task.start(async_mode=False, ctx=OfflineRLContext()):
664+
evaluator_env = BaseEnvManagerV2(
665+
env_fn=[lambda: AllinObsWrapper(CartPoleEnv(config.env)) for _ in range(config.env.evaluator_env_num)],
666+
cfg=config.env.manager
667+
)
668+
669+
set_pkg_seed(config.seed, use_cuda=config.policy.cuda)
670+
671+
dataset = create_dataset(config)
672+
673+
model = DecisionTransformer(**config.policy.model)
674+
policy = DTPolicy(config.policy, model=model)
675+
676+
task.use(termination_checker(max_train_iter=1))
677+
task.use(interaction_evaluator(config, policy.eval_mode, evaluator_env))
678+
task.use(offline_data_fetcher_from_mem_c(config, dataset))
679+
task.use(trainer(config, policy.learn_mode))
680+
task.use(CkptSaver(policy, config.exp_name, train_freq=100))
681+
task.use(offline_logger(config.exp_name))
682+
task.run()
683+
except Exception:
684+
assert False, "pipeline fail"
685+
finally:
686+
os.popen('rm -rf cartpole cartpole_dt')
687+
688+
624689
@pytest.mark.platformtest
625690
@pytest.mark.unittest
626691
def test_td3_bc():

ding/envs/env/tests/test_ding_env_wrapper.py

+19
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,22 @@ def test_hybrid(self):
180180
action = ding_env_hybrid.random_action()
181181
print('random_action', action)
182182
assert isinstance(action, dict)
183+
184+
@pytest.mark.unittest
185+
def test_AllinObsWrapper(self):
186+
env_cfg = EasyDict(env_id='PongNoFrameskip-v4', env_wrapper='reward_in_obs')
187+
ding_env_aio = DingEnvWrapper(cfg=env_cfg)
188+
189+
data = ding_env_aio.reset()
190+
assert isinstance(data, dict)
191+
assert 'obs' in data.keys() and 'reward' in data.keys()
192+
assert data['obs'].shape == ding_env_aio.observation_space
193+
while True:
194+
action = ding_env_aio.random_action()
195+
timestep = ding_env_aio.step(action)
196+
# print(timestep.reward)
197+
assert isinstance(timestep.obs, dict)
198+
if timestep.done:
199+
assert 'eval_episode_return' in timestep.info, timestep.info
200+
break
201+
print(ding_env_aio.observation_space, ding_env_aio.action_space, ding_env_aio.reward_space)

ding/envs/env_wrappers/env_wrappers.py

+37
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,43 @@ def reset(self):
12151215
return self.env.reset()
12161216

12171217

1218+
@ENV_WRAPPER_REGISTRY.register('reward_in_obs')
1219+
class AllinObsWrapper(gym.Wrapper):
1220+
"""
1221+
Overview:
1222+
This wrapper is used in policy DT.
1223+
Set a dict {'obs': obs, 'reward': reward}
1224+
as the new wrapped observation,
1225+
which including the current obs, previous reward.
1226+
Interface:
1227+
``__init__``, ``reset``, ``step``, ``seed``
1228+
Properties:
1229+
- env (:obj:`gym.Env`): the environment to wrap.
1230+
"""
1231+
1232+
def __init__(self, env):
1233+
super().__init__(env)
1234+
1235+
def reset(self):
1236+
ret = {'obs': self.env.reset(), 'reward': np.array([0])}
1237+
self._observation_space = gym.spaces.Dict(
1238+
{
1239+
'obs': self.env.observation_space,
1240+
'reward': gym.spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32)
1241+
}
1242+
)
1243+
return ret
1244+
1245+
def step(self, action):
1246+
obs, reward, done, info = self.env.step(action)
1247+
obs = {'obs': obs, 'reward': reward}
1248+
from ding.envs import BaseEnvTimestep
1249+
return BaseEnvTimestep(obs, reward, done, info)
1250+
1251+
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
1252+
self.env.seed(seed, dynamic_seed)
1253+
1254+
12181255
def update_shape(obs_shape, act_shape, rew_shape, wrapper_names):
12191256
"""
12201257
Overview:

ding/example/dt.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import gym
2+
from ditk import logging
3+
from ding.model.template.decision_transformer import DecisionTransformer
4+
from ding.policy import DTPolicy
5+
from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2
6+
from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
7+
from ding.data import create_dataset
8+
from ding.config import compile_config
9+
from ding.framework import task, ding_init
10+
from ding.framework.context import OfflineRLContext
11+
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
12+
offline_data_fetcher, offline_logger, termination_checker, final_ctx_saver
13+
from ding.utils import set_pkg_seed
14+
from dizoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv
15+
from dizoo.box2d.lunarlander.config.lunarlander_dt_config import main_config, create_config
16+
17+
18+
def main():
19+
# If you don't have offline data, you need to prepare if first and set the data_path in config
20+
# For demostration, we also can train a RL policy (e.g. SAC) and collect some data
21+
logging.getLogger().setLevel(logging.INFO)
22+
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
23+
ding_init(cfg)
24+
with task.start(async_mode=False, ctx=OfflineRLContext()):
25+
evaluator_env = BaseEnvManagerV2(
26+
env_fn=[lambda: AllinObsWrapper(LunarLanderEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)],
27+
cfg=cfg.env.manager
28+
)
29+
30+
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
31+
32+
dataset = create_dataset(cfg)
33+
cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats()
34+
model = DecisionTransformer(**cfg.policy.model)
35+
policy = DTPolicy(cfg.policy, model=model)
36+
37+
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
38+
task.use(offline_data_fetcher(cfg, dataset))
39+
task.use(trainer(cfg, policy.learn_mode))
40+
task.use(termination_checker(max_train_iter=1e5))
41+
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
42+
task.use(offline_logger())
43+
task.run()
44+
45+
46+
if __name__ == "__main__":
47+
main()

ding/framework/context.py

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class OfflineRLContext(Context):
8282

8383
# common
8484
total_step: int = 0
85+
env_step: int = 0
8586
train_epoch: int = 0
8687
train_iter: int = 0
8788
train_data: Union[Dict, List] = None

ding/framework/middleware/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .ckpt_handler import CkptSaver
55
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
66
from .barrier import Barrier, BarrierRuntime
7+
from .data_fetcher import offline_data_fetcher_from_mem_c
+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from typing import TYPE_CHECKING
2+
from threading import Thread, Event
3+
from queue import Queue
4+
import time
5+
import torch
6+
import torch.distributed as dist
7+
from easydict import EasyDict
8+
from ding.framework import task
9+
from ding.data import Dataset, DataLoader
10+
from ding.utils import get_rank
11+
import numpy as np
12+
13+
if TYPE_CHECKING:
14+
from ding.framework import OfflineRLContext
15+
16+
17+
class offline_data_fetcher_from_mem_c:
18+
19+
def __new__(cls, *args, **kwargs):
20+
if task.router.is_active and not task.has_role(task.role.FETCHER):
21+
return task.void()
22+
return super(offline_data_fetcher_from_mem_c, cls).__new__(cls)
23+
24+
def __init__(self, cfg: EasyDict, dataset: Dataset):
25+
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
26+
if device != 'cpu':
27+
stream = torch.cuda.Stream()
28+
29+
def producer(queue, dataset, batch_size, device, event):
30+
torch.set_num_threads(4)
31+
if device != 'cpu':
32+
nonlocal stream
33+
sbatch_size = batch_size * dist.get_world_size()
34+
rank = get_rank()
35+
idx_list = np.random.permutation(len(dataset))
36+
temp_idx_list = []
37+
for i in range(len(dataset) // sbatch_size):
38+
temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size])
39+
idx_iter = iter(temp_idx_list)
40+
41+
if device != 'cpu':
42+
with torch.cuda.stream(stream):
43+
while True:
44+
if queue.full():
45+
time.sleep(0.1)
46+
else:
47+
data = []
48+
for _ in range(batch_size):
49+
try:
50+
data.append(dataset.__getitem__(next(idx_iter)))
51+
except StopIteration:
52+
del idx_iter
53+
idx_list = np.random.permutation(len(dataset))
54+
idx_iter = iter(idx_list)
55+
data.append(dataset.__getitem__(next(idx_iter)))
56+
data = [[i[j] for i in data] for j in range(len(data[0]))]
57+
data = [torch.stack(x).to(device) for x in data]
58+
queue.put(data)
59+
if event.is_set():
60+
break
61+
else:
62+
while True:
63+
if queue.full():
64+
time.sleep(0.1)
65+
else:
66+
data = []
67+
for _ in range(batch_size):
68+
try:
69+
data.append(dataset.__getitem__(next(idx_iter)))
70+
except StopIteration:
71+
del idx_iter
72+
idx_list = np.random.permutation(len(dataset))
73+
idx_iter = iter(idx_list)
74+
data.append(dataset.__getitem__(next(idx_iter)))
75+
data = [[i[j] for i in data] for j in range(len(data[0]))]
76+
data = [torch.stack(x) for x in data]
77+
queue.put(data)
78+
if event.is_set():
79+
break
80+
81+
self.queue = Queue(maxsize=50)
82+
self.event = Event()
83+
self.producer_thread = Thread(
84+
target=producer,
85+
args=(self.queue, dataset, cfg.policy.batch_size, device, self.event),
86+
name='cuda_fetcher_producer'
87+
)
88+
89+
def __call__(self, ctx: "OfflineRLContext"):
90+
if not self.producer_thread.is_alive():
91+
time.sleep(5)
92+
self.producer_thread.start()
93+
while self.queue.empty():
94+
time.sleep(0.001)
95+
ctx.train_data = self.queue.get()
96+
97+
def __del__(self):
98+
if self.producer_thread.is_alive():
99+
self.event.set()
100+
del self.queue

ding/framework/middleware/functional/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .trainer import trainer, multistep_trainer
22
from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \
3-
sqil_data_pusher, buffer_saver
3+
offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver
44
from .collector import inferencer, rolloutor, TransitionList
55
from .evaluator import interaction_evaluator, interaction_evaluator_ttorch
66
from .termination_checker import termination_checker, ddp_termination_checker

ding/framework/middleware/functional/data_processor.py

+46-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type
77
from ding.data.buffer.middleware import PriorityExperienceReplay
88
from ding.framework import task
9+
from ding.utils import get_rank
910

1011
if TYPE_CHECKING:
1112
from ding.framework import OnlineRLContext, OfflineRLContext
@@ -180,6 +181,51 @@ def _fetch(ctx: "OnlineRLContext"):
180181
return _fetch
181182

182183

184+
def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:
185+
186+
from threading import Thread
187+
from queue import Queue
188+
import time
189+
stream = torch.cuda.Stream()
190+
191+
def producer(queue, dataset, batch_size, device):
192+
torch.set_num_threads(4)
193+
nonlocal stream
194+
idx_iter = iter(range(len(dataset)))
195+
with torch.cuda.stream(stream):
196+
while True:
197+
if queue.full():
198+
time.sleep(0.1)
199+
else:
200+
try:
201+
start_idx = next(idx_iter)
202+
except StopIteration:
203+
del idx_iter
204+
idx_iter = iter(range(len(dataset)))
205+
start_idx = next(idx_iter)
206+
data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
207+
data = [[i[j] for i in data] for j in range(len(data[0]))]
208+
data = [torch.stack(x).to(device) for x in data]
209+
queue.put(data)
210+
211+
queue = Queue(maxsize=50)
212+
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
213+
producer_thread = Thread(
214+
target=producer, args=(queue, dataset, cfg.policy.batch_size, device), name='cuda_fetcher_producer'
215+
)
216+
217+
def _fetch(ctx: "OfflineRLContext"):
218+
nonlocal queue, producer_thread
219+
if not producer_thread.is_alive():
220+
time.sleep(5)
221+
producer_thread.start()
222+
while queue.empty():
223+
time.sleep(0.001)
224+
ctx.train_data = queue.get()
225+
226+
return _fetch
227+
228+
183229
def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
184230
"""
185231
Overview:
@@ -208,7 +254,6 @@ def _fetch(ctx: "OfflineRLContext"):
208254
for i, data in enumerate(dataloader):
209255
ctx.train_data = data
210256
yield
211-
ctx.train_epoch += 1
212257
# TODO apply data update (e.g. priority) in offline setting when necessary
213258

214259
return _fetch

0 commit comments

Comments
 (0)