diff --git a/README.md b/README.md index c19b62c..5a9f8a0 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,27 @@ A window should pop up which shows the video frame-by-frame, showing the predict Note that `run_inverse_dynamics_model.py` is designed to be a demo of the IDM, not code to put it into practice. +# Using behavioural cloning to fine-tune the models + +**Disclaimer:** This code is a rough demonstration only and not an exact recreation of what original VPT paper did (but it contains some preprocessing steps you want to be aware of)! As such, do not expect replicate the original experiments with this code. This code has been designed to be run-able on consumer hardware (e.g., 8GB of VRAM). + +Setup: +* Install requirements: `pip install -r requirements.txt` +* Download `.weights` and `.model` file for model you want to fine-tune. +* Download contractor data (below) and place the `.mp4` and `.jsonl` files to the same directory (e.g., `data`). With default settings, you need at least 12 recordings. + +If you downloaded the "1x Width" models and placed some data under `data` directory, you can perform finetuning with + +``` +python behavioural_cloning.py --data-dir data --in-model foundation-model-1x.model --in-weights foundation-model-1x.weights --out-weights finetuned-1x.weights +``` + +You can then use `finetuned-1x.weights` when running the agent. You can change the training settings at the top of `behavioural_cloning.py`. + +Major limitations: +- Only trains single step at the time, i.e., errors are not propagated through timesteps. +- Computes gradients one sample at a time to keep memory use low, but also slows down the code. + # Contractor Demonstrations ### Versions diff --git a/agent.py b/agent.py index 874642a..12c98e2 100644 --- a/agent.py +++ b/agent.py @@ -139,7 +139,11 @@ def reset(self): self.hidden_state = self.policy.initial_state(1) def _env_obs_to_agent(self, minerl_obs): - """Turn observation from MineRL environment into model's observation""" + """ + Turn observation from MineRL environment into model's observation + + Returns torch tensors. + """ agent_input = resize_image(minerl_obs["pov"], AGENT_RESOLUTION)[None] agent_input = {"img": th.from_numpy(agent_input).to(self.device)} return agent_input @@ -149,17 +153,39 @@ def _agent_action_to_env(self, agent_action): # This is quite important step (for some reason). # For the sake of your sanity, remember to do this step (manual conversion to numpy) # before proceeding. Otherwise, your agent might be a little derp. - action = { - "buttons": agent_action["buttons"].cpu().numpy(), - "camera": agent_action["camera"].cpu().numpy() - } + action = agent_action + if isinstance(action["buttons"], th.Tensor): + action = { + "buttons": agent_action["buttons"].cpu().numpy(), + "camera": agent_action["camera"].cpu().numpy() + } minerl_action = self.action_mapper.to_factored(action) minerl_action_transformed = self.action_transformer.policy2env(minerl_action) return minerl_action_transformed - def _env_action_to_agent(self, minerl_action): - """Turn action from MineRL to model's action""" - raise NotImplementedError() + def _env_action_to_agent(self, minerl_action_transformed, to_torch=False, check_if_null=False): + """ + Turn action from MineRL to model's action. + + Note that this will add batch dimensions to the action. + Returns numpy arrays, unless `to_torch` is True, in which case it returns torch tensors. + + If `check_if_null` is True, check if the action is null (no action) after the initial + transformation. This matches the behaviour done in OpenAI's VPT work. + If action is null, return "None" instead + """ + minerl_action = self.action_transformer.env2policy(minerl_action_transformed) + if check_if_null: + if np.all(minerl_action["buttons"] == 0) and np.all(minerl_action["camera"] == self.action_transformer.camera_zero_bin): + return None + + # Add batch dims if not existant + if minerl_action["camera"].ndim == 1: + minerl_action = {k: v[None] for k, v in minerl_action.items()} + action = self.action_mapper.from_factored(minerl_action) + if to_torch: + action = {k: th.from_numpy(v).to(self.device) for k, v in action.items()} + return action def get_action(self, minerl_obs): """ @@ -177,4 +203,4 @@ def get_action(self, minerl_obs): stochastic=True ) minerl_action = self._agent_action_to_env(agent_action) - return minerl_action \ No newline at end of file + return minerl_action diff --git a/behavioural_cloning.py b/behavioural_cloning.py new file mode 100644 index 0000000..07ea827 --- /dev/null +++ b/behavioural_cloning.py @@ -0,0 +1,143 @@ +# Basic behavioural cloning +# Note: this uses gradient accumulation in batches of ones +# to perform training. +# This will fit inside even smaller GPUs (tested on 8GB one), +# but is slow. +# NOTE: This is _not_ the original code used for VPT! +# This is merely to illustrate how to fine-tune the models and includes +# the processing steps used. + +# This will likely be much worse than what original VPT did: +# we are not training on full sequences, but only one step at a time to save VRAM. + +from argparse import ArgumentParser +import pickle +import time + +import gym +import minerl +import torch as th +import numpy as np + +from agent import PI_HEAD_KWARGS, MineRLAgent +from data_loader import DataLoader +from lib.tree_util import tree_map + +EPOCHS = 2 +# Needs to be <= number of videos +BATCH_SIZE = 8 +# Ideally more than batch size to create +# variation in datasets (otherwise, you will +# get a bunch of consecutive samples) +# Decrease this (and batch_size) if you run out of memory +N_WORKERS = 12 +DEVICE = "cuda" + +LOSS_REPORT_RATE = 100 + +LEARNING_RATE = 0.000181 +WEIGHT_DECAY = 0.039428 +MAX_GRAD_NORM = 5.0 + +def load_model_parameters(path_to_model_file): + agent_parameters = pickle.load(open(path_to_model_file, "rb")) + policy_kwargs = agent_parameters["model"]["args"]["net"]["args"] + pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"] + pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"]) + return policy_kwargs, pi_head_kwargs + +def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights): + agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model) + + # To create model with the right environment. + # All basalt environments have the same settings, so any of them works here + env = gym.make("MineRLBasaltFindCave-v0") + agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs) + agent.load_weights(in_weights) + env.close() + + policy = agent.policy + trainable_parameters = policy.parameters() + + # Parameters taken from the OpenAI VPT paper + optimizer = th.optim.Adam( + trainable_parameters, + lr=LEARNING_RATE, + weight_decay=WEIGHT_DECAY + ) + + data_loader = DataLoader( + dataset_dir=data_dir, + n_workers=N_WORKERS, + batch_size=BATCH_SIZE, + n_epochs=EPOCHS + ) + + start_time = time.time() + + # Keep track of the hidden state per episode/trajectory. + # DataLoader provides unique id for each episode, which will + # be different even for the same trajectory when it is loaded + # up again + episode_hidden_states = {} + dummy_first = th.from_numpy(np.array((False,))).to(DEVICE) + + loss_sum = 0 + for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader): + batch_loss = 0 + for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id): + agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True) + if agent_action is None: + # Action was null + continue + + agent_obs = agent._env_obs_to_agent({"pov": image}) + if episode_id not in episode_hidden_states: + # TODO need to clean up this hidden state after worker is done with the work item. + # Leaks memory, but not tooooo much at these scales (will be a problem later). + episode_hidden_states[episode_id] = policy.initial_state(1) + agent_state = episode_hidden_states[episode_id] + + pi_distribution, v_prediction, new_agent_state = policy.get_output_for_observation( + agent_obs, + agent_state, + dummy_first + ) + + log_prob = policy.get_logprob_of_action(pi_distribution, agent_action) + + # Make sure we do not try to backprop through sequence + # (fails with current accumulation) + new_agent_state = tree_map(lambda x: x.detach(), new_agent_state) + episode_hidden_states[episode_id] = new_agent_state + + # Finally, update the agent to increase the probability of the + # taken action. + # Remember to take mean over batch losses + loss = -log_prob / BATCH_SIZE + batch_loss += loss.item() + loss.backward() + + th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM) + optimizer.step() + optimizer.zero_grad() + + loss_sum += batch_loss + if batch_i % LOSS_REPORT_RATE == 0: + time_since_start = time.time() - start_time + print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}") + loss_sum = 0 + + state_dict = policy.state_dict() + th.save(state_dict, out_weights) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--data-dir", type=str, required=True, help="Path to the directory containing recordings to be trained on") + parser.add_argument("--in-model", required=True, type=str, help="Path to the .model file to be finetuned") + parser.add_argument("--in-weights", required=True, type=str, help="Path to the .weights file to be finetuned") + parser.add_argument("--out-weights", required=True, type=str, help="Path where finetuned weights will be saved") + + args = parser.parse_args() + behavioural_cloning_train(args.data_dir, args.in_model, args.in_weights, args.out_weights) diff --git a/cursors/mouse_cursor_white_16x16.png b/cursors/mouse_cursor_white_16x16.png new file mode 100644 index 0000000..def065b Binary files /dev/null and b/cursors/mouse_cursor_white_16x16.png differ diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..bbd6f3b --- /dev/null +++ b/data_loader.py @@ -0,0 +1,222 @@ +# Code for loading OpenAI MineRL VPT datasets +# NOTE: This is NOT original code used for the VPT experiments! +# (But contains all [or at least most] steps done in the original data loading) + +import json +import glob +import os +import random +from multiprocessing import Process, Queue, Event + +import numpy as np +import cv2 + +from run_inverse_dynamics_model import json_action_to_env_action +from agent import resize_image, AGENT_RESOLUTION + +QUEUE_TIMEOUT = 10 + +CURSOR_FILE = os.path.join(os.path.dirname(__file__), "cursors", "mouse_cursor_white_16x16.png") + +MINEREC_ORIGINAL_HEIGHT_PX = 720 + +# If GUI is open, mouse dx/dy need also be adjusted with these scalers. +# If data version is not present, assume it is 1. +MINEREC_VERSION_SPECIFIC_SCALERS = { + "5.7": 0.5, + "5.8": 0.5, + "6.7": 2.0, + "6.8": 2.0, + "6.9": 2.0, +} + + +def composite_images_with_alpha(image1, image2, alpha, x, y): + """ + Draw image2 over image1 at location x,y, using alpha as the opacity for image2. + + Modifies image1 in-place + """ + ch = min(0, image1.shape[0] - y, image2.shape[0]) + cw = min(0, image1.shape[1] - x, image2.shape[1]) + if ch == 0 or cw == 0: + return + alpha = alpha[:ch, :cw] + image1[y:y + ch, x:x + cw, :] = (image1[y:y + ch, x:x + cw, :] * (1 - alpha) + image2[:ch, :cw, :] * alpha).astype(np.uint8) + + +def data_loader_worker(tasks_queue, output_queue, quit_workers_event): + """ + Worker for the data loader. + """ + cursor_image = cv2.imread(CURSOR_FILE, cv2.IMREAD_UNCHANGED) + # Assume 16x16 + cursor_image = cursor_image[:16, :16, :] + cursor_alpha = cursor_image[:, :, 3:] / 255.0 + cursor_image = cursor_image[:, :, :3] + + while True: + task = tasks_queue.get() + if task is None: + break + trajectory_id, video_path, json_path = task + video = cv2.VideoCapture(video_path) + # NOTE: In some recordings, the game seems to start + # with attack always down from the beginning, which + # is stuck down until player actually presses attack + # NOTE: It is uncertain if this was the issue with the original code. + attack_is_stuck = False + # Scrollwheel is allowed way to change items, but this is + # not captured by the recorder. + # Work around this by keeping track of selected hotbar item + # and updating "hotbar.#" actions when hotbar selection changes. + # NOTE: It is uncertain is this was/is an issue with the contractor data + last_hotbar = 0 + + with open(json_path) as json_file: + json_lines = json_file.readlines() + json_data = "[" + ",".join(json_lines) + "]" + json_data = json.loads(json_data) + for i in range(len(json_data)): + if quit_workers_event.is_set(): + break + step_data = json_data[i] + + if i == 0: + # Check if attack will be stuck down + if step_data["mouse"]["newButtons"] == [0]: + attack_is_stuck = True + elif attack_is_stuck: + # Check if we press attack down, then it might not be stuck + if 0 in step_data["mouse"]["newButtons"]: + attack_is_stuck = False + # If still stuck, remove the action + if attack_is_stuck: + step_data["mouse"]["buttons"] = [button for button in step_data["mouse"]["buttons"] if button != 0] + + action, is_null_action = json_action_to_env_action(step_data) + + # Update hotbar selection + current_hotbar = step_data["hotbar"] + if current_hotbar != last_hotbar: + action["hotbar.{}".format(current_hotbar + 1)] = 1 + last_hotbar = current_hotbar + + # Read frame even if this is null so we progress forward + ret, frame = video.read() + if ret: + # Skip null actions as done in the VPT paper + # NOTE: in VPT paper, this was checked _after_ transforming into agent's action-space. + # We do this here as well to reduce amount of data sent over. + if is_null_action: + continue + if step_data["isGuiOpen"]: + camera_scaling_factor = frame.shape[0] / MINEREC_ORIGINAL_HEIGHT_PX + cursor_x = int(step_data["mouse"]["x"] * camera_scaling_factor) + cursor_y = int(step_data["mouse"]["y"] * camera_scaling_factor) + composite_images_with_alpha(frame, cursor_image, cursor_alpha, cursor_x, cursor_y) + cv2.cvtColor(frame, code=cv2.COLOR_BGR2RGB, dst=frame) + frame = np.asarray(np.clip(frame, 0, 255), dtype=np.uint8) + frame = resize_image(frame, AGENT_RESOLUTION) + output_queue.put((trajectory_id, frame, action), timeout=QUEUE_TIMEOUT) + else: + print(f"Could not read frame from video {video_path}") + video.release() + if quit_workers_event.is_set(): + break + # Tell that we ended + output_queue.put(None) + +class DataLoader: + """ + Generator class for loading batches from a dataset + + This only returns a single step at a time per worker; no sub-sequences. + Idea is that you keep track of the model's hidden state and feed that in, + along with one sample at a time. + + + Simpler loader code + + Supports lower end hardware + - Not very efficient (could be faster) + - No support for sub-sequences + - Loads up individual files as trajectory files (i.e. if a trajectory is split into multiple files, + this code will load it up as a separate item). + """ + def __init__(self, dataset_dir, n_workers=8, batch_size=8, n_epochs=1, max_queue_size=16): + assert n_workers >= batch_size, "Number of workers must be equal or greater than batch size" + self.dataset_dir = dataset_dir + self.n_workers = n_workers + self.n_epochs = n_epochs + self.batch_size = batch_size + self.max_queue_size = max_queue_size + unique_ids = glob.glob(os.path.join(dataset_dir, "*.mp4")) + unique_ids = list(set([os.path.basename(x).split(".")[0] for x in unique_ids])) + self.unique_ids = unique_ids + # Create tuples of (video_path, json_path) for each unique_id + demonstration_tuples = [] + for unique_id in unique_ids: + video_path = os.path.abspath(os.path.join(dataset_dir, unique_id + ".mp4")) + json_path = os.path.abspath(os.path.join(dataset_dir, unique_id + ".jsonl")) + demonstration_tuples.append((video_path, json_path)) + + assert n_workers <= len(demonstration_tuples), f"n_workers should be lower or equal than number of demonstrations {len(demonstration_tuples)}" + + # Repeat dataset for n_epochs times, shuffling the order for + # each epoch + self.demonstration_tuples = [] + for i in range(n_epochs): + random.shuffle(demonstration_tuples) + self.demonstration_tuples += demonstration_tuples + + self.task_queue = Queue() + self.n_steps_processed = 0 + for trajectory_id, task in enumerate(self.demonstration_tuples): + self.task_queue.put((trajectory_id, *task)) + for _ in range(n_workers): + self.task_queue.put(None) + + self.output_queues = [Queue(maxsize=max_queue_size) for _ in range(n_workers)] + self.quit_workers_event = Event() + self.processes = [ + Process( + target=data_loader_worker, + args=( + self.task_queue, + output_queue, + self.quit_workers_event, + ), + daemon=True + ) + for output_queue in self.output_queues + ] + for process in self.processes: + process.start() + + def __iter__(self): + return self + + def __next__(self): + batch_frames = [] + batch_actions = [] + batch_episode_id = [] + + for i in range(self.batch_size): + workitem = self.output_queues[self.n_steps_processed % self.n_workers].get(timeout=QUEUE_TIMEOUT) + if workitem is None: + # Stop iteration when first worker runs out of work to do. + # Yes, this has a chance of cutting out a lot of the work, + # but this ensures batches will remain diverse, instead + # of having bad ones in the end where potentially + # one worker outputs all samples to the same batch. + raise StopIteration() + trajectory_id, frame, action = workitem + batch_frames.append(frame) + batch_actions.append(action) + batch_episode_id.append(trajectory_id) + self.n_steps_processed += 1 + return batch_frames, batch_actions, batch_episode_id + + def __del__(self): + for process in self.processes: + process.terminate() + process.join() diff --git a/lib/policy.py b/lib/policy.py index 7086c73..b517a08 100644 --- a/lib/policy.py +++ b/lib/policy.py @@ -268,6 +268,42 @@ def forward(self, obs, first: th.Tensor, state_in): return (pi_logits, vpred, None), state_out + def get_logprob_of_action(self, pd, action): + """ + Get logprob of taking action `action` given probability distribution + (see `get_gradient_for_action` to get this distribution) + """ + ac = tree_map(lambda x: x.unsqueeze(1), action) + log_prob = self.pi_head.logprob(ac, pd) + assert not th.isnan(log_prob).any() + return log_prob[:, 0] + + def get_kl_of_action_dists(self, pd1, pd2): + """ + Get the KL divergence between two action probability distributions + """ + return self.pi_head.kl_divergence(pd1, pd2) + + def get_output_for_observation(self, obs, state_in, first): + """ + Return gradient-enabled outputs for given observation. + + Use `get_logprob_of_action` to get log probability of action + with the given probability distribution. + + Returns: + - probability distribution given observation + - value prediction for given observation + - new state + """ + # We need to add a fictitious time dimension everywhere + obs = tree_map(lambda x: x.unsqueeze(1), obs) + first = first.unsqueeze(1) + + (pd, vpred, _), state_out = self(obs=obs, first=first, state_in=state_in) + + return pd, self.value_head.denormalize(vpred)[:, 0], state_out + @th.no_grad() def act(self, obs, first, state_in, stochastic: bool = True, taken_action=None, return_pd=False): # We need to add a fictitious time dimension everywhere diff --git a/run_inverse_dynamics_model.py b/run_inverse_dynamics_model.py index d7a5373..a932d92 100644 --- a/run_inverse_dynamics_model.py +++ b/run_inverse_dynamics_model.py @@ -80,39 +80,49 @@ def json_action_to_env_action(json_action): """ Converts a json action into a MineRL action. - - Note: in some recordings, the recording starts with "attack: 1" despite - player obviously not attacking. The IDM seems to reflect this (predicts "attack: 1") - at the beginning of the recording. + Returns (minerl_action, is_null_action) """ + # This might be slow... env_action = NOOP_ACTION.copy() + # As a safeguard, make camera action again so we do not override anything env_action["camera"] = np.array([0, 0]) + is_null_action = True keyboard_keys = json_action["keyboard"]["keys"] for key in keyboard_keys: # You can have keys that we do not use, so just skip them + # NOTE in original training code, ESC was removed and replaced with + # "inventory" action if GUI was open. + # Not doing it here, as BASALT uses ESC to quit the game. if key in KEYBOARD_BUTTON_MAPPING: env_action[KEYBOARD_BUTTON_MAPPING[key]] = 1 + is_null_action = False mouse = json_action["mouse"] camera_action = env_action["camera"] camera_action[0] = mouse["dy"] * CAMERA_SCALER camera_action[1] = mouse["dx"] * CAMERA_SCALER - if abs(camera_action[0]) > 180: - camera_action[0] = 0 - if abs(camera_action[1]) > 180: - camera_action[1] = 0 + if mouse["dx"] != 0 or mouse["dy"] != 0: + is_null_action = False + else: + if abs(camera_action[0]) > 180: + camera_action[0] = 0 + if abs(camera_action[1]) > 180: + camera_action[1] = 0 mouse_buttons = mouse["buttons"] if 0 in mouse_buttons: env_action["attack"] = 1 + is_null_action = False if 1 in mouse_buttons: env_action["use"] = 1 + is_null_action = False if 2 in mouse_buttons: env_action["pickItem"] = 1 + is_null_action = False - return env_action + return env_action, is_null_action def main(model, weights, video_path, json_path, n_batches, n_frames): @@ -145,7 +155,8 @@ def main(model, weights, video_path, json_path, n_batches, n_frames): assert frame.shape[0] == required_resolution[1] and frame.shape[1] == required_resolution[0], "Video must be of resolution {}".format(required_resolution) # BGR -> RGB frames.append(frame[..., ::-1]) - recorded_actions.append(json_action_to_env_action(json_data[json_index])) + env_action, _ = json_action_to_env_action(json_data[json_index]) + recorded_actions.append(env_action) json_index += 1 frames = np.stack(frames) print("=== Predicting actions ===")