Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep replay buffer on disk (not in memory), allowing it to grow to any size. #151

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, game_name, config=None, split_resources_in=1):
"num_reanalysed_games": 0,
"terminate": False,
}
self.replay_buffer = {}
self.replay_buffer = os.path.join(self.config.results_path, "replay_buffer.db")

cpu_actor = CPUActor.remote()
cpu_weights = cpu_actor.get_initial_weights.remote(self.config)
Expand Down
280 changes: 235 additions & 45 deletions replay_buffer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,218 @@
import copy
import collections
import pickle
import time

import math
import numpy
import random
import ray
import sqlite3
import torch

import models


class GameHistoryDao(collections.MutableMapping):
"""
Data Access Object for the game histories comprising the replay buffer
"""

@staticmethod
def assemble_game_history(result):

# assemble priorities into the game history
# structure: (id, game_priority, priorities, reanalysed_predicted_root_values, object)
game_history = pickle.loads(result[4])
game_history.game_priority = result[1]
game_history.priorities = pickle.loads(result[2])
game_history.reanalysed_predicted_root_values = pickle.loads(result[3])

return result[0], game_history

@staticmethod
def disassemble_game_history(value):

# disassemble the priorities from the game history
game_priority = value.game_priority
priorities = value.priorities
reanalysed_predicted_root_values = value.reanalysed_predicted_root_values

# avoid storing duplicate data (it will be reassembled later)
value.game_priority = None
value.priorities = None
value.reanalysed_predicted_root_values = None

return game_priority, priorities, reanalysed_predicted_root_values

def __init__(self, file):
self.connection = sqlite3.connect(file)
self.connection.create_function('log', 1, math.log10)
self.connection.create_function('rand', 0, random.random)
self.connection.execute("CREATE TABLE IF NOT EXISTS game_history("
" id INTEGER PRIMARY KEY ASC,"
" game_priority REAL,"
" priorities TEXT,"
" reanalysed_predicted_root_values TEXT,"
" object TEXT"
")")
self.connection.commit()

def __len__(self):
cursor = self.connection.cursor()
cursor.execute("SELECT COUNT(*) FROM game_history")
result = cursor.fetchone()[0]
return result

def __getitem__(self, key):
cursor = self.connection.cursor()
cursor.execute("SELECT id,"
" game_priority,"
" priorities,"
" reanalysed_predicted_root_values,"
" object"
" FROM game_history"
" WHERE id = ?", (int(key),))
result = cursor.fetchone()
if result is None:
raise KeyError()

return self.assemble_game_history(result)

def __setitem__(self, key, value):

game_priority, priorities, reanalysed_predicted_root_values = self.disassemble_game_history(value)

cursor = self.connection.cursor()
cursor.execute("REPLACE INTO game_history("
" id,"
" game_priority,"
" priorities,"
" reanalysed_predicted_root_values,"
" object"
") VALUES(?, ?, ?, ?, ?)", (
int(key),
float(game_priority),
pickle.dumps(priorities),
pickle.dumps(reanalysed_predicted_root_values),
pickle.dumps(value)
))
self.connection.commit()

def __delitem__(self, key):
cursor = self.connection.cursor()
cursor.execute("DELETE FROM game_history WHERE id = ?", (int(key),))
self.connection.commit()
if cursor.rowcount == 0:
raise KeyError()

def keys(self):
cursor = self.connection.cursor()
cursor.execute("SELECT id FROM game_history ORDER BY id ASC")
for row in cursor.fetchall():
yield row[0]

def values(self):
cursor = self.connection.cursor()
cursor.execute("SELECT id,"
" game_priority,"
" priorities,"
" reanalysed_predicted_root_values,"
" object"
" FROM game_history ORDER BY id ASC")
for row in cursor:
yield self.assemble_game_history(row)[1]

def items(self):
cursor = self.connection.cursor()
cursor.execute("SELECT id,"
" game_priority,"
" priorities,"
" reanalysed_predicted_root_values,"
" object"
" FROM game_history ORDER BY id ASC")
for row in cursor:
yield self.assemble_game_history(row)

