diff --git a/muzero.py b/muzero.py index 834580a0..9213cf03 100644 --- a/muzero.py +++ b/muzero.py @@ -110,7 +110,7 @@ def __init__(self, game_name, config=None, split_resources_in=1): "num_reanalysed_games": 0, "terminate": False, } - self.replay_buffer = {} + self.replay_buffer = os.path.join(self.config.results_path, "replay_buffer.db") cpu_actor = CPUActor.remote() cpu_weights = cpu_actor.get_initial_weights.remote(self.config) diff --git a/replay_buffer.py b/replay_buffer.py index e1e5aed4..30fd9daa 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -1,13 +1,218 @@ -import copy +import collections +import pickle import time +import math import numpy +import random import ray +import sqlite3 import torch import models +class GameHistoryDao(collections.MutableMapping): + """ + Data Access Object for the game histories comprising the replay buffer + """ + + @staticmethod + def assemble_game_history(result): + + # assemble priorities into the game history + # structure: (id, game_priority, priorities, reanalysed_predicted_root_values, object) + game_history = pickle.loads(result[4]) + game_history.game_priority = result[1] + game_history.priorities = pickle.loads(result[2]) + game_history.reanalysed_predicted_root_values = pickle.loads(result[3]) + + return result[0], game_history + + @staticmethod + def disassemble_game_history(value): + + # disassemble the priorities from the game history + game_priority = value.game_priority + priorities = value.priorities + reanalysed_predicted_root_values = value.reanalysed_predicted_root_values + + # avoid storing duplicate data (it will be reassembled later) + value.game_priority = None + value.priorities = None + value.reanalysed_predicted_root_values = None + + return game_priority, priorities, reanalysed_predicted_root_values + + def __init__(self, file): + self.connection = sqlite3.connect(file) + self.connection.create_function('log', 1, math.log10) + self.connection.create_function('rand', 0, random.random) + self.connection.execute("CREATE TABLE IF NOT EXISTS game_history(" + " id INTEGER PRIMARY KEY ASC," + " game_priority REAL," + " priorities TEXT," + " reanalysed_predicted_root_values TEXT," + " object TEXT" + ")") + self.connection.commit() + + def __len__(self): + cursor = self.connection.cursor() + cursor.execute("SELECT COUNT(*) FROM game_history") + result = cursor.fetchone()[0] + return result + + def __getitem__(self, key): + cursor = self.connection.cursor() + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalysed_predicted_root_values," + " object" + " FROM game_history" + " WHERE id = ?", (int(key),)) + result = cursor.fetchone() + if result is None: + raise KeyError() + + return self.assemble_game_history(result) + + def __setitem__(self, key, value): + + game_priority, priorities, reanalysed_predicted_root_values = self.disassemble_game_history(value) + + cursor = self.connection.cursor() + cursor.execute("REPLACE INTO game_history(" + " id," + " game_priority," + " priorities," + " reanalysed_predicted_root_values," + " object" + ") VALUES(?, ?, ?, ?, ?)", ( + int(key), + float(game_priority), + pickle.dumps(priorities), + pickle.dumps(reanalysed_predicted_root_values), + pickle.dumps(value) + )) + self.connection.commit() + + def __delitem__(self, key): + cursor = self.connection.cursor() + cursor.execute("DELETE FROM game_history WHERE id = ?", (int(key),)) + self.connection.commit() + if cursor.rowcount == 0: + raise KeyError() + + def keys(self): + cursor = self.connection.cursor() + cursor.execute("SELECT id FROM game_history ORDER BY id ASC") + for row in cursor.fetchall(): + yield row[0] + + def values(self): + cursor = self.connection.cursor() + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalysed_predicted_root_values," + " object" + " FROM game_history ORDER BY id ASC") + for row in cursor: + yield self.assemble_game_history(row)[1] + + def items(self): + cursor = self.connection.cursor() + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalysed_predicted_root_values," + " object" + " FROM game_history ORDER BY id ASC") + for row in cursor: + yield self.assemble_game_history(row) + + def __contains__(self, key): + cursor = self.connection.cursor() + cursor.execute("SELECT COUNT(*) FROM game_history WHERE id = ?", (int(key),)) + return cursor.fetchone()[0] > 0 + + def __iter__(self): + cursor = self.connection.cursor() + cursor.execute("SELECT id FROM game_history ORDER BY id ASC") + for row in cursor: + yield row[0] + + def priorities(self, game_id): + cursor = self.connection.cursor() + cursor.execute("SELECT priorities FROM game_history WHERE id = ?", (game_id,)) + result = cursor.fetchone() + if result is None: + raise KeyError() + return pickle.loads(result[0]) + + def min_id(self): + cursor = self.connection.cursor() + cursor.execute("SELECT MIN(id) FROM game_history") + return cursor.fetchone()[0] + + def sample_n(self, n): + cursor = self.connection.cursor() + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalysed_predicted_root_values," + " object" + " FROM game_history WHERE id IN (" + " SELECT id FROM game_history" + " ORDER BY RANDOM()" + " LIMIT ?" + " )", (int(n),)) + for row in cursor: + yield self.assemble_game_history(row) + + def sample_n_ranked(self, n): + # reference: https://stackoverflow.com/a/12301949 + cursor = self.connection.cursor() + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalysed_predicted_root_values," + " object" + " FROM game_history WHERE id IN (" + " SELECT id FROM game_history" + " ORDER BY -LOG(1.0 - RAND()) / game_priority" + " LIMIT ?" + " )", (int(n),)) + for row in cursor: + yield self.assemble_game_history(row) + + def update_priorities(self, game_id, game_priority, priorities): + cursor = self.connection.cursor() + cursor.execute("UPDATE game_history" + " SET game_priority = ?," + " priorities = ?" + " WHERE" + " id = ?", ( + float(game_priority), + pickle.dumps(priorities), + int(game_id) + )) + self.connection.commit() + + def update_reanalysed_values(self, game_id, reanalysed_predicted_root_values): + cursor = self.connection.cursor() + cursor.execute("UPDATE game_history" + " SET reanalysed_predicted_root_values = ?" + " WHERE" + " id = ?", ( + pickle.dumps(reanalysed_predicted_root_values), + int(game_id) + )) + self.connection.commit() + + @ray.remote class ReplayBuffer: """ @@ -16,7 +221,8 @@ class ReplayBuffer: def __init__(self, initial_checkpoint, initial_buffer, config): self.config = config - self.buffer = copy.deepcopy(initial_buffer) + self.buffer_file = initial_buffer + self.buffer = GameHistoryDao(file=self.buffer_file) self.num_played_games = initial_checkpoint["num_played_games"] self.num_played_steps = initial_checkpoint["num_played_steps"] self.total_samples = sum( @@ -65,7 +271,7 @@ def save_game(self, game_history, shared_storage=None): shared_storage.set_info.remote("num_played_steps", self.num_played_steps) def get_buffer(self): - return self.buffer + return self.buffer_file def get_batch(self): ( @@ -138,38 +344,19 @@ def sample_game(self, force_uniform=False): Sample game from buffer either uniformly or according to some priority. See paper appendix Training. """ - game_prob = None - if self.config.PER and not force_uniform: - game_probs = numpy.array( - [game_history.game_priority for game_history in self.buffer.values()], - dtype="float32", - ) - game_probs /= numpy.sum(game_probs) - game_index = numpy.random.choice(len(self.buffer), p=game_probs) - game_prob = game_probs[game_index] - else: - game_index = numpy.random.choice(len(self.buffer)) - game_id = self.num_played_games - len(self.buffer) + game_index - - return game_id, self.buffer[game_id], game_prob + return next(iter(self.sample_n_games(1))) def sample_n_games(self, n_games, force_uniform=False): if self.config.PER and not force_uniform: - game_id_list = [] - game_probs = [] - for game_id, game_history in self.buffer.items(): - game_id_list.append(game_id) - game_probs.append(game_history.game_priority) - game_probs = numpy.array(game_probs, dtype="float32") - game_probs /= numpy.sum(game_probs) - game_prob_dict = dict([(game_id, prob) for game_id, prob in zip(game_id_list, game_probs)]) - selected_games = numpy.random.choice(game_id_list, n_games, p=game_probs) + samples = self.buffer.sample_n_ranked(n_games) else: - selected_games = numpy.random.choice(list(self.buffer.keys()), n_games) - game_prob_dict = {} - ret = [(game_id, self.buffer[game_id], game_prob_dict.get(game_id)) - for game_id in selected_games] - return ret + samples = self.buffer.sample_n(n_games) + + for sample in samples: + game_id = sample[0] + game_history = sample[1] + game_prob = game_history.game_priority + yield game_id, game_history, game_prob def sample_position(self, game_history, force_uniform=False): """ @@ -186,13 +373,8 @@ def sample_position(self, game_history, force_uniform=False): return position_index, position_prob - def update_game_history(self, game_id, game_history): - # The element could have been removed since its selection and update - if next(iter(self.buffer)) <= game_id: - if self.config.PER: - # Avoid read only array when loading replay buffer from disk - game_history.priorities = numpy.copy(game_history.priorities) - self.buffer[game_id] = game_history + def update_reanalysed_values(self, game_id, reanalysed_predicted_root_values): + self.buffer.update_reanalysed_values(game_id, reanalysed_predicted_root_values) def update_priorities(self, priorities, index_info): """ @@ -203,22 +385,29 @@ def update_priorities(self, priorities, index_info): game_id, game_pos = index_info[i] # The element could have been removed since its selection and training - if next(iter(self.buffer)) <= game_id: + if self.buffer.min_id() <= game_id: + + # select record from database (can't update in place) + priorities_record = self.buffer.priorities(game_id) + # Update position priorities priority = priorities[i, :] start_index = game_pos end_index = min( - game_pos + len(priority), len(self.buffer[game_id].priorities) + game_pos + len(priority), len(priorities_record) ) - self.buffer[game_id].priorities[start_index:end_index] = priority[ + priorities_record[start_index:end_index] = priority[ : end_index - start_index ] # Update game priorities - self.buffer[game_id].game_priority = numpy.max( - self.buffer[game_id].priorities + game_priority = numpy.max( + priorities_record ) + # update record + self.buffer.update_priorities(game_id, game_priority, priorities_record) + def compute_target_value(self, game_history, index): # The value target is the discounted root value of the search tree td_steps into the # future, plus the discounted sum of all rewards until then. @@ -333,6 +522,7 @@ def reanalyse(self, replay_buffer, shared_storage): ) # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) + reanalysed_predicted_root_values = game_history.reanalysed_predicted_root_values if self.config.use_last_model_value: observations = [ game_history.get_stacked_observations( @@ -350,11 +540,11 @@ def reanalyse(self, replay_buffer, shared_storage): self.model.initial_inference(observations)[0], self.config.support_size, ) - game_history.reanalysed_predicted_root_values = ( + reanalysed_predicted_root_values = ( torch.squeeze(values).detach().cpu().numpy() ) - replay_buffer.update_game_history.remote(game_id, game_history) + replay_buffer.update_reanalysed_values.remote(game_id, reanalysed_predicted_root_values) self.num_reanalysed_games += 1 shared_storage.set_info.remote( "num_reanalysed_games", self.num_reanalysed_games