From f18a35b736b897ddd37ce40024968e40d0abb012 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Wed, 21 Apr 2021 22:34:57 -0400 Subject: [PATCH 01/17] Keep replay buffer on disk instead of in memory, using a sqlite database --- muzero.py | 2 +- replay_buffer.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/muzero.py b/muzero.py index 834580a0..9213cf03 100644 --- a/muzero.py +++ b/muzero.py @@ -110,7 +110,7 @@ def __init__(self, game_name, config=None, split_resources_in=1): "num_reanalysed_games": 0, "terminate": False, } - self.replay_buffer = {} + self.replay_buffer = os.path.join(self.config.results_path, "replay_buffer.db") cpu_actor = CPUActor.remote() cpu_weights = cpu_actor.get_initial_weights.remote(self.config) diff --git a/replay_buffer.py b/replay_buffer.py index e1e5aed4..d9a0332b 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -1,13 +1,89 @@ +import collections import copy +import pickle import time import numpy import ray import torch +import sqlite3 import models +class GameHistoryDao(collections.MutableMapping): + """ + Data Access Object for the game histories comprising the replay buffer + """ + + def __init__(self, file): + self.connection = sqlite3.connect(file) + self.connection.execute("CREATE TABLE IF NOT EXISTS game_history(" + " id INTEGER PRIMARY KEY ASC," + " value TEXT" + ")") + self.connection.commit() + self.temp = None + + def __len__(self): + cursor = self.connection.cursor() + cursor.execute("SELECT COUNT(*) FROM game_history") + result = cursor.fetchone()[0] + return result + + def __getitem__(self, key): + cursor = self.connection.cursor() + cursor.execute("SELECT value FROM game_history WHERE id = ?", (int(key),)) + result = cursor.fetchone() + if result is None: + raise KeyError() + as_text = result[0] + return pickle.loads(as_text) + + def __setitem__(self, key, value): + as_text = pickle.dumps(value) + cursor = self.connection.cursor() + cursor.execute("REPLACE INTO game_history(id, value) VALUES(?, ?)", (int(key), as_text)) + self.connection.commit() + self.temp = value + + def __delitem__(self, key): + cursor = self.connection.cursor() + cursor.execute("DELETE FROM game_history WHERE id = ?", (int(key),)) + self.connection.commit() + if cursor.rowcount == 0: + raise KeyError() + + def keys(self): + cursor = self.connection.cursor() + cursor.execute("SELECT id FROM game_history ORDER BY id ASC") + for row in cursor.fetchall(): + yield row[0] + + def values(self): + cursor = self.connection.cursor() + cursor.execute("SELECT value FROM game_history ORDER BY id ASC") + for row in cursor: + yield pickle.loads(row[0]) + + def items(self): + cursor = self.connection.cursor() + cursor.execute("SELECT id, value FROM game_history ORDER BY id ASC") + for row in cursor: + yield row[0], pickle.loads(row[1]) + + def __contains__(self, key): + cursor = self.connection.cursor() + cursor.execute("SELECT COUNT(*) FROM game_history WHERE id = ?", (int(key),)) + return cursor.fetchone()[0] > 0 + + def __iter__(self): + cursor = self.connection.cursor() + cursor.execute("SELECT id FROM game_history ORDER BY id ASC") + for row in cursor: + yield row[0] + + @ray.remote class ReplayBuffer: """ @@ -16,7 +92,8 @@ class ReplayBuffer: def __init__(self, initial_checkpoint, initial_buffer, config): self.config = config - self.buffer = copy.deepcopy(initial_buffer) + self.buffer_file = initial_buffer + self.buffer = GameHistoryDao(file=self.buffer_file) self.num_played_games = initial_checkpoint["num_played_games"] self.num_played_steps = initial_checkpoint["num_played_steps"] self.total_samples = sum( @@ -65,7 +142,7 @@ def save_game(self, game_history, shared_storage=None): shared_storage.set_info.remote("num_played_steps", self.num_played_steps) def get_buffer(self): - return self.buffer + return self.buffer_file def get_batch(self): ( From a22b6f586bf265df4018a53edbcca2a1f38f43b1 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Wed, 21 Apr 2021 22:44:24 -0400 Subject: [PATCH 02/17] Update appropriately for database backed records --- replay_buffer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index d9a0332b..2d05478e 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -281,21 +281,28 @@ def update_priorities(self, priorities, index_info): # The element could have been removed since its selection and training if next(iter(self.buffer)) <= game_id: + + # select record from database (can't update in place) + game_history = self.buffer[game_id] + # Update position priorities priority = priorities[i, :] start_index = game_pos end_index = min( - game_pos + len(priority), len(self.buffer[game_id].priorities) + game_pos + len(priority), len(game_history.priorities) ) - self.buffer[game_id].priorities[start_index:end_index] = priority[ + game_history.priorities[start_index:end_index] = priority[ : end_index - start_index ] # Update game priorities - self.buffer[game_id].game_priority = numpy.max( + game_history.game_priority = numpy.max( self.buffer[game_id].priorities ) + # update record + self.buffer[game_id] = game_history + def compute_target_value(self, game_history, index): # The value target is the discounted root value of the search tree td_steps into the # future, plus the discounted sum of all rewards until then. From 4b46ed69de8d677197926bb7412035f557d8557b Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 19:30:41 -0400 Subject: [PATCH 03/17] Split GameHistory data into multiple columns for quick access --- replay_buffer.py | 183 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 136 insertions(+), 47 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 2d05478e..bb545666 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -16,14 +16,43 @@ class GameHistoryDao(collections.MutableMapping): Data Access Object for the game histories comprising the replay buffer """ + @staticmethod + def assemble_game_history(result): + + # assemble priorities into the game history + # structure: (id, game_priority, priorities, reanalyzed_predicted_root_values, object) + game_history = pickle.loads(result[4]) + game_history.game_priority = result[1] + game_history.priorities = pickle.loads(result[2]) + game_history.reanalysed_predicted_root_values = pickle.loads(result[3]) + + return result[0], game_history + + @staticmethod + def disassemble_game_history(value): + + # disassemble the priorities from the game history + game_priority = value.game_priority + priorities = value.priorities + reanalyzed_predicted_root_values = value.reanalyzed_predicted_root_values + + # avoid storing duplicate data (it will be reassembled later) + value.game_priority = None + value.priorities = None + value.reanalyzed_predicted_root_values = None + + return game_priority, priorities, reanalyzed_predicted_root_values + def __init__(self, file): self.connection = sqlite3.connect(file) self.connection.execute("CREATE TABLE IF NOT EXISTS game_history(" " id INTEGER PRIMARY KEY ASC," - " value TEXT" + " game_priority REAL," + " priorities TEXT," + " reanalyzed_predicted_root_values TEXT," + " object TEXT" ")") self.connection.commit() - self.temp = None def __len__(self): cursor = self.connection.cursor() @@ -33,19 +62,37 @@ def __len__(self): def __getitem__(self, key): cursor = self.connection.cursor() - cursor.execute("SELECT value FROM game_history WHERE id = ?", (int(key),)) + cursor.execute("SELECT game_priority," + " priorities," + " reanalyzed_predicted_root_values" + " object" + " FROM game_history" + " WHERE id = ?", (int(key),)) result = cursor.fetchone() if result is None: raise KeyError() - as_text = result[0] - return pickle.loads(as_text) + + return self.assemble_game_history(result) def __setitem__(self, key, value): - as_text = pickle.dumps(value) + + game_priority, priorities, reanalyzed_predicted_root_values = self.disassemble_game_history(value) + cursor = self.connection.cursor() - cursor.execute("REPLACE INTO game_history(id, value) VALUES(?, ?)", (int(key), as_text)) + cursor.execute("REPLACE INTO game_history(" + " id," + " game_priority," + " priorities," + " reanalyzed_predicted_root_values," + " object" + ") VALUES(?, ?, ?, ?, ?)", ( + int(key), + game_priority, + pickle.dumps(priorities), + pickle.dumps(reanalyzed_predicted_root_values), + pickle.dumps(value) + )) self.connection.commit() - self.temp = value def __delitem__(self, key): cursor = self.connection.cursor() @@ -83,6 +130,70 @@ def __iter__(self): for row in cursor: yield row[0] + def priorities(self, game_id): + cursor = self.connection.cursor() + cursor.execute("SELECT priorities FROM game_history WHERE id = ?", (game_id,)) + result = cursor.fetchone() + if result is None: + raise KeyError() + return pickle.loads(result[0]) + + def min_id(self): + cursor = self.connection.cursor() + cursor.execute("SELECT MIN(id) FROM game_history") + return cursor.fetchone()[0] + + def sample_n(self, n): + cursor = self.connection.cursor() + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalyzed_predicted_root_values" + " object" + " FROM game_history" + " ORDER BY RANDOM()" + " LIMIT ?", (int(n),)) + for row in cursor: + yield self.assemble_game_history(row) + + def sample_n_ranked(self, n): + # reference: https://stackoverflow.com/a/12301949 + cursor = self.connection.cursor() + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalyzed_predicted_root_values" + " object" + " FROM game_history" + " ORDER BY -LOG(1.0 - RAND()) / game_priority" + " LIMIT ?", (int(n),)) + for row in cursor: + yield self.assemble_game_history(row) + + def update_priorities(self, game_id, game_priority, priorities): + cursor = self.connection.cursor() + cursor.execute("UPDATE game_history" + "SET game_priority = ?," + " priorities = ?" + "WHERE" + " id = ?", ( + game_priority, + pickle.dumps(priorities), + int(game_id) + )) + self.connection.commit() + + def update_reanalyzed_values(self, game_id, reanalyzed_predicted_root_values): + cursor = self.connection.cursor() + cursor.execute("UPDATE game_history" + "SET reanalyzed_predicted_root_values = ?" + "WHERE" + " id = ?", ( + reanalyzed_predicted_root_values, + int(game_id) + )) + self.connection.commit() + @ray.remote class ReplayBuffer: @@ -215,38 +326,19 @@ def sample_game(self, force_uniform=False): Sample game from buffer either uniformly or according to some priority. See paper appendix Training. """ - game_prob = None - if self.config.PER and not force_uniform: - game_probs = numpy.array( - [game_history.game_priority for game_history in self.buffer.values()], - dtype="float32", - ) - game_probs /= numpy.sum(game_probs) - game_index = numpy.random.choice(len(self.buffer), p=game_probs) - game_prob = game_probs[game_index] - else: - game_index = numpy.random.choice(len(self.buffer)) - game_id = self.num_played_games - len(self.buffer) + game_index - - return game_id, self.buffer[game_id], game_prob + return next(iter(self.sample_n_games(1))) def sample_n_games(self, n_games, force_uniform=False): if self.config.PER and not force_uniform: - game_id_list = [] - game_probs = [] - for game_id, game_history in self.buffer.items(): - game_id_list.append(game_id) - game_probs.append(game_history.game_priority) - game_probs = numpy.array(game_probs, dtype="float32") - game_probs /= numpy.sum(game_probs) - game_prob_dict = dict([(game_id, prob) for game_id, prob in zip(game_id_list, game_probs)]) - selected_games = numpy.random.choice(game_id_list, n_games, p=game_probs) + samples = self.buffer.sample_n_ranked(n_games) else: - selected_games = numpy.random.choice(list(self.buffer.keys()), n_games) - game_prob_dict = {} - ret = [(game_id, self.buffer[game_id], game_prob_dict.get(game_id)) - for game_id in selected_games] - return ret + samples = self.buffer.sample_n(n_games) + + for sample in samples: + game_id = sample[0] + game_history = sample[1] + game_prob = game_history.game_priority + yield game_id, game_history, game_prob def sample_position(self, game_history, force_uniform=False): """ @@ -265,10 +357,7 @@ def sample_position(self, game_history, force_uniform=False): def update_game_history(self, game_id, game_history): # The element could have been removed since its selection and update - if next(iter(self.buffer)) <= game_id: - if self.config.PER: - # Avoid read only array when loading replay buffer from disk - game_history.priorities = numpy.copy(game_history.priorities) + if self.buffer.min_id() <= game_id: self.buffer[game_id] = game_history def update_priorities(self, priorities, index_info): @@ -280,28 +369,28 @@ def update_priorities(self, priorities, index_info): game_id, game_pos = index_info[i] # The element could have been removed since its selection and training - if next(iter(self.buffer)) <= game_id: + if self.buffer.min_id() <= game_id: # select record from database (can't update in place) - game_history = self.buffer[game_id] + priorities_record = self.buffer.priorities(game_id) # Update position priorities priority = priorities[i, :] start_index = game_pos end_index = min( - game_pos + len(priority), len(game_history.priorities) + game_pos + len(priority), len(priorities_record) ) - game_history.priorities[start_index:end_index] = priority[ + priorities_record[start_index:end_index] = priority[ : end_index - start_index ] # Update game priorities - game_history.game_priority = numpy.max( - self.buffer[game_id].priorities + game_priority = numpy.max( + priorities_record ) # update record - self.buffer[game_id] = game_history + self.buffer.update_priorities(game_id, game_priority, priorities_record) def compute_target_value(self, game_history, index): # The value target is the discounted root value of the search tree td_steps into the From 2415a6ca12d01c3ae11e0792009dea6eebbc2f8d Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:00:25 -0400 Subject: [PATCH 04/17] Fix column name --- replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replay_buffer.py b/replay_buffer.py index bb545666..09dc6c4f 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -109,7 +109,7 @@ def keys(self): def values(self): cursor = self.connection.cursor() - cursor.execute("SELECT value FROM game_history ORDER BY id ASC") + cursor.execute("SELECT object FROM game_history ORDER BY id ASC") for row in cursor: yield pickle.loads(row[0]) From 0f45259051375c5cc9f5a7eab1089601b04135d0 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:09:55 -0400 Subject: [PATCH 05/17] Assemble game history properly when accessing replay buffer with container methods --- replay_buffer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 09dc6c4f..1d27ab7f 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -109,15 +109,25 @@ def keys(self): def values(self): cursor = self.connection.cursor() - cursor.execute("SELECT object FROM game_history ORDER BY id ASC") + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalyzed_predicted_root_values" + " object" + " FROM game_history ORDER BY id ASC") for row in cursor: - yield pickle.loads(row[0]) + yield self.assemble_game_history(row)[1] def items(self): cursor = self.connection.cursor() - cursor.execute("SELECT id, value FROM game_history ORDER BY id ASC") + cursor.execute("SELECT id," + " game_priority," + " priorities," + " reanalyzed_predicted_root_values" + " object" + " FROM game_history ORDER BY id ASC") for row in cursor: - yield row[0], pickle.loads(row[1]) + yield self.assemble_game_history(row) def __contains__(self, key): cursor = self.connection.cursor() From b513a76393b4eaaab11a40d7368d9b74d8a93154 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:16:51 -0400 Subject: [PATCH 06/17] Change spelling 'reanalyzed' to 'reanalysed' --- replay_buffer.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 1d27ab7f..73b7e8c0 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -20,7 +20,7 @@ class GameHistoryDao(collections.MutableMapping): def assemble_game_history(result): # assemble priorities into the game history - # structure: (id, game_priority, priorities, reanalyzed_predicted_root_values, object) + # structure: (id, game_priority, priorities, reanalysed_predicted_root_values, object) game_history = pickle.loads(result[4]) game_history.game_priority = result[1] game_history.priorities = pickle.loads(result[2]) @@ -34,14 +34,14 @@ def disassemble_game_history(value): # disassemble the priorities from the game history game_priority = value.game_priority priorities = value.priorities - reanalyzed_predicted_root_values = value.reanalyzed_predicted_root_values + reanalysed_predicted_root_values = value.reanalysed_predicted_root_values # avoid storing duplicate data (it will be reassembled later) value.game_priority = None value.priorities = None - value.reanalyzed_predicted_root_values = None + value.reanalysed_predicted_root_values = None - return game_priority, priorities, reanalyzed_predicted_root_values + return game_priority, priorities, reanalysed_predicted_root_values def __init__(self, file): self.connection = sqlite3.connect(file) @@ -49,7 +49,7 @@ def __init__(self, file): " id INTEGER PRIMARY KEY ASC," " game_priority REAL," " priorities TEXT," - " reanalyzed_predicted_root_values TEXT," + " reanalysed_predicted_root_values TEXT," " object TEXT" ")") self.connection.commit() @@ -64,7 +64,7 @@ def __getitem__(self, key): cursor = self.connection.cursor() cursor.execute("SELECT game_priority," " priorities," - " reanalyzed_predicted_root_values" + " reanalysed_predicted_root_values" " object" " FROM game_history" " WHERE id = ?", (int(key),)) @@ -76,20 +76,20 @@ def __getitem__(self, key): def __setitem__(self, key, value): - game_priority, priorities, reanalyzed_predicted_root_values = self.disassemble_game_history(value) + game_priority, priorities, reanalysed_predicted_root_values = self.disassemble_game_history(value) cursor = self.connection.cursor() cursor.execute("REPLACE INTO game_history(" " id," " game_priority," " priorities," - " reanalyzed_predicted_root_values," + " reanalysed_predicted_root_values," " object" ") VALUES(?, ?, ?, ?, ?)", ( int(key), game_priority, pickle.dumps(priorities), - pickle.dumps(reanalyzed_predicted_root_values), + pickle.dumps(reanalysed_predicted_root_values), pickle.dumps(value) )) self.connection.commit() @@ -112,7 +112,7 @@ def values(self): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalyzed_predicted_root_values" + " reanalysed_predicted_root_values" " object" " FROM game_history ORDER BY id ASC") for row in cursor: @@ -123,7 +123,7 @@ def items(self): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalyzed_predicted_root_values" + " reanalysed_predicted_root_values" " object" " FROM game_history ORDER BY id ASC") for row in cursor: @@ -158,7 +158,7 @@ def sample_n(self, n): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalyzed_predicted_root_values" + " reanalysed_predicted_root_values" " object" " FROM game_history" " ORDER BY RANDOM()" @@ -172,7 +172,7 @@ def sample_n_ranked(self, n): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalyzed_predicted_root_values" + " reanalysed_predicted_root_values" " object" " FROM game_history" " ORDER BY -LOG(1.0 - RAND()) / game_priority" @@ -193,13 +193,13 @@ def update_priorities(self, game_id, game_priority, priorities): )) self.connection.commit() - def update_reanalyzed_values(self, game_id, reanalyzed_predicted_root_values): + def update_reanalyzed_values(self, game_id, reanalysed_predicted_root_values): cursor = self.connection.cursor() cursor.execute("UPDATE game_history" - "SET reanalyzed_predicted_root_values = ?" + "SET reanalysed_predicted_root_values = ?" "WHERE" " id = ?", ( - reanalyzed_predicted_root_values, + reanalysed_predicted_root_values, int(game_id) )) self.connection.commit() From 4f4915d21ddca965aac8f490d68f52c2bf5277ea Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:34:50 -0400 Subject: [PATCH 07/17] Add missing SQL functions --- replay_buffer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 73b7e8c0..2c6ce5d7 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -1,12 +1,13 @@ import collections -import copy import pickle import time +import math import numpy +import random import ray -import torch import sqlite3 +import torch import models @@ -45,6 +46,8 @@ def disassemble_game_history(value): def __init__(self, file): self.connection = sqlite3.connect(file) + self.connection.create_function('log', 1, math.log10) + self.connection.create_function('rand', 0, random.random) self.connection.execute("CREATE TABLE IF NOT EXISTS game_history(" " id INTEGER PRIMARY KEY ASC," " game_priority REAL," From 23ac3a96359d7856dfabf6b42833836c7c8bef31 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:38:11 -0400 Subject: [PATCH 08/17] Add missing column to SELECT statement --- replay_buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/replay_buffer.py b/replay_buffer.py index 2c6ce5d7..6272f764 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -65,7 +65,8 @@ def __len__(self): def __getitem__(self, key): cursor = self.connection.cursor() - cursor.execute("SELECT game_priority," + cursor.execute("SELECT id," + " game_priority," " priorities," " reanalysed_predicted_root_values" " object" From 6f69e45d04ea0817866748bd40fae046d7fbef6f Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:42:10 -0400 Subject: [PATCH 09/17] Add missing commas in SQL --- replay_buffer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 6272f764..2d427ed6 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -68,7 +68,7 @@ def __getitem__(self, key): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalysed_predicted_root_values" + " reanalysed_predicted_root_values," " object" " FROM game_history" " WHERE id = ?", (int(key),)) @@ -116,7 +116,7 @@ def values(self): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalysed_predicted_root_values" + " reanalysed_predicted_root_values," " object" " FROM game_history ORDER BY id ASC") for row in cursor: @@ -127,7 +127,7 @@ def items(self): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalysed_predicted_root_values" + " reanalysed_predicted_root_values," " object" " FROM game_history ORDER BY id ASC") for row in cursor: @@ -162,7 +162,7 @@ def sample_n(self, n): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalysed_predicted_root_values" + " reanalysed_predicted_root_values," " object" " FROM game_history" " ORDER BY RANDOM()" @@ -176,7 +176,7 @@ def sample_n_ranked(self, n): cursor.execute("SELECT id," " game_priority," " priorities," - " reanalysed_predicted_root_values" + " reanalysed_predicted_root_values," " object" " FROM game_history" " ORDER BY -LOG(1.0 - RAND()) / game_priority" From 9084cb45a7c238746e306d68d7eaca8a92f2f201 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:50:21 -0400 Subject: [PATCH 10/17] Don't store numpy arrays --- replay_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 2d427ed6..42643561 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -34,8 +34,8 @@ def disassemble_game_history(value): # disassemble the priorities from the game history game_priority = value.game_priority - priorities = value.priorities - reanalysed_predicted_root_values = value.reanalysed_predicted_root_values + priorities = list(value.priorities) + reanalysed_predicted_root_values = list(value.reanalysed_predicted_root_values) # avoid storing duplicate data (it will be reassembled later) value.game_priority = None From 9a25c09ce910bfb8f7cb5d958356ae6ac3cc99c3 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 20:57:28 -0400 Subject: [PATCH 11/17] Check for numpy array type before converting to list --- replay_buffer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 42643561..9c6918b3 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -34,8 +34,13 @@ def disassemble_game_history(value): # disassemble the priorities from the game history game_priority = value.game_priority - priorities = list(value.priorities) - reanalysed_predicted_root_values = list(value.reanalysed_predicted_root_values) + priorities = value.priorities + reanalysed_predicted_root_values = value.reanalysed_predicted_root_values + + # don't store numpy arrays or else it's a problem later after reloading + priorities = list(priorities) if type(priorities) is numpy.ndarray else priorities + reanalysed_predicted_root_values = list(reanalysed_predicted_root_values) if type( + reanalysed_predicted_root_values) is numpy.ndarray else reanalysed_predicted_root_values # avoid storing duplicate data (it will be reassembled later) value.game_priority = None From dceed5900a2f99df659974c713ff38259b74fed9 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 21:16:17 -0400 Subject: [PATCH 12/17] Coerce game_priority to a regular float before saving --- replay_buffer.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 9c6918b3..5a758769 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -37,11 +37,6 @@ def disassemble_game_history(value): priorities = value.priorities reanalysed_predicted_root_values = value.reanalysed_predicted_root_values - # don't store numpy arrays or else it's a problem later after reloading - priorities = list(priorities) if type(priorities) is numpy.ndarray else priorities - reanalysed_predicted_root_values = list(reanalysed_predicted_root_values) if type( - reanalysed_predicted_root_values) is numpy.ndarray else reanalysed_predicted_root_values - # avoid storing duplicate data (it will be reassembled later) value.game_priority = None value.priorities = None @@ -96,7 +91,7 @@ def __setitem__(self, key, value): " object" ") VALUES(?, ?, ?, ?, ?)", ( int(key), - game_priority, + float(game_priority), pickle.dumps(priorities), pickle.dumps(reanalysed_predicted_root_values), pickle.dumps(value) @@ -196,7 +191,7 @@ def update_priorities(self, game_id, game_priority, priorities): " priorities = ?" "WHERE" " id = ?", ( - game_priority, + float(game_priority), pickle.dumps(priorities), int(game_id) )) From 641ff903d732cb758eaba6007e5840eedeac8d14 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 21:18:05 -0400 Subject: [PATCH 13/17] Fix SQL syntax errors with concatenated strings --- replay_buffer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 5a758769..b8e9d663 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -187,10 +187,10 @@ def sample_n_ranked(self, n): def update_priorities(self, game_id, game_priority, priorities): cursor = self.connection.cursor() cursor.execute("UPDATE game_history" - "SET game_priority = ?," - " priorities = ?" - "WHERE" - " id = ?", ( + " SET game_priority = ?," + " priorities = ?" + " WHERE" + " id = ?", ( float(game_priority), pickle.dumps(priorities), int(game_id) @@ -200,9 +200,9 @@ def update_priorities(self, game_id, game_priority, priorities): def update_reanalyzed_values(self, game_id, reanalysed_predicted_root_values): cursor = self.connection.cursor() cursor.execute("UPDATE game_history" - "SET reanalysed_predicted_root_values = ?" - "WHERE" - " id = ?", ( + " SET reanalysed_predicted_root_values = ?" + " WHERE" + " id = ?", ( reanalysed_predicted_root_values, int(game_id) )) From 423e2b02b9aab2c503c2f0fd839948fb23238c87 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Mon, 26 Apr 2021 21:55:40 -0400 Subject: [PATCH 14/17] Serialize reanalysed_predicted_root_values before update --- replay_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/replay_buffer.py b/replay_buffer.py index b8e9d663..58c64af7 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -203,7 +203,7 @@ def update_reanalyzed_values(self, game_id, reanalysed_predicted_root_values): " SET reanalysed_predicted_root_values = ?" " WHERE" " id = ?", ( - reanalysed_predicted_root_values, + pickle.dumps(reanalysed_predicted_root_values), int(game_id) )) self.connection.commit() From f47e164eb47567a843d74f36642e556cdc749784 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Tue, 27 Apr 2021 00:52:40 -0400 Subject: [PATCH 15/17] Sample the database more efficiently --- replay_buffer.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 58c64af7..0adeff02 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -164,9 +164,11 @@ def sample_n(self, n): " priorities," " reanalysed_predicted_root_values," " object" - " FROM game_history" - " ORDER BY RANDOM()" - " LIMIT ?", (int(n),)) + " FROM game_history WHERE id IN (" + " SELECT id FROM game_history" + " ORDER BY RANDOM()" + " LIMIT ?" + " )", (int(n),)) for row in cursor: yield self.assemble_game_history(row) @@ -178,9 +180,11 @@ def sample_n_ranked(self, n): " priorities," " reanalysed_predicted_root_values," " object" - " FROM game_history" - " ORDER BY -LOG(1.0 - RAND()) / game_priority" - " LIMIT ?", (int(n),)) + " FROM game_history WHERE id IN (" + " SELECT id FROM game_history" + " ORDER BY -LOG(1.0 - RAND()) / game_priority" + " LIMIT ?" + " )", (int(n),)) for row in cursor: yield self.assemble_game_history(row) From f9a3307331de40a4950b7ccf627a2fdd54df6d3d Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Tue, 27 Apr 2021 20:48:32 -0400 Subject: [PATCH 16/17] Use fast update for reanalysed_predicted_root_values --- replay_buffer.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index 0adeff02..aa0d1bad 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -373,10 +373,8 @@ def sample_position(self, game_history, force_uniform=False): return position_index, position_prob - def update_game_history(self, game_id, game_history): - # The element could have been removed since its selection and update - if self.buffer.min_id() <= game_id: - self.buffer[game_id] = game_history + def update_reanalysed_values(self, game_id, reanalysed_predicted_root_values): + self.buffer.update_reanalyzed_values(game_id, reanalysed_predicted_root_values) def update_priorities(self, priorities, index_info): """ @@ -524,6 +522,7 @@ def reanalyse(self, replay_buffer, shared_storage): ) # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) + reanalysed_predicted_root_values = game_history.reanalysed_predicted_root_values if self.config.use_last_model_value: observations = [ game_history.get_stacked_observations( @@ -541,11 +540,11 @@ def reanalyse(self, replay_buffer, shared_storage): self.model.initial_inference(observations)[0], self.config.support_size, ) - game_history.reanalysed_predicted_root_values = ( + reanalysed_predicted_root_values = ( torch.squeeze(values).detach().cpu().numpy() ) - replay_buffer.update_game_history.remote(game_id, game_history) + replay_buffer.update_reanalysed_values.remote(game_id, reanalysed_predicted_root_values) self.num_reanalysed_games += 1 shared_storage.set_info.remote( "num_reanalysed_games", self.num_reanalysed_games From 729daaaaa66d3a940873422deef9461dd30c92c2 Mon Sep 17 00:00:00 2001 From: me-unsolicited Date: Tue, 27 Apr 2021 22:19:30 -0400 Subject: [PATCH 17/17] Change spelling 'reanalyzed' to 'reanalysed' --- replay_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/replay_buffer.py b/replay_buffer.py index aa0d1bad..30fd9daa 100644 --- a/replay_buffer.py +++ b/replay_buffer.py @@ -201,7 +201,7 @@ def update_priorities(self, game_id, game_priority, priorities): )) self.connection.commit() - def update_reanalyzed_values(self, game_id, reanalysed_predicted_root_values): + def update_reanalysed_values(self, game_id, reanalysed_predicted_root_values): cursor = self.connection.cursor() cursor.execute("UPDATE game_history" " SET reanalysed_predicted_root_values = ?" @@ -374,7 +374,7 @@ def sample_position(self, game_history, force_uniform=False): return position_index, position_prob def update_reanalysed_values(self, game_id, reanalysed_predicted_root_values): - self.buffer.update_reanalyzed_values(game_id, reanalysed_predicted_root_values) + self.buffer.update_reanalysed_values(game_id, reanalysed_predicted_root_values) def update_priorities(self, priorities, index_info): """