def __contains__(self, key):
cursor = self.connection.cursor()
cursor.execute("SELECT COUNT(*) FROM game_history WHERE id = ?", (int(key),))
return cursor.fetchone()[0] > 0

def __iter__(self):
cursor = self.connection.cursor()
cursor.execute("SELECT id FROM game_history ORDER BY id ASC")
for row in cursor:
yield row[0]

def priorities(self, game_id):
cursor = self.connection.cursor()
cursor.execute("SELECT priorities FROM game_history WHERE id = ?", (game_id,))
result = cursor.fetchone()
if result is None:
raise KeyError()
return pickle.loads(result[0])

def min_id(self):
cursor = self.connection.cursor()
cursor.execute("SELECT MIN(id) FROM game_history")
return cursor.fetchone()[0]

def sample_n(self, n):
cursor = self.connection.cursor()
cursor.execute("SELECT id,"
" game_priority,"
" priorities,"
" reanalysed_predicted_root_values,"
" object"
" FROM game_history WHERE id IN ("
" SELECT id FROM game_history"
" ORDER BY RANDOM()"
" LIMIT ?"
" )", (int(n),))
for row in cursor:
yield self.assemble_game_history(row)

def sample_n_ranked(self, n):
# reference: https://stackoverflow.com/a/12301949
cursor = self.connection.cursor()
cursor.execute("SELECT id,"
" game_priority,"
" priorities,"
" reanalysed_predicted_root_values,"
" object"
" FROM game_history WHERE id IN ("
" SELECT id FROM game_history"
" ORDER BY -LOG(1.0 - RAND()) / game_priority"
" LIMIT ?"
" )", (int(n),))
for row in cursor:
yield self.assemble_game_history(row)

def update_priorities(self, game_id, game_priority, priorities):
cursor = self.connection.cursor()
cursor.execute("UPDATE game_history"
" SET game_priority = ?,"
" priorities = ?"
" WHERE"
" id = ?", (
float(game_priority),
pickle.dumps(priorities),
int(game_id)
))
self.connection.commit()

def update_reanalysed_values(self, game_id, reanalysed_predicted_root_values):
cursor = self.connection.cursor()
cursor.execute("UPDATE game_history"
" SET reanalysed_predicted_root_values = ?"
" WHERE"
" id = ?", (
pickle.dumps(reanalysed_predicted_root_values),
int(game_id)
))
self.connection.commit()


@ray.remote
class ReplayBuffer:
"""
Expand All @@ -16,7 +221,8 @@ class ReplayBuffer:

def __init__(self, initial_checkpoint, initial_buffer, config):
self.config = config
self.buffer = copy.deepcopy(initial_buffer)
self.buffer_file = initial_buffer
self.buffer = GameHistoryDao(file=self.buffer_file)
self.num_played_games = initial_checkpoint["num_played_games"]
self.num_played_steps = initial_checkpoint["num_played_steps"]
self.total_samples = sum(
Expand Down Expand Up @@ -65,7 +271,7 @@ def save_game(self, game_history, shared_storage=None):
shared_storage.set_info.remote("num_played_steps", self.num_played_steps)

def get_buffer(self):
return self.buffer
return self.buffer_file

def get_batch(self):
(
Expand Down Expand Up @@ -138,38 +344,19 @@ def sample_game(self, force_uniform=False):
Sample game from buffer either uniformly or according to some priority.
See paper appendix Training.
"""
game_prob = None
if self.config.PER and not force_uniform:
game_probs = numpy.array(
[game_history.game_priority for game_history in self.buffer.values()],
dtype="float32",
)
game_probs /= numpy.sum(game_probs)
game_index = numpy.random.choice(len(self.buffer), p=game_probs)
game_prob = game_probs[game_index]
else:
game_index = numpy.random.choice(len(self.buffer))
game_id = self.num_played_games - len(self.buffer) + game_index

