|
| 1 | +from easydict import EasyDict |
| 2 | +from copy import deepcopy |
| 3 | + |
| 4 | +from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa |
| 5 | +from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config |
| 6 | +from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy |
| 7 | + |
| 8 | + |
| 9 | +def cartpole_dqn_pwil_main(): |
| 10 | + reward_model_config = { |
| 11 | + 'type': 'pwil', |
| 12 | + 's_size': 4, |
| 13 | + 'a_size': 2, |
| 14 | + 'sample_size': 500, |
| 15 | + } |
| 16 | + |
| 17 | + # train a expert policy (PPO offpolicy) |
| 18 | + reward_model_config = EasyDict(reward_model_config) |
| 19 | + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
| 20 | + expert_policy = serial_pipeline(config, seed=0) |
| 21 | + |
| 22 | + # (optional) collect expert demo data |
| 23 | + collect_count = 10000 |
| 24 | + expert_data_path = 'expert_data.pkl' |
| 25 | + state_dict = expert_policy.collect_mode.state_dict() |
| 26 | + config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
| 27 | + collect_demo_data( |
| 28 | + config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count |
| 29 | + ) |
| 30 | + |
| 31 | + # irl + rl training |
| 32 | + cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config) |
| 33 | + cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config) |
| 34 | + cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type) |
| 35 | + reward_model_config['expert_data_path'] = expert_data_path |
| 36 | + cp_cartpole_dqn_config.exp_name = 'cartpole_dqn_pwil' |
| 37 | + cp_cartpole_dqn_config.reward_model = reward_model_config |
| 38 | + cp_cartpole_dqn_config.policy.collect.n_sample = 128 |
| 39 | + |
| 40 | + serial_pipeline_reward_model_offpolicy((cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0) |
| 41 | + |
| 42 | + |
| 43 | +if __name__ == "__main__": |
| 44 | + cartpole_dqn_pwil_main() |
0 commit comments