diff --git a/perfect_information_game/move_selection/mcts/__init__.py b/perfect_information_game/move_selection/mcts/__init__.py index 8a07807..3e3509b 100644 --- a/perfect_information_game/move_selection/mcts/__init__.py +++ b/perfect_information_game/move_selection/mcts/__init__.py @@ -1,5 +1,6 @@ from perfect_information_game.move_selection.mcts.abstract_node import AbstractNode from perfect_information_game.move_selection.mcts.rollout_node import RolloutNode from perfect_information_game.move_selection.mcts.heuristic_node import HeuristicNode +from perfect_information_game.move_selection.mcts.tablebase_node import TablebaseNode from perfect_information_game.move_selection.mcts.mcts import MCTS from perfect_information_game.move_selection.mcts.async_mcts import AsyncMCTS diff --git a/perfect_information_game/move_selection/mcts/abstract_node.py b/perfect_information_game/move_selection/mcts/abstract_node.py index 332cca4..c352bf0 100644 --- a/perfect_information_game/move_selection/mcts/abstract_node.py +++ b/perfect_information_game/move_selection/mcts/abstract_node.py @@ -1,24 +1,41 @@ from abc import ABC, abstractmethod import numpy as np +from perfect_information_game.tablebases import EmptyTablebaseManager class AbstractNode(ABC): - def __init__(self, position, parent, GameClass, c=np.sqrt(2), verbose=False): + def __init__(self, position, parent, GameClass, tablebase_manager=None, verbose=False): self.position = position self.parent = parent self.GameClass = GameClass - self.c = c - self.fully_expanded = GameClass.is_over(position) + self.tablebase_manager = tablebase_manager if tablebase_manager is not None \ + else EmptyTablebaseManager(GameClass) + self.verbose = verbose + + if self.GameClass.is_over(position): + self.fully_expanded = True + self.outcome = GameClass.get_winner(position) + else: + self.fully_expanded = False + self.outcome = None # this will be None if fully_expanded is False, otherwise it will be either -1, 0, or 1 + self.is_maximizing = GameClass.is_player_1_turn(position) self.children = None - self.verbose = verbose @abstractmethod def get_evaluation(self): + """ + Returns the prediction of the evaluation of the game assuming perfect play, which is in the range [-1, 1]. + This will be 1 if player 1 is winning, 0 if it is a draw, and -1 if player 1 is losing. + """ pass @abstractmethod def count_expansions(self): + """ + Returns the number of times that this node has been explored. + If the node is fully expanded, then np.inf will be returned. + """ pass @abstractmethod @@ -33,11 +50,15 @@ def get_puct_heuristic_for_child(self, i): def expand(self): pass - @abstractmethod def set_fully_expanded(self, minimax_evaluation): - pass + self.fully_expanded = True + self.outcome = minimax_evaluation def choose_best_node(self, return_probability_distribution=False, optimal=False): + """ + Chooses the best move based on the expansions performed so far. + This should only need to be called on the root node in order to choose a move for the AI. + """ distribution = [] optimal_value = 1 if self.is_maximizing else -1 @@ -55,10 +76,9 @@ def choose_best_node(self, return_probability_distribution=False, optimal=False) if child.fully_expanded and child.get_evaluation() == self.get_evaluation(): # TODO: when losing, consider the number of ways the opponent can win in response to a move depth_to_endgame = child.depth_to_end_game() - # if we are winning, weight smaller depths much more strongly by using e^-x - # if we are losing or drawing, weight larger depths much more strongly by using e^x - relative_probability = np.exp(-depth_to_endgame if self.get_evaluation() == optimal_value else - depth_to_endgame) + # if we want a short game, weight smaller depths much more strongly by using e^-x + # if we want a long game, weight larger depths much more strongly by using e^x + relative_probability = np.exp(depth_to_endgame if self.want_long_game() else -depth_to_endgame) distribution.append(relative_probability) else: distribution.append(0) @@ -81,12 +101,24 @@ def choose_best_node(self, return_probability_distribution=False, optimal=False) best_child = self.children[idx] return (best_child, distribution) if return_probability_distribution else best_child - def choose_expansion_node(self): + def choose_expansion_node(self, search_suboptimal=False): + """ + Searches the tree to find a node to expand. + Returns None if no nodes could be found because they are all fully expanded. + :param search_suboptimal: If True, then nodes will continue to be searched even if + they have siblings that lead to a win. + This should only be set to True once the entire tree has been searched and the best line has been determined. + """ + if search_suboptimal: + raise NotImplementedError() + # TODO: continue tree search in case the user makes a mistake and the game continues if self.fully_expanded: + # self must be the root node because a fully expanded node would never be chosen by its parent return None if self.count_expansions() == 0: + # this node itself has never been expanded return self self.ensure_children() @@ -99,11 +131,13 @@ def choose_expansion_node(self): # If this child is already optimal, then self is fully expanded and there is no point searching further if child.get_evaluation() == optimal_value: self.set_fully_expanded(optimal_value) + # delegate to the parent, which will now choose differently since self is fully expanded return self.parent.choose_expansion_node() if self.parent is not None else None # continue searching other children, there may be another child that is more optimal continue - # check puct heuristic before calling child.get_evaluation() because it may result in division by 0 + # check puct heuristic before calling child.get_evaluation() because + # it may result in division by 0 for RolloutNode puct_heuristic = self.get_puct_heuristic_for_child(i) if np.isinf(puct_heuristic): return child @@ -122,30 +156,48 @@ def choose_expansion_node(self): # if nothing was found because all children are fully expanded if best_child is None: if self.verbose and not self.fully_expanded and self.parent is None: + # print this message when the root node becomes fully expanded so that it is only printed once print('Fully expanded tree!') - minimax_evaluation = max([child.get_evaluation() for child in self.children]) if self.is_maximizing \ - else min([child.get_evaluation() for child in self.children]) + minimax_evaluation = (max if self.is_maximizing else min)( + [child.get_evaluation() for child in self.children]) self.set_fully_expanded(minimax_evaluation) # this node is now fully expanded, so ask the parent to try to choose again # if no parent is available (i.e. this is the root node) then the entire search tree has been expanded return self.parent.choose_expansion_node() if self.parent is not None else None + # the best child has been chosen, and the expansion node choice is delegated to it now return best_child.choose_expansion_node() - def depth_to_end_game(self): + def want_long_game(self): + """ + Can only be called on fully expanded nodes. + + Returns True if we want the game to be as long as possible and + False if we want the game to be as fast as possible. + """ if not self.fully_expanded: raise Exception('Node not fully expanded!') - if self.children is None: - return 0 optimal_value = 1 if self.is_maximizing else -1 - if self.get_evaluation() == optimal_value: + if self.outcome == optimal_value: # if we are winning, win as fast as possible - return 1 + min(child.depth_to_end_game() for child in self.children - if child.fully_expanded and child.get_evaluation() == self.get_evaluation()) + return False + elif self.outcome == -optimal_value: + # if we are losing, lose as slow as possible + return True else: - # if we are losing or it is a draw, lose as slow as possible - return 1 + max(child.depth_to_end_game() for child in self.children - if child.fully_expanded and child.get_evaluation() == self.get_evaluation()) + try: + heuristic = self.GameClass.heuristic(self.position) + # if the game is a draw, make it as slow as possible if we have a material advantage + # otherwise, finish the game as fast as possible + return heuristic > 0 if self.is_maximizing else heuristic < 0 + except NotImplementedError: + # if the game is a draw and no heuristic is defined, then finish the game as fast as possible + return False + + def depth_to_end_game(self): + return 1 + (max if self.want_long_game() else min)( + child.depth_to_end_game() for child in self.children + if child.fully_expanded and child.get_evaluation() == self.get_evaluation()) diff --git a/perfect_information_game/move_selection/mcts/async_mcts.py b/perfect_information_game/move_selection/mcts/async_mcts.py index 690dc6c..11e9740 100644 --- a/perfect_information_game/move_selection/mcts/async_mcts.py +++ b/perfect_information_game/move_selection/mcts/async_mcts.py @@ -1,10 +1,11 @@ -from multiprocessing import Pipe, Pool +from multiprocessing import Pipe from multiprocessing.context import Process from time import time import numpy as np -from perfect_information_game.move_selection.mcts import HeuristicNode -from perfect_information_game.move_selection.mcts import RolloutNode from perfect_information_game.move_selection import MoveChooser +from perfect_information_game.move_selection.mcts import TablebaseNode, RolloutNode, HeuristicNode +from perfect_information_game.tablebases import EmptyTablebaseManager +from perfect_information_game.utils import OptionalPool class AsyncMCTS(MoveChooser): @@ -13,7 +14,8 @@ class AsyncMCTS(MoveChooser): This is achieved using multiprocessing, and a Pipe for transferring data to and from the worker process. """ - def __init__(self, GameClass, starting_position, time_limit=3, network=None, c=np.sqrt(2), d=1, threads=1): + def __init__(self, GameClass, starting_position, time_limit=3, network=None, c=np.sqrt(2), d=1, threads=1, + tablebase_manager=None): """ Either: If network is provided, threads must be 1. @@ -23,10 +25,12 @@ def __init__(self, GameClass, starting_position, time_limit=3, network=None, c=n if network is not None and threads != 1: raise ValueError('Threads != 1 with Network != None') + if tablebase_manager is None: + tablebase_manager = EmptyTablebaseManager(GameClass) self.parent_pipe, worker_pipe = Pipe() self.worker_process = Process(target=self.loop_func, args=(GameClass, starting_position, time_limit, network, c, d, threads, - worker_pipe)) + tablebase_manager, worker_pipe)) def start(self): self.worker_process.start() @@ -60,69 +64,71 @@ def terminate(self): self.worker_process.join() @staticmethod - def loop_func(GameClass, position, time_limit, network, c, d, threads, worker_pipe): - if network is None: - pool = Pool(threads) if threads > 1 else None - root = RolloutNode(position, parent=None, GameClass=GameClass, c=c, rollout_batch_size=threads, pool=pool, - verbose=True) - else: - network.initialize() - root = HeuristicNode(position, None, GameClass, network, c, d, verbose=True) - - while True: - best_node = root.choose_expansion_node() - - if best_node is not None: - best_node.expand() - - if root.children is not None and worker_pipe.poll(): - user_chosen_position = worker_pipe.recv() - - if user_chosen_position is not None: - # an updated position has been received so we can truncate the tree - for child in root.children: - if np.all(child.position == user_chosen_position): - root = child - root.parent = None + def loop_func(GameClass, position, time_limit, network, c, d, threads, tablebase_manager, worker_pipe): + with OptionalPool(threads) as pool: + if network is not None: + network.initialize() + + root = TablebaseNode.attempt_create(position, None, GameClass, tablebase_manager, verbose=True, + backup_factory=lambda: + RolloutNode(position, None, GameClass, c, threads, pool, verbose=True) + if network is None else + HeuristicNode(position, None, GameClass, network, c, d, verbose=True)) + + while True: + best_node = root.choose_expansion_node() + + if best_node is not None: + best_node.expand() + + if root.children is not None and worker_pipe.poll(): + user_chosen_position = worker_pipe.recv() + + if user_chosen_position is not None: + # an updated position has been received so we can truncate the tree + for child in root.children: + if np.all(child.position == user_chosen_position): + root = child + root.parent = None + break + else: + print(user_chosen_position) + raise Exception('Invalid user chosen move!') + + if GameClass.is_over(root.position): + print('Game Over in Async MCTS: ', GameClass.get_winner(root.position)) break else: - print(user_chosen_position) - raise Exception('Invalid user chosen move!') - - if GameClass.is_over(root.position): - print('Game Over in Async MCTS: ', GameClass.get_winner(root.position)) - return - else: - # this move chooser has been requested to decide on a move via the choose_move function - start_time = time() - while time() - start_time < time_limit: - best_node = root.choose_expansion_node() - - # best_node will be None if the tree is fully expanded - if best_node is None: - break - - best_node.expand() - - is_ai_player_1 = GameClass.is_player_1_turn(root.position) - chosen_positions = [] - print(f'MCTS choosing move based on {root.count_expansions()} expansions!') - - # choose moves as long as it is still the ai's turn - while GameClass.is_player_1_turn(root.position) == is_ai_player_1: - if root.children is None: + # this move chooser has been requested to decide on a move via the choose_move function + start_time = time() + while time() - start_time < time_limit: best_node = root.choose_expansion_node() - if best_node is not None: - best_node.expand() - root, distribution = root.choose_best_node(return_probability_distribution=True, optimal=True) - chosen_positions.append((root.position, distribution)) - - print('Expected outcome: ', root.get_evaluation()) - root.parent = None # delete references to the parent and siblings - worker_pipe.send(chosen_positions) - if GameClass.is_over(root.position): - print('Game Over in Async MCTS: ', GameClass.get_winner(root.position)) - return + + # best_node will be None if the tree is fully expanded + if best_node is None: + break + + best_node.expand() + + is_ai_player_1 = GameClass.is_player_1_turn(root.position) + chosen_positions = [] + print(f'MCTS choosing move based on {root.count_expansions()} expansions!') + + # choose moves as long as it is still the ai's turn + while GameClass.is_player_1_turn(root.position) == is_ai_player_1: + if root.children is None: + best_node = root.choose_expansion_node() + if best_node is not None: + best_node.expand() + root, distribution = root.choose_best_node(return_probability_distribution=True, optimal=True) + chosen_positions.append((root.position, distribution)) + + print('Expected outcome: ', root.get_evaluation()) + root.parent = None # delete references to the parent and siblings + worker_pipe.send(chosen_positions) + if GameClass.is_over(root.position): + print('Game Over in Async MCTS: ', GameClass.get_winner(root.position)) + break def reset(self): raise NotImplementedError('') diff --git a/perfect_information_game/move_selection/mcts/heuristic_node.py b/perfect_information_game/move_selection/mcts/heuristic_node.py index 17011d3..7b68500 100644 --- a/perfect_information_game/move_selection/mcts/heuristic_node.py +++ b/perfect_information_game/move_selection/mcts/heuristic_node.py @@ -1,48 +1,49 @@ import numpy as np -from perfect_information_game.move_selection.mcts import AbstractNode +from perfect_information_game.move_selection.mcts import AbstractNode, TablebaseNode class HeuristicNode(AbstractNode): def __init__(self, position, parent, GameClass, network, c=np.sqrt(2), d=1, network_call_results=None, - verbose=False): - super().__init__(position, parent, GameClass, c, verbose) + tablebase_manager=None, verbose=False): + super().__init__(position, parent, GameClass, tablebase_manager, verbose) self.network = network + self.c = c self.d = d - if self.fully_expanded: - self.heuristic = GameClass.get_winner(position) - self.policy = None - self.expansions = np.inf - else: - self.policy, self.heuristic = network.call(position[np.newaxis, ...])[0] if network_call_results is None \ - else network_call_results - self.expansions = 0 + self.policy, self.heuristic = network.call(position[np.newaxis, ...])[0] if network_call_results is None \ + else network_call_results + self.expansions = 0 def count_expansions(self): return self.expansions def get_evaluation(self): - return self.heuristic + return self.outcome if self.fully_expanded else self.heuristic def expand(self, moves=None, network_call_results=None): + """ + When a HeuristicNode is expanded, its children are created and they're heuristics are created with the network. + Then the tree is updated so that every node's heuristic is the minimax value of its children. + """ if self.children is not None: + # this node was already expanded raise Exception('Node already has children!') if self.fully_expanded: - raise Exception('Node is terminal!') + raise Exception('Node is fully expanded!') self.ensure_children(moves, network_call_results) if self.children is None: + # this check is needed to prevent lint warnings raise Exception('Failed to create children!') - critical_value = max([child.heuristic for child in self.children]) if self.is_maximizing else \ - min([child.heuristic for child in self.children]) + critical_value = (max if self.is_maximizing else min)( + [child.get_evaluation() for child in self.children]) self.heuristic = critical_value # update heuristic for all parents if it beats their current best heuristic node = self.parent while node is not None: - if (node.is_maximizing and critical_value > node.heuristic) or \ - (not node.is_maximizing and critical_value < node.heuristic): + if critical_value > node.heuristic if node.is_maximizing else critical_value < node.heuristic: node.heuristic = critical_value node.expansions += 1 node = node.parent @@ -59,9 +60,8 @@ def expand(self, moves=None, network_call_results=None): node = node.parent def set_fully_expanded(self, minimax_evaluation): - self.heuristic = minimax_evaluation + super(HeuristicNode, self).set_fully_expanded(minimax_evaluation) self.expansions = np.inf - self.fully_expanded = True def get_puct_heuristic_for_child(self, i): exploration_term = self.c * np.sqrt(np.log(self.expansions) / (self.children[i].expansions + 1)) @@ -69,11 +69,14 @@ def get_puct_heuristic_for_child(self, i): return exploration_term + policy_term def ensure_children(self, moves=None, network_call_results=None): - if self.children is None: - moves = self.GameClass.get_possible_moves(self.position) if moves is None else moves - network_call_results = self.network.call(np.stack(moves, axis=0)) if network_call_results is None \ - else network_call_results - self.children = [HeuristicNode(move, self, self.GameClass, self.network, self.c, self.d, - network_call_results=network_call_result, verbose=self.verbose) - for move, network_call_result in zip(moves, network_call_results)] - self.expansions = 1 + if self.children is not None: + return + moves = self.GameClass.get_possible_moves(self.position) if moves is None else moves + network_call_results = self.network.call(np.stack(moves, axis=0)) if network_call_results is None \ + else network_call_results + self.children = [TablebaseNode.attempt_create(move, self, self.GameClass, self.tablebase_manager, self.verbose, + lambda: HeuristicNode(move, self, self.GameClass, self.network, + self.c, self.d, network_call_result, + self.verbose)) + for move, network_call_result in zip(moves, network_call_results)] + self.expansions = 1 diff --git a/perfect_information_game/move_selection/mcts/mcts.py b/perfect_information_game/move_selection/mcts/mcts.py index b18ae86..b2c920f 100644 --- a/perfect_information_game/move_selection/mcts/mcts.py +++ b/perfect_information_game/move_selection/mcts/mcts.py @@ -2,8 +2,8 @@ from multiprocessing import Pool import numpy as np from perfect_information_game.move_selection import MoveChooser -from perfect_information_game.move_selection.mcts import RolloutNode -from perfect_information_game.move_selection.mcts import HeuristicNode +from perfect_information_game.move_selection.mcts import TablebaseNode, RolloutNode, HeuristicNode +from perfect_information_game.tablebases import EmptyTablebaseManager # TODO: add hash table to keep track of multiple move combinations that lead to the same position @@ -15,7 +15,8 @@ class MCTS(MoveChooser): https://www.youtube.com/watch?v=UXW2yZndl7U """ - def __init__(self, GameClass, starting_position=None, network=None, c=np.sqrt(2), d=1, threads=1): + def __init__(self, GameClass, starting_position=None, network=None, c=np.sqrt(2), d=1, threads=1, + tablebase_manager=None): """ Either: If network is provided, threads must be 1. @@ -32,6 +33,8 @@ def __init__(self, GameClass, starting_position=None, network=None, c=np.sqrt(2) self.d = d self.threads = threads self.pool = Pool(threads) if threads > 1 else None + self.tablebase_manager = tablebase_manager if tablebase_manager is not None \ + else EmptyTablebaseManager(GameClass) def choose_move(self, return_distribution=False, time_limit=10): if return_distribution: @@ -40,11 +43,13 @@ def choose_move(self, return_distribution=False, time_limit=10): if self.GameClass.is_over(self.position): raise Exception('Game Finished!') - if self.network is None: - root = RolloutNode(self.position, parent=None, GameClass=self.GameClass, c=self.c, - rollout_batch_size=self.threads, pool=self.pool, verbose=True) - else: - root = HeuristicNode(self.position, None, self.GameClass, self.network, self.c, self.d, verbose=True) + root = TablebaseNode.attempt_create(self.position, None, self.GameClass, self.tablebase_manager, verbose=True, + backup_factory=lambda: + RolloutNode(self.position, None, self.GameClass, self.c, self.threads, + self.pool, verbose=True) + if self.network is None else + HeuristicNode(self.position, None, self.GameClass, self.network, self.c, + self.d, verbose=True)) start_time = time() while time() - start_time < time_limit: diff --git a/perfect_information_game/move_selection/mcts/rollout_node.py b/perfect_information_game/move_selection/mcts/rollout_node.py index 36375bd..abb2c09 100644 --- a/perfect_information_game/move_selection/mcts/rollout_node.py +++ b/perfect_information_game/move_selection/mcts/rollout_node.py @@ -1,36 +1,34 @@ import numpy as np -from perfect_information_game.move_selection.mcts import AbstractNode +from perfect_information_game.move_selection.mcts import AbstractNode, TablebaseNode +from perfect_information_game.utils import OptionalPool, choose_random class RolloutNode(AbstractNode): - def __init__(self, position, parent, GameClass, c=np.sqrt(2), rollout_batch_size=1, pool=None, verbose=False): - super().__init__(position, parent, GameClass, c, verbose) + def __init__(self, position, parent, GameClass, c=np.sqrt(2), rollout_batch_size=1, pool=None, + tablebase_manager=None, verbose=False): + super().__init__(position, parent, GameClass, tablebase_manager, verbose) + self.c = c self.rollout_batch_size = rollout_batch_size - self.pool = pool + self.pool = pool if pool is not None else OptionalPool(threads=1) - if self.fully_expanded: - self.rollout_sum = GameClass.get_winner(position) - self.rollout_count = np.inf - else: - self.rollout_sum = 0 - self.rollout_count = 0 + # track the sum and the number of rollouts so that the average can be updated as more rollouts are done. + self.rollout_sum = 0 + self.rollout_count = 0 def count_expansions(self): - return self.rollout_count + return self.rollout_count if not self.fully_expanded else np.inf def get_evaluation(self): - return self.rollout_sum / self.rollout_count if not self.fully_expanded else self.rollout_sum + return self.rollout_sum / self.rollout_count if not self.fully_expanded else self.outcome def ensure_children(self): - if self.children is None: - self.children = [RolloutNode(move, self, self.GameClass, self.c, self.rollout_batch_size, self.pool, - self.verbose) - for move in self.GameClass.get_possible_moves(self.position)] - - def set_fully_expanded(self, minimax_evaluation): - self.rollout_sum = minimax_evaluation - self.rollout_count = np.inf - self.fully_expanded = True + if self.children is not None: + return + self.children = [TablebaseNode.attempt_create(move, self, self.GameClass, self.tablebase_manager, self.verbose, + lambda: RolloutNode(move, self, self.GameClass, self.c, + self.rollout_batch_size, self.pool, + self.verbose)) + for move in self.GameClass.get_possible_moves(self.position)] def get_puct_heuristic_for_child(self, i): exploration_term = self.c * np.sqrt(np.log(self.rollout_count) / self.children[i].rollout_count) \ @@ -38,9 +36,16 @@ def get_puct_heuristic_for_child(self, i): return exploration_term def expand(self): - rollout_sum = sum(self.pool.starmap(self.execute_single_rollout, [() for _ in range(self.rollout_batch_size)]) - if self.pool is not None else - [self.execute_single_rollout() for _ in range(self.rollout_batch_size)]) + """ + When a RolloutNode is expanded, random rollouts are performed from the given node. + The node as well as all its parents are updated such that each node keeps track of the average rollout outcome + out of all rollouts from itself or any of its children. + + Although any node can be expanded, only leaf nodes determined by choose_expansion_node should be expanded. + If non-leaf nodes are expanded, then the tree will never grow and the children of the expanded node + would not have their average rollout updated. + """ + rollout_sum = sum(self.pool.starmap(self.execute_single_rollout, [() for _ in range(self.rollout_batch_size)])) # update this node and all its parents node = self @@ -51,8 +56,9 @@ def expand(self): def execute_single_rollout(self): state = self.position - while not self.GameClass.is_over(state): - sub_states = self.GameClass.get_possible_moves(state) - state = sub_states[np.random.randint(len(sub_states))] + outcome, _ = self.tablebase_manager.query_position(state, outcome_only=True) + while np.isnan(outcome): + state = choose_random(self.GameClass.get_possible_moves(state)) + outcome, _ = self.tablebase_manager.query_position(state, outcome_only=True) - return self.GameClass.get_winner(state) + return outcome diff --git a/perfect_information_game/move_selection/mcts/tablebase_node.py b/perfect_information_game/move_selection/mcts/tablebase_node.py new file mode 100644 index 0000000..623d6fd --- /dev/null +++ b/perfect_information_game/move_selection/mcts/tablebase_node.py @@ -0,0 +1,67 @@ +import numpy as np +from perfect_information_game.move_selection.mcts import AbstractNode +from perfect_information_game.tablebases import TablebaseException + + +class TablebaseNode(AbstractNode): + def __init__(self, position, parent, GameClass, tablebase_manager, verbose=False): + super().__init__(position, parent, GameClass, tablebase_manager, verbose=verbose) + if self.fully_expanded: + # the AbstractNode constructor already marked this node as fully_expanded because the game is over + self.best_move = None + self.terminal_distance = 0 + else: + self.fully_expanded = True + + self.best_move, self.outcome, self.terminal_distance = tablebase_manager.query_position(position) + if np.isnan(self.outcome): + raise TablebaseException('Given position was not found in any existing tablebase!') + + @staticmethod + def attempt_create(position, parent, GameClass, tablebase_manager, verbose=False, backup_factory=None): + try: + return TablebaseNode(position, parent, GameClass, tablebase_manager, verbose) + except TablebaseException: + if backup_factory is not None: + return backup_factory() + raise + + def get_evaluation(self): + return self.outcome + + def count_expansions(self): + return np.inf + + def set_fully_expanded(self, minimax_evaluation): + raise NotImplementedError() + + def ensure_children(self): + if self.children is None: + self.children = [TablebaseNode(move, self, self.GameClass, self.tablebase_manager, self.verbose) + for move in self.GameClass.get_possible_moves(self.position)] + + def get_puct_heuristic_for_child(self, i): + raise NotImplementedError() + + def choose_best_node(self, return_probability_distribution=False, optimal=True): + if not optimal: + print('Warning: non-optimal moves are not possible for TablebaseNode!') + self.ensure_children() + distribution = [1 if child.position == self.best_move else 0 for child in self.children] + if np.sum(distribution) != 1: + raise Exception('Inconsistent tablebase results!') + idx = np.argmax(distribution) + best_child = self.children[idx] + return (best_child, distribution) if return_probability_distribution else best_child + + def expand(self): + raise Exception('Node is fully expanded!') + + def choose_expansion_node(self, search_suboptimal=False): + # no reason to expand any nodes since they are all in the tablebase + return None + + def depth_to_end_game(self): + return self.terminal_distance + +