-
Notifications
You must be signed in to change notification settings - Fork 144
/
behavioural_cloning.py
143 lines (116 loc) · 5.42 KB
/
behavioural_cloning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Basic behavioural cloning
# Note: this uses gradient accumulation in batches of ones
# to perform training.
# This will fit inside even smaller GPUs (tested on 8GB one),
# but is slow.
# NOTE: This is _not_ the original code used for VPT!
# This is merely to illustrate how to fine-tune the models and includes
# the processing steps used.
# This will likely be much worse than what original VPT did:
# we are not training on full sequences, but only one step at a time to save VRAM.
from argparse import ArgumentParser
import pickle
import time
import gym
import minerl
import torch as th
import numpy as np
from agent import PI_HEAD_KWARGS, MineRLAgent
from data_loader import DataLoader
from lib.tree_util import tree_map
EPOCHS = 2
# Needs to be <= number of videos
BATCH_SIZE = 8
# Ideally more than batch size to create
# variation in datasets (otherwise, you will
# get a bunch of consecutive samples)
# Decrease this (and batch_size) if you run out of memory
N_WORKERS = 12
DEVICE = "cuda"
LOSS_REPORT_RATE = 100
LEARNING_RATE = 0.000181
WEIGHT_DECAY = 0.039428
MAX_GRAD_NORM = 5.0
def load_model_parameters(path_to_model_file):
agent_parameters = pickle.load(open(path_to_model_file, "rb"))
policy_kwargs = agent_parameters["model"]["args"]["net"]["args"]
pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"]
pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"])
return policy_kwargs, pi_head_kwargs
def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)
# To create model with the right environment.
# All basalt environments have the same settings, so any of them works here
env = gym.make("MineRLBasaltFindCave-v0")
agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
agent.load_weights(in_weights)
env.close()
policy = agent.policy
trainable_parameters = policy.parameters()
# Parameters taken from the OpenAI VPT paper
optimizer = th.optim.Adam(
trainable_parameters,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
data_loader = DataLoader(
dataset_dir=data_dir,
n_workers=N_WORKERS,
batch_size=BATCH_SIZE,
n_epochs=EPOCHS
)
start_time = time.time()
# Keep track of the hidden state per episode/trajectory.
# DataLoader provides unique id for each episode, which will
# be different even for the same trajectory when it is loaded
# up again
episode_hidden_states = {}
dummy_first = th.from_numpy(np.array((False,))).to(DEVICE)
loss_sum = 0
for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader):
batch_loss = 0
for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True)
if agent_action is None:
# Action was null
continue
agent_obs = agent._env_obs_to_agent({"pov": image})
if episode_id not in episode_hidden_states:
# TODO need to clean up this hidden state after worker is done with the work item.
# Leaks memory, but not tooooo much at these scales (will be a problem later).
episode_hidden_states[episode_id] = policy.initial_state(1)
agent_state = episode_hidden_states[episode_id]
pi_distribution, v_prediction, new_agent_state = policy.get_output_for_observation(
agent_obs,
agent_state,
dummy_first
)
log_prob = policy.get_logprob_of_action(pi_distribution, agent_action)
# Make sure we do not try to backprop through sequence
# (fails with current accumulation)
new_agent_state = tree_map(lambda x: x.detach(), new_agent_state)
episode_hidden_states[episode_id] = new_agent_state
# Finally, update the agent to increase the probability of the
# taken action.
# Remember to take mean over batch losses
loss = -log_prob / BATCH_SIZE
batch_loss += loss.item()
loss.backward()
th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM)
optimizer.step()
optimizer.zero_grad()
loss_sum += batch_loss
if batch_i % LOSS_REPORT_RATE == 0:
time_since_start = time.time() - start_time
print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")
loss_sum = 0
state_dict = policy.state_dict()
th.save(state_dict, out_weights)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data-dir", type=str, required=True, help="Path to the directory containing recordings to be trained on")
parser.add_argument("--in-model", required=True, type=str, help="Path to the .model file to be finetuned")
parser.add_argument("--in-weights", required=True, type=str, help="Path to the .weights file to be finetuned")
parser.add_argument("--out-weights", required=True, type=str, help="Path where finetuned weights will be saved")
args = parser.parse_args()
behavioural_cloning_train(args.data_dir, args.in_model, args.in_weights, args.out_weights)