|
45 | 45 | from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa
|
46 | 46 | from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa
|
47 | 47 | from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa
|
| 48 | +from dizoo.classic_control.cartpole.config.cartpole_dt_config import cartpole_discrete_dt_config, cartpole_discrete_dt_create_config # noqa |
48 | 49 | from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa
|
49 | 50 | from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa
|
50 | 51 | from dizoo.classic_control.pendulum.config.pendulum_ibc_config import pendulum_ibc_config, pendulum_ibc_create_config
|
@@ -621,6 +622,70 @@ def test_discrete_cql():
|
621 | 622 | os.popen('rm -rf cartpole cartpole_cql')
|
622 | 623 |
|
623 | 624 |
|
| 625 | +@pytest.mark.platformtest |
| 626 | +@pytest.mark.unittest |
| 627 | +def test_discrete_dt(): |
| 628 | + # train expert |
| 629 | + config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)] |
| 630 | + config[0].policy.learn.update_per_collect = 1 |
| 631 | + config[0].exp_name = 'dt_cartpole' |
| 632 | + try: |
| 633 | + serial_pipeline(config, seed=0, max_train_iter=1) |
| 634 | + except Exception: |
| 635 | + assert False, "pipeline fail" |
| 636 | + # collect expert data |
| 637 | + import torch |
| 638 | + config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)] |
| 639 | + state_dict = torch.load('./dt_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu') |
| 640 | + try: |
| 641 | + collect_demo_data(config, seed=0, collect_count=1000, state_dict=state_dict) |
| 642 | + except Exception as e: |
| 643 | + assert False, "pipeline fail" |
| 644 | + print(repr(e)) |
| 645 | + |
| 646 | + # train dt |
| 647 | + config = [deepcopy(cartpole_discrete_dt_config), deepcopy(cartpole_discrete_dt_create_config)] |
| 648 | + config[0].policy.eval.evaluator.eval_freq = 5 |
| 649 | + try: |
| 650 | + from ding.framework import task |
| 651 | + from ding.framework.context import OfflineRLContext |
| 652 | + from ding.envs import SubprocessEnvManagerV2, BaseEnvManagerV2 |
| 653 | + from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper |
| 654 | + from dizoo.classic_control.cartpole.envs import CartPoleEnv |
| 655 | + from ding.utils import set_pkg_seed |
| 656 | + from ding.data import create_dataset |
| 657 | + from ding.config import compile_config |
| 658 | + from ding.model.template.dt import DecisionTransformer |
| 659 | + from ding.policy import DTPolicy |
| 660 | + from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \ |
| 661 | + offline_data_fetcher_from_mem_c, offline_logger, termination_checker |
| 662 | + config = compile_config(config[0], create_cfg=config[1], auto=True) |
| 663 | + with task.start(async_mode=False, ctx=OfflineRLContext()): |
| 664 | + evaluator_env = BaseEnvManagerV2( |
| 665 | + env_fn=[lambda: AllinObsWrapper(CartPoleEnv(config.env)) for _ in range(config.env.evaluator_env_num)], |
| 666 | + cfg=config.env.manager |
| 667 | + ) |
| 668 | + |
| 669 | + set_pkg_seed(config.seed, use_cuda=config.policy.cuda) |
| 670 | + |
| 671 | + dataset = create_dataset(config) |
| 672 | + |
| 673 | + model = DecisionTransformer(**config.policy.model) |
| 674 | + policy = DTPolicy(config.policy, model=model) |
| 675 | + |
| 676 | + task.use(termination_checker(max_train_iter=1)) |
| 677 | + task.use(interaction_evaluator(config, policy.eval_mode, evaluator_env)) |
| 678 | + task.use(offline_data_fetcher_from_mem_c(config, dataset)) |
| 679 | + task.use(trainer(config, policy.learn_mode)) |
| 680 | + task.use(CkptSaver(policy, config.exp_name, train_freq=100)) |
| 681 | + task.use(offline_logger(config.exp_name)) |
| 682 | + task.run() |
| 683 | + except Exception: |
| 684 | + assert False, "pipeline fail" |
| 685 | + finally: |
| 686 | + os.popen('rm -rf cartpole cartpole_dt') |
| 687 | + |
| 688 | + |
624 | 689 | @pytest.mark.platformtest
|
625 | 690 | @pytest.mark.unittest
|
626 | 691 | def test_td3_bc():
|
|
0 commit comments