return game_id, self.buffer[game_id], game_prob
return next(iter(self.sample_n_games(1)))

def sample_n_games(self, n_games, force_uniform=False):
if self.config.PER and not force_uniform:
game_id_list = []
game_probs = []
for game_id, game_history in self.buffer.items():
game_id_list.append(game_id)
game_probs.append(game_history.game_priority)
game_probs = numpy.array(game_probs, dtype="float32")
game_probs /= numpy.sum(game_probs)
game_prob_dict = dict([(game_id, prob) for game_id, prob in zip(game_id_list, game_probs)])
selected_games = numpy.random.choice(game_id_list, n_games, p=game_probs)
samples = self.buffer.sample_n_ranked(n_games)
else:
selected_games = numpy.random.choice(list(self.buffer.keys()), n_games)
game_prob_dict = {}
ret = [(game_id, self.buffer[game_id], game_prob_dict.get(game_id))
for game_id in selected_games]
return ret
samples = self.buffer.sample_n(n_games)

for sample in samples:
game_id = sample[0]
game_history = sample[1]
game_prob = game_history.game_priority
yield game_id, game_history, game_prob

def sample_position(self, game_history, force_uniform=False):
"""
Expand All @@ -186,13 +373,8 @@ def sample_position(self, game_history, force_uniform=False):

return position_index, position_prob

def update_game_history(self, game_id, game_history):
# The element could have been removed since its selection and update
if next(iter(self.buffer)) <= game_id:
if self.config.PER:
# Avoid read only array when loading replay buffer from disk
game_history.priorities = numpy.copy(game_history.priorities)
self.buffer[game_id] = game_history
def update_reanalysed_values(self, game_id, reanalysed_predicted_root_values):
self.buffer.update_reanalysed_values(game_id, reanalysed_predicted_root_values)

def update_priorities(self, priorities, index_info):
"""
Expand All @@ -203,22 +385,29 @@ def update_priorities(self, priorities, index_info):
game_id, game_pos = index_info[i]

# The element could have been removed since its selection and training
if next(iter(self.buffer)) <= game_id:
if self.buffer.min_id() <= game_id:

# select record from database (can't update in place)
priorities_record = self.buffer.priorities(game_id)

# Update position priorities
priority = priorities[i, :]
start_index = game_pos
end_index = min(
game_pos + len(priority), len(self.buffer[game_id].priorities)
game_pos + len(priority), len(priorities_record)
)
self.buffer[game_id].priorities[start_index:end_index] = priority[
priorities_record[start_index:end_index] = priority[
: end_index - start_index
]

# Update game priorities
self.buffer[game_id].game_priority = numpy.max(
self.buffer[game_id].priorities
game_priority = numpy.max(
priorities_record
)

# update record
self.buffer.update_priorities(game_id, game_priority, priorities_record)

def compute_target_value(self, game_history, index):
# The value target is the discounted root value of the search tree td_steps into the
# future, plus the discounted sum of all rewards until then.
Expand Down Expand Up @@ -333,6 +522,7 @@ def reanalyse(self, replay_buffer, shared_storage):
)

# Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze)
reanalysed_predicted_root_values = game_history.reanalysed_predicted_root_values
if self.config.use_last_model_value:
observations = [
game_history.get_stacked_observations(
Expand All @@ -350,11 +540,11 @@ def reanalyse(self, replay_buffer, shared_storage):
self.model.initial_inference(observations)[0],
self.config.support_size,
)
game_history.reanalysed_predicted_root_values = (
reanalysed_predicted_root_values = (
torch.squeeze(values).detach().cpu().numpy()
)

replay_buffer.update_game_history.remote(game_id, game_history)
replay_buffer.update_reanalysed_values.remote(game_id, reanalysed_predicted_root_values)
self.num_reanalysed_games += 1
shared_storage.set_info.remote(
"num_reanalysed_games", self.num_reanalysed_games
Expand Down