Skip to content

Commit c999b07

Browse files
committed
demo(nyz): add naive PWIL demo
1 parent 0591b5e commit c999b07

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from easydict import EasyDict
2+
from copy import deepcopy
3+
4+
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
5+
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
6+
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy
7+
8+
9+
def cartpole_dqn_pwil_main():
10+
reward_model_config = {
11+
'type': 'pwil',
12+
's_size': 4,
13+
'a_size': 2,
14+
'sample_size': 500,
15+
}
16+
17+
# train a expert policy (PPO offpolicy)
18+
reward_model_config = EasyDict(reward_model_config)
19+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
20+
expert_policy = serial_pipeline(config, seed=0)
21+
22+
# (optional) collect expert demo data
23+
collect_count = 10000
24+
expert_data_path = 'expert_data.pkl'
25+
state_dict = expert_policy.collect_mode.state_dict()
26+
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
27+
collect_demo_data(
28+
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
29+
)
30+
31+
# irl + rl training
32+
cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config)
33+
cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config)
34+
cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type)
35+
reward_model_config['expert_data_path'] = expert_data_path
36+
cp_cartpole_dqn_config.exp_name = 'cartpole_dqn_pwil'
37+
cp_cartpole_dqn_config.reward_model = reward_model_config
38+
cp_cartpole_dqn_config.policy.collect.n_sample = 128
39+
40+
serial_pipeline_reward_model_offpolicy((cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0)
41+
42+
43+
if __name__ == "__main__":
44+
cartpole_dqn_pwil_main()

0 commit comments

Comments
 (0)