-
Notifications
You must be signed in to change notification settings - Fork 250
/
quickstart.py
99 lines (83 loc) · 2.66 KB
/
quickstart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""This is a simple example demonstrating how to clone the behavior of an expert.
Refer to the jupyter notebooks for more detailed examples of how to use the algorithms.
"""
import numpy as np
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy
from imitation.algorithms import bc
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
rng = np.random.default_rng(0)
env = make_vec_env(
"seals:seals/CartPole-v0",
rng=rng,
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # for computing rollouts
)
def train_expert():
# note: use `download_expert` instead to download a pretrained, competent expert
print("Training a expert.")
expert = PPO(
policy=MlpPolicy,
env=env,
seed=0,
batch_size=64,
ent_coef=0.0,
learning_rate=0.0003,
n_epochs=10,
n_steps=64,
)
expert.learn(1_000) # Note: change this to 100_000 to train a decent expert.
return expert
def download_expert():
print("Downloading a pretrained expert.")
expert = load_policy(
"ppo-huggingface",
organization="HumanCompatibleAI",
env_name="seals-CartPole-v0",
venv=env,
)
return expert
def sample_expert_transitions():
# expert = train_expert() # uncomment to train your own expert
expert = download_expert()
print("Sampling expert transitions.")
rollouts = rollout.rollout(
expert,
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=50),
rng=rng,
)
return rollout.flatten_trajectories(rollouts)
transitions = sample_expert_transitions()
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,
rng=rng,
)
evaluation_env = make_vec_env(
"seals:seals/CartPole-v0",
rng=rng,
env_make_kwargs={"render_mode": "human"}, # for rendering
)
print("Evaluating the untrained policy.")
reward, _ = evaluate_policy(
bc_trainer.policy, # type: ignore[arg-type]
evaluation_env,
n_eval_episodes=3,
render=True, # comment out to speed up
)
print(f"Reward before training: {reward}")
print("Training a policy using Behavior Cloning")
bc_trainer.train(n_epochs=1)
print("Evaluating the trained policy.")
reward, _ = evaluate_policy(
bc_trainer.policy, # type: ignore[arg-type]
evaluation_env,
n_eval_episodes=3,
render=True, # comment out to speed up
)
print(f"Reward after training: {reward}")